1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 /*
4  *  HID-BPF support for Linux
5  *
6  *  Copyright (c) 2022 Benjamin Tissoires
7  */
8 
9 #include <linux/bitops.h>
10 #include <linux/btf.h>
11 #include <linux/btf_ids.h>
12 #include <linux/circ_buf.h>
13 #include <linux/filter.h>
14 #include <linux/hid.h>
15 #include <linux/hid_bpf.h>
16 #include <linux/init.h>
17 #include <linux/module.h>
18 #include <linux/workqueue.h>
19 #include "hid_bpf_dispatch.h"
20 #include "entrypoints/entrypoints.lskel.h"
21 
22 #define HID_BPF_MAX_PROGS 1024 /* keep this in sync with preloaded bpf,
23 				* needs to be a power of 2 as we use it as
24 				* a circular buffer
25 				*/
26 
27 #define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1))
28 #define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1))
29 
30 /*
31  * represents one attached program stored in the hid jump table
32  */
33 struct hid_bpf_prog_entry {
34 	struct bpf_prog *prog;
35 	struct hid_device *hdev;
36 	enum hid_bpf_prog_type type;
37 	u16 idx;
38 };
39 
40 struct hid_bpf_jmp_table {
41 	struct bpf_map *map;
42 	struct hid_bpf_prog_entry entries[HID_BPF_MAX_PROGS]; /* compacted list, circular buffer */
43 	int tail, head;
44 	struct bpf_prog *progs[HID_BPF_MAX_PROGS]; /* idx -> progs mapping */
45 	unsigned long enabled[BITS_TO_LONGS(HID_BPF_MAX_PROGS)];
46 };
47 
48 #define FOR_ENTRIES(__i, __start, __end) \
49 	for (__i = __start; CIRC_CNT(__end, __i, HID_BPF_MAX_PROGS); __i = NEXT(__i))
50 
51 static struct hid_bpf_jmp_table jmp_table;
52 
53 static DEFINE_MUTEX(hid_bpf_attach_lock);		/* held when attaching/detaching programs */
54 
55 static void hid_bpf_release_progs(struct work_struct *work);
56 
57 static DECLARE_WORK(release_work, hid_bpf_release_progs);
58 
59 BTF_ID_LIST(hid_bpf_btf_ids)
60 BTF_ID(func, hid_bpf_device_event)			/* HID_BPF_PROG_TYPE_DEVICE_EVENT */
61 BTF_ID(func, hid_bpf_rdesc_fixup)			/* HID_BPF_PROG_TYPE_RDESC_FIXUP */
62 
63 static int hid_bpf_max_programs(enum hid_bpf_prog_type type)
64 {
65 	switch (type) {
66 	case HID_BPF_PROG_TYPE_DEVICE_EVENT:
67 		return HID_BPF_MAX_PROGS_PER_DEV;
68 	case HID_BPF_PROG_TYPE_RDESC_FIXUP:
69 		return 1;
70 	default:
71 		return -EINVAL;
72 	}
73 }
74 
75 static int hid_bpf_program_count(struct hid_device *hdev,
76 				 struct bpf_prog *prog,
77 				 enum hid_bpf_prog_type type)
78 {
79 	int i, n = 0;
80 
81 	if (type >= HID_BPF_PROG_TYPE_MAX)
82 		return -EINVAL;
83 
84 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
85 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
86 
87 		if (type != HID_BPF_PROG_TYPE_UNDEF && entry->type != type)
88 			continue;
89 
90 		if (hdev && entry->hdev != hdev)
91 			continue;
92 
93 		if (prog && entry->prog != prog)
94 			continue;
95 
96 		n++;
97 	}
98 
99 	return n;
100 }
101 
102 __weak noinline int __hid_bpf_tail_call(struct hid_bpf_ctx *ctx)
103 {
104 	return 0;
105 }
106 
107 int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
108 		     struct hid_bpf_ctx_kern *ctx_kern)
109 {
110 	struct hid_bpf_prog_list *prog_list;
111 	int i, idx, err = 0;
112 
113 	rcu_read_lock();
114 	prog_list = rcu_dereference(hdev->bpf.progs[type]);
115 
116 	if (!prog_list)
117 		goto out_unlock;
118 
119 	for (i = 0; i < prog_list->prog_cnt; i++) {
120 		idx = prog_list->prog_idx[i];
121 
122 		if (!test_bit(idx, jmp_table.enabled))
123 			continue;
124 
125 		ctx_kern->ctx.index = idx;
126 		err = __hid_bpf_tail_call(&ctx_kern->ctx);
127 		if (err < 0)
128 			break;
129 		if (err)
130 			ctx_kern->ctx.retval = err;
131 	}
132 
133  out_unlock:
134 	rcu_read_unlock();
135 
136 	return err;
137 }
138 
139 /*
140  * assign the list of programs attached to a given hid device.
141  */
142 static void __hid_bpf_set_hdev_progs(struct hid_device *hdev, struct hid_bpf_prog_list *new_list,
143 				     enum hid_bpf_prog_type type)
144 {
145 	struct hid_bpf_prog_list *old_list;
146 
147 	spin_lock(&hdev->bpf.progs_lock);
148 	old_list = rcu_dereference_protected(hdev->bpf.progs[type],
149 					     lockdep_is_held(&hdev->bpf.progs_lock));
150 	rcu_assign_pointer(hdev->bpf.progs[type], new_list);
151 	spin_unlock(&hdev->bpf.progs_lock);
152 	synchronize_rcu();
153 
154 	kfree(old_list);
155 }
156 
157 /*
158  * allocate and populate the list of programs attached to a given hid device.
159  *
160  * Must be called under lock.
161  */
162 static int hid_bpf_populate_hdev(struct hid_device *hdev, enum hid_bpf_prog_type type)
163 {
164 	struct hid_bpf_prog_list *new_list;
165 	int i;
166 
167 	if (type >= HID_BPF_PROG_TYPE_MAX || !hdev)
168 		return -EINVAL;
169 
170 	if (hdev->bpf.destroyed)
171 		return 0;
172 
173 	new_list = kzalloc(sizeof(*new_list), GFP_KERNEL);
174 	if (!new_list)
175 		return -ENOMEM;
176 
177 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
178 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
179 
180 		if (entry->type == type && entry->hdev == hdev &&
181 		    test_bit(entry->idx, jmp_table.enabled))
182 			new_list->prog_idx[new_list->prog_cnt++] = entry->idx;
183 	}
184 
185 	__hid_bpf_set_hdev_progs(hdev, new_list, type);
186 
187 	return 0;
188 }
189 
190 static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
191 {
192 	skel_map_delete_elem(map_fd, &idx);
193 	jmp_table.progs[idx] = NULL;
194 }
195 
196 static void hid_bpf_release_progs(struct work_struct *work)
197 {
198 	int i, j, n, map_fd = -1;
199 	bool hdev_destroyed;
200 
201 	if (!jmp_table.map)
202 		return;
203 
204 	/* retrieve a fd of our prog_array map in BPF */
205 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
206 	if (map_fd < 0)
207 		return;
208 
209 	mutex_lock(&hid_bpf_attach_lock); /* protects against attaching new programs */
210 
211 	/* detach unused progs from HID devices */
212 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
213 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
214 		enum hid_bpf_prog_type type;
215 		struct hid_device *hdev;
216 
217 		if (test_bit(entry->idx, jmp_table.enabled))
218 			continue;
219 
220 		/* we have an attached prog */
221 		if (entry->hdev) {
222 			hdev = entry->hdev;
223 			type = entry->type;
224 			/*
225 			 * hdev is still valid, even if we are called after hid_destroy_device():
226 			 * when hid_bpf_attach() gets called, it takes a ref on the dev through
227 			 * bus_find_device()
228 			 */
229 			hdev_destroyed = hdev->bpf.destroyed;
230 
231 			hid_bpf_populate_hdev(hdev, type);
232 
233 			/* mark all other disabled progs from hdev of the given type as detached */
234 			FOR_ENTRIES(j, i, jmp_table.head) {
235 				struct hid_bpf_prog_entry *next;
236 
237 				next = &jmp_table.entries[j];
238 
239 				if (test_bit(next->idx, jmp_table.enabled))
240 					continue;
241 
242 				if (next->hdev == hdev && next->type == type) {
243 					/*
244 					 * clear the hdev reference and decrement the device ref
245 					 * that was taken during bus_find_device() while calling
246 					 * hid_bpf_attach()
247 					 */
248 					next->hdev = NULL;
249 					put_device(&hdev->dev);
250 				}
251 			}
252 
253 			/* if type was rdesc fixup and the device is not gone, reconnect device */
254 			if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP && !hdev_destroyed)
255 				hid_bpf_reconnect(hdev);
256 		}
257 	}
258 
259 	/* remove all unused progs from the jump table */
260 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
261 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
262 
263 		if (test_bit(entry->idx, jmp_table.enabled))
264 			continue;
265 
266 		if (entry->prog)
267 			__hid_bpf_do_release_prog(map_fd, entry->idx);
268 	}
269 
270 	/* compact the entry list */
271 	n = jmp_table.tail;
272 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
273 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
274 
275 		if (!test_bit(entry->idx, jmp_table.enabled))
276 			continue;
277 
278 		jmp_table.entries[n] = jmp_table.entries[i];
279 		n = NEXT(n);
280 	}
281 
282 	jmp_table.head = n;
283 
284 	mutex_unlock(&hid_bpf_attach_lock);
285 
286 	if (map_fd >= 0)
287 		close_fd(map_fd);
288 }
289 
290 static void hid_bpf_release_prog_at(int idx)
291 {
292 	int map_fd = -1;
293 
294 	/* retrieve a fd of our prog_array map in BPF */
295 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
296 	if (map_fd < 0)
297 		return;
298 
299 	__hid_bpf_do_release_prog(map_fd, idx);
300 
301 	close(map_fd);
302 }
303 
304 /*
305  * Insert the given BPF program represented by its fd in the jmp table.
306  * Returns the index in the jump table or a negative error.
307  */
308 static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog)
309 {
310 	int i, index = -1, map_fd = -1, err = -EINVAL;
311 
312 	/* retrieve a fd of our prog_array map in BPF */
313 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
314 
315 	if (map_fd < 0) {
316 		err = -EINVAL;
317 		goto out;
318 	}
319 
320 	/* find the first available index in the jmp_table */
321 	for (i = 0; i < HID_BPF_MAX_PROGS; i++) {
322 		if (!jmp_table.progs[i] && index < 0) {
323 			/* mark the index as used */
324 			jmp_table.progs[i] = prog;
325 			index = i;
326 			__set_bit(i, jmp_table.enabled);
327 		}
328 	}
329 	if (index < 0) {
330 		err = -ENOMEM;
331 		goto out;
332 	}
333 
334 	/* insert the program in the jump table */
335 	err = skel_map_update_elem(map_fd, &index, &prog_fd, 0);
336 	if (err)
337 		goto out;
338 
339 	/* return the index */
340 	err = index;
341 
342  out:
343 	if (err < 0)
344 		__hid_bpf_do_release_prog(map_fd, index);
345 	if (map_fd >= 0)
346 		close_fd(map_fd);
347 	return err;
348 }
349 
350 int hid_bpf_get_prog_attach_type(struct bpf_prog *prog)
351 {
352 	int prog_type = HID_BPF_PROG_TYPE_UNDEF;
353 	int i;
354 
355 	for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) {
356 		if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) {
357 			prog_type = i;
358 			break;
359 		}
360 	}
361 
362 	return prog_type;
363 }
364 
365 static void hid_bpf_link_release(struct bpf_link *link)
366 {
367 	struct hid_bpf_link *hid_link =
368 		container_of(link, struct hid_bpf_link, link);
369 
370 	__clear_bit(hid_link->hid_table_index, jmp_table.enabled);
371 	schedule_work(&release_work);
372 }
373 
374 static void hid_bpf_link_dealloc(struct bpf_link *link)
375 {
376 	struct hid_bpf_link *hid_link =
377 		container_of(link, struct hid_bpf_link, link);
378 
379 	kfree(hid_link);
380 }
381 
382 static void hid_bpf_link_show_fdinfo(const struct bpf_link *link,
383 					 struct seq_file *seq)
384 {
385 	seq_printf(seq,
386 		   "attach_type:\tHID-BPF\n");
387 }
388 
389 static const struct bpf_link_ops hid_bpf_link_lops = {
390 	.release = hid_bpf_link_release,
391 	.dealloc = hid_bpf_link_dealloc,
392 	.show_fdinfo = hid_bpf_link_show_fdinfo,
393 };
394 
395 /* called from syscall */
396 noinline int
397 __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
398 		      int prog_fd, struct bpf_prog *prog, __u32 flags)
399 {
400 	struct bpf_link_primer link_primer;
401 	struct hid_bpf_link *link;
402 	struct hid_bpf_prog_entry *prog_entry;
403 	int cnt, err = -EINVAL, prog_table_idx = -1;
404 
405 	mutex_lock(&hid_bpf_attach_lock);
406 
407 	link = kzalloc(sizeof(*link), GFP_USER);
408 	if (!link) {
409 		err = -ENOMEM;
410 		goto err_unlock;
411 	}
412 
413 	bpf_link_init(&link->link, BPF_LINK_TYPE_UNSPEC,
414 		      &hid_bpf_link_lops, prog);
415 
416 	/* do not attach too many programs to a given HID device */
417 	cnt = hid_bpf_program_count(hdev, NULL, prog_type);
418 	if (cnt < 0) {
419 		err = cnt;
420 		goto err_unlock;
421 	}
422 
423 	if (cnt >= hid_bpf_max_programs(prog_type)) {
424 		err = -E2BIG;
425 		goto err_unlock;
426 	}
427 
428 	prog_table_idx = hid_bpf_insert_prog(prog_fd, prog);
429 	/* if the jmp table is full, abort */
430 	if (prog_table_idx < 0) {
431 		err = prog_table_idx;
432 		goto err_unlock;
433 	}
434 
435 	if (flags & HID_BPF_FLAG_INSERT_HEAD) {
436 		/* take the previous prog_entry slot */
437 		jmp_table.tail = PREV(jmp_table.tail);
438 		prog_entry = &jmp_table.entries[jmp_table.tail];
439 	} else {
440 		/* take the next prog_entry slot */
441 		prog_entry = &jmp_table.entries[jmp_table.head];
442 		jmp_table.head = NEXT(jmp_table.head);
443 	}
444 
445 	/* we steal the ref here */
446 	prog_entry->prog = prog;
447 	prog_entry->idx = prog_table_idx;
448 	prog_entry->hdev = hdev;
449 	prog_entry->type = prog_type;
450 
451 	/* finally store the index in the device list */
452 	err = hid_bpf_populate_hdev(hdev, prog_type);
453 	if (err) {
454 		hid_bpf_release_prog_at(prog_table_idx);
455 		goto err_unlock;
456 	}
457 
458 	link->hid_table_index = prog_table_idx;
459 
460 	err = bpf_link_prime(&link->link, &link_primer);
461 	if (err)
462 		goto err_unlock;
463 
464 	mutex_unlock(&hid_bpf_attach_lock);
465 
466 	return bpf_link_settle(&link_primer);
467 
468  err_unlock:
469 	mutex_unlock(&hid_bpf_attach_lock);
470 
471 	kfree(link);
472 
473 	return err;
474 }
475 
476 void __hid_bpf_destroy_device(struct hid_device *hdev)
477 {
478 	int type, i;
479 	struct hid_bpf_prog_list *prog_list;
480 
481 	rcu_read_lock();
482 
483 	for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) {
484 		prog_list = rcu_dereference(hdev->bpf.progs[type]);
485 
486 		if (!prog_list)
487 			continue;
488 
489 		for (i = 0; i < prog_list->prog_cnt; i++)
490 			__clear_bit(prog_list->prog_idx[i], jmp_table.enabled);
491 	}
492 
493 	rcu_read_unlock();
494 
495 	for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++)
496 		__hid_bpf_set_hdev_progs(hdev, NULL, type);
497 
498 	/* schedule release of all detached progs */
499 	schedule_work(&release_work);
500 }
501 
502 #define HID_BPF_PROGS_COUNT 1
503 
504 static struct bpf_link *links[HID_BPF_PROGS_COUNT];
505 static struct entrypoints_bpf *skel;
506 
507 void hid_bpf_free_links_and_skel(void)
508 {
509 	int i;
510 
511 	/* the following is enough to release all programs attached to hid */
512 	if (jmp_table.map)
513 		bpf_map_put_with_uref(jmp_table.map);
514 
515 	for (i = 0; i < ARRAY_SIZE(links); i++) {
516 		if (!IS_ERR_OR_NULL(links[i]))
517 			bpf_link_put(links[i]);
518 	}
519 	entrypoints_bpf__destroy(skel);
520 }
521 
522 #define ATTACH_AND_STORE_LINK(__name) do {					\
523 	err = entrypoints_bpf__##__name##__attach(skel);			\
524 	if (err)								\
525 		goto out;							\
526 										\
527 	links[idx] = bpf_link_get_from_fd(skel->links.__name##_fd);		\
528 	if (IS_ERR(links[idx])) {						\
529 		err = PTR_ERR(links[idx]);					\
530 		goto out;							\
531 	}									\
532 										\
533 	/* Avoid taking over stdin/stdout/stderr of init process. Zeroing out	\
534 	 * makes skel_closenz() a no-op later in iterators_bpf__destroy().	\
535 	 */									\
536 	close_fd(skel->links.__name##_fd);					\
537 	skel->links.__name##_fd = 0;						\
538 	idx++;									\
539 } while (0)
540 
541 int hid_bpf_preload_skel(void)
542 {
543 	int err, idx = 0;
544 
545 	skel = entrypoints_bpf__open();
546 	if (!skel)
547 		return -ENOMEM;
548 
549 	err = entrypoints_bpf__load(skel);
550 	if (err)
551 		goto out;
552 
553 	jmp_table.map = bpf_map_get_with_uref(skel->maps.hid_jmp_table.map_fd);
554 	if (IS_ERR(jmp_table.map)) {
555 		err = PTR_ERR(jmp_table.map);
556 		goto out;
557 	}
558 
559 	ATTACH_AND_STORE_LINK(hid_tail_call);
560 
561 	return 0;
562 out:
563 	hid_bpf_free_links_and_skel();
564 	return err;
565 }
566