1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 /*
4  *  HID-BPF support for Linux
5  *
6  *  Copyright (c) 2022 Benjamin Tissoires
7  */
8 
9 #include <linux/bitops.h>
10 #include <linux/btf.h>
11 #include <linux/btf_ids.h>
12 #include <linux/circ_buf.h>
13 #include <linux/filter.h>
14 #include <linux/hid.h>
15 #include <linux/hid_bpf.h>
16 #include <linux/init.h>
17 #include <linux/module.h>
18 #include <linux/workqueue.h>
19 #include "hid_bpf_dispatch.h"
20 #include "entrypoints/entrypoints.lskel.h"
21 
22 #define HID_BPF_MAX_PROGS 1024 /* keep this in sync with preloaded bpf,
23 				* needs to be a power of 2 as we use it as
24 				* a circular buffer
25 				*/
26 
27 #define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1))
28 #define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1))
29 
30 /*
31  * represents one attached program stored in the hid jump table
32  */
33 struct hid_bpf_prog_entry {
34 	struct bpf_prog *prog;
35 	struct hid_device *hdev;
36 	enum hid_bpf_prog_type type;
37 	u16 idx;
38 };
39 
40 struct hid_bpf_jmp_table {
41 	struct bpf_map *map;
42 	struct bpf_map *prog_keys;
43 	struct hid_bpf_prog_entry entries[HID_BPF_MAX_PROGS]; /* compacted list, circular buffer */
44 	int tail, head;
45 	struct bpf_prog *progs[HID_BPF_MAX_PROGS]; /* idx -> progs mapping */
46 	unsigned long enabled[BITS_TO_LONGS(HID_BPF_MAX_PROGS)];
47 };
48 
49 #define FOR_ENTRIES(__i, __start, __end) \
50 	for (__i = __start; CIRC_CNT(__end, __i, HID_BPF_MAX_PROGS); __i = NEXT(__i))
51 
52 static struct hid_bpf_jmp_table jmp_table;
53 
54 static DEFINE_MUTEX(hid_bpf_attach_lock);		/* held when attaching/detaching programs */
55 
56 static void hid_bpf_release_progs(struct work_struct *work);
57 
58 static DECLARE_WORK(release_work, hid_bpf_release_progs);
59 
60 BTF_ID_LIST(hid_bpf_btf_ids)
61 BTF_ID(func, hid_bpf_device_event)			/* HID_BPF_PROG_TYPE_DEVICE_EVENT */
62 
63 static int hid_bpf_max_programs(enum hid_bpf_prog_type type)
64 {
65 	switch (type) {
66 	case HID_BPF_PROG_TYPE_DEVICE_EVENT:
67 		return HID_BPF_MAX_PROGS_PER_DEV;
68 	default:
69 		return -EINVAL;
70 	}
71 }
72 
73 static int hid_bpf_program_count(struct hid_device *hdev,
74 				 struct bpf_prog *prog,
75 				 enum hid_bpf_prog_type type)
76 {
77 	int i, n = 0;
78 
79 	if (type >= HID_BPF_PROG_TYPE_MAX)
80 		return -EINVAL;
81 
82 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
83 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
84 
85 		if (type != HID_BPF_PROG_TYPE_UNDEF && entry->type != type)
86 			continue;
87 
88 		if (hdev && entry->hdev != hdev)
89 			continue;
90 
91 		if (prog && entry->prog != prog)
92 			continue;
93 
94 		n++;
95 	}
96 
97 	return n;
98 }
99 
100 __weak noinline int __hid_bpf_tail_call(struct hid_bpf_ctx *ctx)
101 {
102 	return 0;
103 }
104 ALLOW_ERROR_INJECTION(__hid_bpf_tail_call, ERRNO);
105 
106 int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
107 		     struct hid_bpf_ctx_kern *ctx_kern)
108 {
109 	struct hid_bpf_prog_list *prog_list;
110 	int i, idx, err = 0;
111 
112 	rcu_read_lock();
113 	prog_list = rcu_dereference(hdev->bpf.progs[type]);
114 
115 	if (!prog_list)
116 		goto out_unlock;
117 
118 	for (i = 0; i < prog_list->prog_cnt; i++) {
119 		idx = prog_list->prog_idx[i];
120 
121 		if (!test_bit(idx, jmp_table.enabled))
122 			continue;
123 
124 		ctx_kern->ctx.index = idx;
125 		err = __hid_bpf_tail_call(&ctx_kern->ctx);
126 		if (err)
127 			break;
128 	}
129 
130  out_unlock:
131 	rcu_read_unlock();
132 
133 	return err;
134 }
135 
136 /*
137  * assign the list of programs attached to a given hid device.
138  */
139 static void __hid_bpf_set_hdev_progs(struct hid_device *hdev, struct hid_bpf_prog_list *new_list,
140 				     enum hid_bpf_prog_type type)
141 {
142 	struct hid_bpf_prog_list *old_list;
143 
144 	spin_lock(&hdev->bpf.progs_lock);
145 	old_list = rcu_dereference_protected(hdev->bpf.progs[type],
146 					     lockdep_is_held(&hdev->bpf.progs_lock));
147 	rcu_assign_pointer(hdev->bpf.progs[type], new_list);
148 	spin_unlock(&hdev->bpf.progs_lock);
149 	synchronize_rcu();
150 
151 	kfree(old_list);
152 }
153 
154 /*
155  * allocate and populate the list of programs attached to a given hid device.
156  *
157  * Must be called under lock.
158  */
159 static int hid_bpf_populate_hdev(struct hid_device *hdev, enum hid_bpf_prog_type type)
160 {
161 	struct hid_bpf_prog_list *new_list;
162 	int i;
163 
164 	if (type >= HID_BPF_PROG_TYPE_MAX || !hdev)
165 		return -EINVAL;
166 
167 	if (hdev->bpf.destroyed)
168 		return 0;
169 
170 	new_list = kzalloc(sizeof(*new_list), GFP_KERNEL);
171 	if (!new_list)
172 		return -ENOMEM;
173 
174 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
175 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
176 
177 		if (entry->type == type && entry->hdev == hdev &&
178 		    test_bit(entry->idx, jmp_table.enabled))
179 			new_list->prog_idx[new_list->prog_cnt++] = entry->idx;
180 	}
181 
182 	__hid_bpf_set_hdev_progs(hdev, new_list, type);
183 
184 	return 0;
185 }
186 
187 static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
188 {
189 	skel_map_delete_elem(map_fd, &idx);
190 	jmp_table.progs[idx] = NULL;
191 }
192 
193 static void hid_bpf_release_progs(struct work_struct *work)
194 {
195 	int i, j, n, map_fd = -1;
196 
197 	if (!jmp_table.map)
198 		return;
199 
200 	/* retrieve a fd of our prog_array map in BPF */
201 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
202 	if (map_fd < 0)
203 		return;
204 
205 	mutex_lock(&hid_bpf_attach_lock); /* protects against attaching new programs */
206 
207 	/* detach unused progs from HID devices */
208 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
209 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
210 		enum hid_bpf_prog_type type;
211 		struct hid_device *hdev;
212 
213 		if (test_bit(entry->idx, jmp_table.enabled))
214 			continue;
215 
216 		/* we have an attached prog */
217 		if (entry->hdev) {
218 			hdev = entry->hdev;
219 			type = entry->type;
220 
221 			hid_bpf_populate_hdev(hdev, type);
222 
223 			/* mark all other disabled progs from hdev of the given type as detached */
224 			FOR_ENTRIES(j, i, jmp_table.head) {
225 				struct hid_bpf_prog_entry *next;
226 
227 				next = &jmp_table.entries[j];
228 
229 				if (test_bit(next->idx, jmp_table.enabled))
230 					continue;
231 
232 				if (next->hdev == hdev && next->type == type)
233 					next->hdev = NULL;
234 			}
235 		}
236 	}
237 
238 	/* remove all unused progs from the jump table */
239 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
240 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
241 
242 		if (test_bit(entry->idx, jmp_table.enabled))
243 			continue;
244 
245 		if (entry->prog)
246 			__hid_bpf_do_release_prog(map_fd, entry->idx);
247 	}
248 
249 	/* compact the entry list */
250 	n = jmp_table.tail;
251 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
252 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
253 
254 		if (!test_bit(entry->idx, jmp_table.enabled))
255 			continue;
256 
257 		jmp_table.entries[n] = jmp_table.entries[i];
258 		n = NEXT(n);
259 	}
260 
261 	jmp_table.head = n;
262 
263 	mutex_unlock(&hid_bpf_attach_lock);
264 
265 	if (map_fd >= 0)
266 		close_fd(map_fd);
267 }
268 
269 static void hid_bpf_release_prog_at(int idx)
270 {
271 	int map_fd = -1;
272 
273 	/* retrieve a fd of our prog_array map in BPF */
274 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
275 	if (map_fd < 0)
276 		return;
277 
278 	__hid_bpf_do_release_prog(map_fd, idx);
279 
280 	close(map_fd);
281 }
282 
283 /*
284  * Insert the given BPF program represented by its fd in the jmp table.
285  * Returns the index in the jump table or a negative error.
286  */
287 static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog)
288 {
289 	int i, cnt, index = -1, map_fd = -1, progs_map_fd = -1, err = -EINVAL;
290 
291 	/* retrieve a fd of our prog_array map in BPF */
292 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
293 	/* take an fd for the table of progs we monitor with SEC("fexit/bpf_prog_release") */
294 	progs_map_fd = skel_map_get_fd_by_id(jmp_table.prog_keys->id);
295 
296 	if (map_fd < 0 || progs_map_fd < 0) {
297 		err = -EINVAL;
298 		goto out;
299 	}
300 
301 	cnt = 0;
302 	/* find the first available index in the jmp_table
303 	 * and count how many time this program has been inserted
304 	 */
305 	for (i = 0; i < HID_BPF_MAX_PROGS; i++) {
306 		if (!jmp_table.progs[i] && index < 0) {
307 			/* mark the index as used */
308 			jmp_table.progs[i] = prog;
309 			index = i;
310 			__set_bit(i, jmp_table.enabled);
311 			cnt++;
312 		} else {
313 			if (jmp_table.progs[i] == prog)
314 				cnt++;
315 		}
316 	}
317 	if (index < 0) {
318 		err = -ENOMEM;
319 		goto out;
320 	}
321 
322 	/* insert the program in the jump table */
323 	err = skel_map_update_elem(map_fd, &index, &prog_fd, 0);
324 	if (err)
325 		goto out;
326 
327 	/* insert the program in the prog list table */
328 	err = skel_map_update_elem(progs_map_fd, &prog, &cnt, 0);
329 	if (err)
330 		goto out;
331 
332 	/* return the index */
333 	err = index;
334 
335  out:
336 	if (err < 0)
337 		__hid_bpf_do_release_prog(map_fd, index);
338 	if (map_fd >= 0)
339 		close_fd(map_fd);
340 	if (progs_map_fd >= 0)
341 		close_fd(progs_map_fd);
342 	return err;
343 }
344 
345 int hid_bpf_get_prog_attach_type(int prog_fd)
346 {
347 	struct bpf_prog *prog = NULL;
348 	int i;
349 	int prog_type = HID_BPF_PROG_TYPE_UNDEF;
350 
351 	prog = bpf_prog_get(prog_fd);
352 	if (IS_ERR(prog))
353 		return PTR_ERR(prog);
354 
355 	for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) {
356 		if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) {
357 			prog_type = i;
358 			break;
359 		}
360 	}
361 
362 	bpf_prog_put(prog);
363 
364 	return prog_type;
365 }
366 
367 /* called from syscall */
368 noinline int
369 __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
370 		      int prog_fd, __u32 flags)
371 {
372 	struct bpf_prog *prog = NULL;
373 	struct hid_bpf_prog_entry *prog_entry;
374 	int cnt, err = -EINVAL, prog_idx = -1;
375 
376 	/* take a ref on the prog itself */
377 	prog = bpf_prog_get(prog_fd);
378 	if (IS_ERR(prog))
379 		return PTR_ERR(prog);
380 
381 	mutex_lock(&hid_bpf_attach_lock);
382 
383 	/* do not attach too many programs to a given HID device */
384 	cnt = hid_bpf_program_count(hdev, NULL, prog_type);
385 	if (cnt < 0) {
386 		err = cnt;
387 		goto out_unlock;
388 	}
389 
390 	if (cnt >= hid_bpf_max_programs(prog_type)) {
391 		err = -E2BIG;
392 		goto out_unlock;
393 	}
394 
395 	prog_idx = hid_bpf_insert_prog(prog_fd, prog);
396 	/* if the jmp table is full, abort */
397 	if (prog_idx < 0) {
398 		err = prog_idx;
399 		goto out_unlock;
400 	}
401 
402 	if (flags & HID_BPF_FLAG_INSERT_HEAD) {
403 		/* take the previous prog_entry slot */
404 		jmp_table.tail = PREV(jmp_table.tail);
405 		prog_entry = &jmp_table.entries[jmp_table.tail];
406 	} else {
407 		/* take the next prog_entry slot */
408 		prog_entry = &jmp_table.entries[jmp_table.head];
409 		jmp_table.head = NEXT(jmp_table.head);
410 	}
411 
412 	/* we steal the ref here */
413 	prog_entry->prog = prog;
414 	prog_entry->idx = prog_idx;
415 	prog_entry->hdev = hdev;
416 	prog_entry->type = prog_type;
417 
418 	/* finally store the index in the device list */
419 	err = hid_bpf_populate_hdev(hdev, prog_type);
420 	if (err)
421 		hid_bpf_release_prog_at(prog_idx);
422 
423  out_unlock:
424 	mutex_unlock(&hid_bpf_attach_lock);
425 
426 	/* we only use prog as a key in the various tables, so we don't need to actually
427 	 * increment the ref count.
428 	 */
429 	bpf_prog_put(prog);
430 
431 	return err;
432 }
433 
434 void __hid_bpf_destroy_device(struct hid_device *hdev)
435 {
436 	int type, i;
437 	struct hid_bpf_prog_list *prog_list;
438 
439 	rcu_read_lock();
440 
441 	for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) {
442 		prog_list = rcu_dereference(hdev->bpf.progs[type]);
443 
444 		if (!prog_list)
445 			continue;
446 
447 		for (i = 0; i < prog_list->prog_cnt; i++)
448 			__clear_bit(prog_list->prog_idx[i], jmp_table.enabled);
449 	}
450 
451 	rcu_read_unlock();
452 
453 	for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++)
454 		__hid_bpf_set_hdev_progs(hdev, NULL, type);
455 
456 	/* schedule release of all detached progs */
457 	schedule_work(&release_work);
458 }
459 
460 noinline bool
461 call_hid_bpf_prog_release(u64 prog_key, int table_cnt)
462 {
463 	/* compare with how many refs are left in the bpf program */
464 	struct bpf_prog *prog = (struct bpf_prog *)prog_key;
465 	int idx;
466 
467 	if (!prog)
468 		return false;
469 
470 	if (atomic64_read(&prog->aux->refcnt) != table_cnt)
471 		return false;
472 
473 	/* we don't need locking here because the entries in the progs table
474 	 * are stable:
475 	 * if there are other users (and the progs entries might change), we
476 	 * would return in the statement above.
477 	 */
478 	for (idx = 0; idx < HID_BPF_MAX_PROGS; idx++) {
479 		if (jmp_table.progs[idx] == prog) {
480 			__clear_bit(idx, jmp_table.enabled);
481 			break;
482 		}
483 	}
484 	if (idx >= HID_BPF_MAX_PROGS) {
485 		/* should never happen if we get our refcount right */
486 		idx = -1;
487 	}
488 
489 	/* schedule release of all detached progs */
490 	schedule_work(&release_work);
491 	return idx >= 0;
492 }
493 
494 #define HID_BPF_PROGS_COUNT 3
495 
496 static struct bpf_link *links[HID_BPF_PROGS_COUNT];
497 static struct entrypoints_bpf *skel;
498 
499 void hid_bpf_free_links_and_skel(void)
500 {
501 	int i;
502 
503 	/* the following is enough to release all programs attached to hid */
504 	if (jmp_table.prog_keys)
505 		bpf_map_put_with_uref(jmp_table.prog_keys);
506 
507 	if (jmp_table.map)
508 		bpf_map_put_with_uref(jmp_table.map);
509 
510 	for (i = 0; i < ARRAY_SIZE(links); i++) {
511 		if (!IS_ERR_OR_NULL(links[i]))
512 			bpf_link_put(links[i]);
513 	}
514 	entrypoints_bpf__destroy(skel);
515 }
516 
517 #define ATTACH_AND_STORE_LINK(__name) do {					\
518 	err = entrypoints_bpf__##__name##__attach(skel);			\
519 	if (err)								\
520 		goto out;							\
521 										\
522 	links[idx] = bpf_link_get_from_fd(skel->links.__name##_fd);		\
523 	if (IS_ERR(links[idx])) {						\
524 		err = PTR_ERR(links[idx]);					\
525 		goto out;							\
526 	}									\
527 										\
528 	/* Avoid taking over stdin/stdout/stderr of init process. Zeroing out	\
529 	 * makes skel_closenz() a no-op later in iterators_bpf__destroy().	\
530 	 */									\
531 	close_fd(skel->links.__name##_fd);					\
532 	skel->links.__name##_fd = 0;						\
533 	idx++;									\
534 } while (0)
535 
536 int hid_bpf_preload_skel(void)
537 {
538 	int err, idx = 0;
539 
540 	skel = entrypoints_bpf__open();
541 	if (!skel)
542 		return -ENOMEM;
543 
544 	err = entrypoints_bpf__load(skel);
545 	if (err)
546 		goto out;
547 
548 	jmp_table.map = bpf_map_get_with_uref(skel->maps.hid_jmp_table.map_fd);
549 	if (IS_ERR(jmp_table.map)) {
550 		err = PTR_ERR(jmp_table.map);
551 		goto out;
552 	}
553 
554 	jmp_table.prog_keys = bpf_map_get_with_uref(skel->maps.progs_map.map_fd);
555 	if (IS_ERR(jmp_table.prog_keys)) {
556 		err = PTR_ERR(jmp_table.prog_keys);
557 		goto out;
558 	}
559 
560 	ATTACH_AND_STORE_LINK(hid_tail_call);
561 	ATTACH_AND_STORE_LINK(hid_prog_release);
562 	ATTACH_AND_STORE_LINK(hid_free_inode);
563 
564 	return 0;
565 out:
566 	hid_bpf_free_links_and_skel();
567 	return err;
568 }
569