1 // SPDX-License-Identifier: GPL-2.0-only
2
3 /*
4 * HID-BPF support for Linux
5 *
6 * Copyright (c) 2022 Benjamin Tissoires
7 */
8
9 #include <linux/bitops.h>
10 #include <linux/btf.h>
11 #include <linux/btf_ids.h>
12 #include <linux/circ_buf.h>
13 #include <linux/filter.h>
14 #include <linux/hid.h>
15 #include <linux/hid_bpf.h>
16 #include <linux/init.h>
17 #include <linux/module.h>
18 #include <linux/workqueue.h>
19 #include "hid_bpf_dispatch.h"
20 #include "entrypoints/entrypoints.lskel.h"
21
22 #define HID_BPF_MAX_PROGS 1024 /* keep this in sync with preloaded bpf,
23 * needs to be a power of 2 as we use it as
24 * a circular buffer
25 */
26
27 #define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1))
28 #define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1))
29
30 /*
31 * represents one attached program stored in the hid jump table
32 */
33 struct hid_bpf_prog_entry {
34 struct bpf_prog *prog;
35 struct hid_device *hdev;
36 enum hid_bpf_prog_type type;
37 u16 idx;
38 };
39
40 struct hid_bpf_jmp_table {
41 struct bpf_map *map;
42 struct hid_bpf_prog_entry entries[HID_BPF_MAX_PROGS]; /* compacted list, circular buffer */
43 int tail, head;
44 struct bpf_prog *progs[HID_BPF_MAX_PROGS]; /* idx -> progs mapping */
45 unsigned long enabled[BITS_TO_LONGS(HID_BPF_MAX_PROGS)];
46 };
47
48 #define FOR_ENTRIES(__i, __start, __end) \
49 for (__i = __start; CIRC_CNT(__end, __i, HID_BPF_MAX_PROGS); __i = NEXT(__i))
50
51 static struct hid_bpf_jmp_table jmp_table;
52
53 static DEFINE_MUTEX(hid_bpf_attach_lock); /* held when attaching/detaching programs */
54
55 static void hid_bpf_release_progs(struct work_struct *work);
56
57 static DECLARE_WORK(release_work, hid_bpf_release_progs);
58
59 BTF_ID_LIST(hid_bpf_btf_ids)
BTF_ID(func,hid_bpf_device_event)60 BTF_ID(func, hid_bpf_device_event) /* HID_BPF_PROG_TYPE_DEVICE_EVENT */
61 BTF_ID(func, hid_bpf_rdesc_fixup) /* HID_BPF_PROG_TYPE_RDESC_FIXUP */
62
63 static int hid_bpf_max_programs(enum hid_bpf_prog_type type)
64 {
65 switch (type) {
66 case HID_BPF_PROG_TYPE_DEVICE_EVENT:
67 return HID_BPF_MAX_PROGS_PER_DEV;
68 case HID_BPF_PROG_TYPE_RDESC_FIXUP:
69 return 1;
70 default:
71 return -EINVAL;
72 }
73 }
74
hid_bpf_program_count(struct hid_device * hdev,struct bpf_prog * prog,enum hid_bpf_prog_type type)75 static int hid_bpf_program_count(struct hid_device *hdev,
76 struct bpf_prog *prog,
77 enum hid_bpf_prog_type type)
78 {
79 int i, n = 0;
80
81 if (type >= HID_BPF_PROG_TYPE_MAX)
82 return -EINVAL;
83
84 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
85 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
86
87 if (type != HID_BPF_PROG_TYPE_UNDEF && entry->type != type)
88 continue;
89
90 if (hdev && entry->hdev != hdev)
91 continue;
92
93 if (prog && entry->prog != prog)
94 continue;
95
96 n++;
97 }
98
99 return n;
100 }
101
__hid_bpf_tail_call(struct hid_bpf_ctx * ctx)102 __weak noinline int __hid_bpf_tail_call(struct hid_bpf_ctx *ctx)
103 {
104 return 0;
105 }
106
hid_bpf_prog_run(struct hid_device * hdev,enum hid_bpf_prog_type type,struct hid_bpf_ctx_kern * ctx_kern)107 int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
108 struct hid_bpf_ctx_kern *ctx_kern)
109 {
110 struct hid_bpf_prog_list *prog_list;
111 int i, idx, err = 0;
112
113 rcu_read_lock();
114 prog_list = rcu_dereference(hdev->bpf.progs[type]);
115
116 if (!prog_list)
117 goto out_unlock;
118
119 for (i = 0; i < prog_list->prog_cnt; i++) {
120 idx = prog_list->prog_idx[i];
121
122 if (!test_bit(idx, jmp_table.enabled))
123 continue;
124
125 ctx_kern->ctx.index = idx;
126 err = __hid_bpf_tail_call(&ctx_kern->ctx);
127 if (err < 0)
128 break;
129 if (err)
130 ctx_kern->ctx.retval = err;
131 }
132
133 out_unlock:
134 rcu_read_unlock();
135
136 return err;
137 }
138
139 /*
140 * assign the list of programs attached to a given hid device.
141 */
__hid_bpf_set_hdev_progs(struct hid_device * hdev,struct hid_bpf_prog_list * new_list,enum hid_bpf_prog_type type)142 static void __hid_bpf_set_hdev_progs(struct hid_device *hdev, struct hid_bpf_prog_list *new_list,
143 enum hid_bpf_prog_type type)
144 {
145 struct hid_bpf_prog_list *old_list;
146
147 spin_lock(&hdev->bpf.progs_lock);
148 old_list = rcu_dereference_protected(hdev->bpf.progs[type],
149 lockdep_is_held(&hdev->bpf.progs_lock));
150 rcu_assign_pointer(hdev->bpf.progs[type], new_list);
151 spin_unlock(&hdev->bpf.progs_lock);
152 synchronize_rcu();
153
154 kfree(old_list);
155 }
156
157 /*
158 * allocate and populate the list of programs attached to a given hid device.
159 *
160 * Must be called under lock.
161 */
hid_bpf_populate_hdev(struct hid_device * hdev,enum hid_bpf_prog_type type)162 static int hid_bpf_populate_hdev(struct hid_device *hdev, enum hid_bpf_prog_type type)
163 {
164 struct hid_bpf_prog_list *new_list;
165 int i;
166
167 if (type >= HID_BPF_PROG_TYPE_MAX || !hdev)
168 return -EINVAL;
169
170 if (hdev->bpf.destroyed)
171 return 0;
172
173 new_list = kzalloc(sizeof(*new_list), GFP_KERNEL);
174 if (!new_list)
175 return -ENOMEM;
176
177 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
178 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
179
180 if (entry->type == type && entry->hdev == hdev &&
181 test_bit(entry->idx, jmp_table.enabled))
182 new_list->prog_idx[new_list->prog_cnt++] = entry->idx;
183 }
184
185 __hid_bpf_set_hdev_progs(hdev, new_list, type);
186
187 return 0;
188 }
189
__hid_bpf_do_release_prog(int map_fd,unsigned int idx)190 static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
191 {
192 skel_map_delete_elem(map_fd, &idx);
193 jmp_table.progs[idx] = NULL;
194 }
195
hid_bpf_release_progs(struct work_struct * work)196 static void hid_bpf_release_progs(struct work_struct *work)
197 {
198 int i, j, n, map_fd = -1;
199 bool hdev_destroyed;
200
201 if (!jmp_table.map)
202 return;
203
204 /* retrieve a fd of our prog_array map in BPF */
205 map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
206 if (map_fd < 0)
207 return;
208
209 mutex_lock(&hid_bpf_attach_lock); /* protects against attaching new programs */
210
211 /* detach unused progs from HID devices */
212 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
213 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
214 enum hid_bpf_prog_type type;
215 struct hid_device *hdev;
216
217 if (test_bit(entry->idx, jmp_table.enabled))
218 continue;
219
220 /* we have an attached prog */
221 if (entry->hdev) {
222 hdev = entry->hdev;
223 type = entry->type;
224 /*
225 * hdev is still valid, even if we are called after hid_destroy_device():
226 * when hid_bpf_attach() gets called, it takes a ref on the dev through
227 * bus_find_device()
228 */
229 hdev_destroyed = hdev->bpf.destroyed;
230
231 hid_bpf_populate_hdev(hdev, type);
232
233 /* mark all other disabled progs from hdev of the given type as detached */
234 FOR_ENTRIES(j, i, jmp_table.head) {
235 struct hid_bpf_prog_entry *next;
236
237 next = &jmp_table.entries[j];
238
239 if (test_bit(next->idx, jmp_table.enabled))
240 continue;
241
242 if (next->hdev == hdev && next->type == type) {
243 /*
244 * clear the hdev reference and decrement the device ref
245 * that was taken during bus_find_device() while calling
246 * hid_bpf_attach()
247 */
248 next->hdev = NULL;
249 put_device(&hdev->dev);
250 }
251 }
252
253 /* if type was rdesc fixup and the device is not gone, reconnect device */
254 if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP && !hdev_destroyed)
255 hid_bpf_reconnect(hdev);
256 }
257 }
258
259 /* remove all unused progs from the jump table */
260 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
261 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
262
263 if (test_bit(entry->idx, jmp_table.enabled))
264 continue;
265
266 if (entry->prog)
267 __hid_bpf_do_release_prog(map_fd, entry->idx);
268 }
269
270 /* compact the entry list */
271 n = jmp_table.tail;
272 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
273 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
274
275 if (!test_bit(entry->idx, jmp_table.enabled))
276 continue;
277
278 jmp_table.entries[n] = jmp_table.entries[i];
279 n = NEXT(n);
280 }
281
282 jmp_table.head = n;
283
284 mutex_unlock(&hid_bpf_attach_lock);
285
286 if (map_fd >= 0)
287 close_fd(map_fd);
288 }
289
hid_bpf_release_prog_at(int idx)290 static void hid_bpf_release_prog_at(int idx)
291 {
292 int map_fd = -1;
293
294 /* retrieve a fd of our prog_array map in BPF */
295 map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
296 if (map_fd < 0)
297 return;
298
299 __hid_bpf_do_release_prog(map_fd, idx);
300
301 close(map_fd);
302 }
303
304 /*
305 * Insert the given BPF program represented by its fd in the jmp table.
306 * Returns the index in the jump table or a negative error.
307 */
hid_bpf_insert_prog(int prog_fd,struct bpf_prog * prog)308 static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog)
309 {
310 int i, index = -1, map_fd = -1, err = -EINVAL;
311
312 /* retrieve a fd of our prog_array map in BPF */
313 map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
314
315 if (map_fd < 0) {
316 err = -EINVAL;
317 goto out;
318 }
319
320 /* find the first available index in the jmp_table */
321 for (i = 0; i < HID_BPF_MAX_PROGS; i++) {
322 if (!jmp_table.progs[i] && index < 0) {
323 /* mark the index as used */
324 jmp_table.progs[i] = prog;
325 index = i;
326 __set_bit(i, jmp_table.enabled);
327 }
328 }
329 if (index < 0) {
330 err = -ENOMEM;
331 goto out;
332 }
333
334 /* insert the program in the jump table */
335 err = skel_map_update_elem(map_fd, &index, &prog_fd, 0);
336 if (err)
337 goto out;
338
339 /* return the index */
340 err = index;
341
342 out:
343 if (err < 0)
344 __hid_bpf_do_release_prog(map_fd, index);
345 if (map_fd >= 0)
346 close_fd(map_fd);
347 return err;
348 }
349
hid_bpf_get_prog_attach_type(struct bpf_prog * prog)350 int hid_bpf_get_prog_attach_type(struct bpf_prog *prog)
351 {
352 int prog_type = HID_BPF_PROG_TYPE_UNDEF;
353 int i;
354
355 for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) {
356 if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) {
357 prog_type = i;
358 break;
359 }
360 }
361
362 return prog_type;
363 }
364
hid_bpf_link_release(struct bpf_link * link)365 static void hid_bpf_link_release(struct bpf_link *link)
366 {
367 struct hid_bpf_link *hid_link =
368 container_of(link, struct hid_bpf_link, link);
369
370 __clear_bit(hid_link->hid_table_index, jmp_table.enabled);
371 schedule_work(&release_work);
372 }
373
hid_bpf_link_dealloc(struct bpf_link * link)374 static void hid_bpf_link_dealloc(struct bpf_link *link)
375 {
376 struct hid_bpf_link *hid_link =
377 container_of(link, struct hid_bpf_link, link);
378
379 kfree(hid_link);
380 }
381
hid_bpf_link_show_fdinfo(const struct bpf_link * link,struct seq_file * seq)382 static void hid_bpf_link_show_fdinfo(const struct bpf_link *link,
383 struct seq_file *seq)
384 {
385 seq_printf(seq,
386 "attach_type:\tHID-BPF\n");
387 }
388
389 static const struct bpf_link_ops hid_bpf_link_lops = {
390 .release = hid_bpf_link_release,
391 .dealloc = hid_bpf_link_dealloc,
392 .show_fdinfo = hid_bpf_link_show_fdinfo,
393 };
394
395 /* called from syscall */
396 noinline int
__hid_bpf_attach_prog(struct hid_device * hdev,enum hid_bpf_prog_type prog_type,int prog_fd,struct bpf_prog * prog,__u32 flags)397 __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
398 int prog_fd, struct bpf_prog *prog, __u32 flags)
399 {
400 struct bpf_link_primer link_primer;
401 struct hid_bpf_link *link;
402 struct hid_bpf_prog_entry *prog_entry;
403 int cnt, err = -EINVAL, prog_table_idx = -1;
404
405 mutex_lock(&hid_bpf_attach_lock);
406
407 link = kzalloc(sizeof(*link), GFP_USER);
408 if (!link) {
409 err = -ENOMEM;
410 goto err_unlock;
411 }
412
413 bpf_link_init(&link->link, BPF_LINK_TYPE_UNSPEC,
414 &hid_bpf_link_lops, prog);
415
416 /* do not attach too many programs to a given HID device */
417 cnt = hid_bpf_program_count(hdev, NULL, prog_type);
418 if (cnt < 0) {
419 err = cnt;
420 goto err_unlock;
421 }
422
423 if (cnt >= hid_bpf_max_programs(prog_type)) {
424 err = -E2BIG;
425 goto err_unlock;
426 }
427
428 prog_table_idx = hid_bpf_insert_prog(prog_fd, prog);
429 /* if the jmp table is full, abort */
430 if (prog_table_idx < 0) {
431 err = prog_table_idx;
432 goto err_unlock;
433 }
434
435 if (flags & HID_BPF_FLAG_INSERT_HEAD) {
436 /* take the previous prog_entry slot */
437 jmp_table.tail = PREV(jmp_table.tail);
438 prog_entry = &jmp_table.entries[jmp_table.tail];
439 } else {
440 /* take the next prog_entry slot */
441 prog_entry = &jmp_table.entries[jmp_table.head];
442 jmp_table.head = NEXT(jmp_table.head);
443 }
444
445 /* we steal the ref here */
446 prog_entry->prog = prog;
447 prog_entry->idx = prog_table_idx;
448 prog_entry->hdev = hdev;
449 prog_entry->type = prog_type;
450
451 /* finally store the index in the device list */
452 err = hid_bpf_populate_hdev(hdev, prog_type);
453 if (err) {
454 hid_bpf_release_prog_at(prog_table_idx);
455 goto err_unlock;
456 }
457
458 link->hid_table_index = prog_table_idx;
459
460 err = bpf_link_prime(&link->link, &link_primer);
461 if (err)
462 goto err_unlock;
463
464 mutex_unlock(&hid_bpf_attach_lock);
465
466 return bpf_link_settle(&link_primer);
467
468 err_unlock:
469 mutex_unlock(&hid_bpf_attach_lock);
470
471 kfree(link);
472
473 return err;
474 }
475
__hid_bpf_destroy_device(struct hid_device * hdev)476 void __hid_bpf_destroy_device(struct hid_device *hdev)
477 {
478 int type, i;
479 struct hid_bpf_prog_list *prog_list;
480
481 rcu_read_lock();
482
483 for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) {
484 prog_list = rcu_dereference(hdev->bpf.progs[type]);
485
486 if (!prog_list)
487 continue;
488
489 for (i = 0; i < prog_list->prog_cnt; i++)
490 __clear_bit(prog_list->prog_idx[i], jmp_table.enabled);
491 }
492
493 rcu_read_unlock();
494
495 for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++)
496 __hid_bpf_set_hdev_progs(hdev, NULL, type);
497
498 /* schedule release of all detached progs */
499 schedule_work(&release_work);
500 }
501
502 #define HID_BPF_PROGS_COUNT 1
503
504 static struct bpf_link *links[HID_BPF_PROGS_COUNT];
505 static struct entrypoints_bpf *skel;
506
hid_bpf_free_links_and_skel(void)507 void hid_bpf_free_links_and_skel(void)
508 {
509 int i;
510
511 /* the following is enough to release all programs attached to hid */
512 if (jmp_table.map)
513 bpf_map_put_with_uref(jmp_table.map);
514
515 for (i = 0; i < ARRAY_SIZE(links); i++) {
516 if (!IS_ERR_OR_NULL(links[i]))
517 bpf_link_put(links[i]);
518 }
519 entrypoints_bpf__destroy(skel);
520 }
521
522 #define ATTACH_AND_STORE_LINK(__name) do { \
523 err = entrypoints_bpf__##__name##__attach(skel); \
524 if (err) \
525 goto out; \
526 \
527 links[idx] = bpf_link_get_from_fd(skel->links.__name##_fd); \
528 if (IS_ERR(links[idx])) { \
529 err = PTR_ERR(links[idx]); \
530 goto out; \
531 } \
532 \
533 /* Avoid taking over stdin/stdout/stderr of init process. Zeroing out \
534 * makes skel_closenz() a no-op later in iterators_bpf__destroy(). \
535 */ \
536 close_fd(skel->links.__name##_fd); \
537 skel->links.__name##_fd = 0; \
538 idx++; \
539 } while (0)
540
hid_bpf_preload_skel(void)541 int hid_bpf_preload_skel(void)
542 {
543 int err, idx = 0;
544
545 skel = entrypoints_bpf__open();
546 if (!skel)
547 return -ENOMEM;
548
549 err = entrypoints_bpf__load(skel);
550 if (err)
551 goto out;
552
553 jmp_table.map = bpf_map_get_with_uref(skel->maps.hid_jmp_table.map_fd);
554 if (IS_ERR(jmp_table.map)) {
555 err = PTR_ERR(jmp_table.map);
556 goto out;
557 }
558
559 ATTACH_AND_STORE_LINK(hid_tail_call);
560
561 return 0;
562 out:
563 hid_bpf_free_links_and_skel();
564 return err;
565 }
566