xref: /openbmc/linux/drivers/hid/bpf/hid_bpf_jmp_table.c (revision 2b91c4a870c9830eaf95e744454c9c218cccb736)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 /*
4  *  HID-BPF support for Linux
5  *
6  *  Copyright (c) 2022 Benjamin Tissoires
7  */
8 
9 #include <linux/bitops.h>
10 #include <linux/btf.h>
11 #include <linux/btf_ids.h>
12 #include <linux/circ_buf.h>
13 #include <linux/filter.h>
14 #include <linux/hid.h>
15 #include <linux/hid_bpf.h>
16 #include <linux/init.h>
17 #include <linux/module.h>
18 #include <linux/workqueue.h>
19 #include "hid_bpf_dispatch.h"
20 #include "entrypoints/entrypoints.lskel.h"
21 
22 #define HID_BPF_MAX_PROGS 1024 /* keep this in sync with preloaded bpf,
23 				* needs to be a power of 2 as we use it as
24 				* a circular buffer
25 				*/
26 
27 #define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1))
28 #define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1))
29 
30 /*
31  * represents one attached program stored in the hid jump table
32  */
33 struct hid_bpf_prog_entry {
34 	struct bpf_prog *prog;
35 	struct hid_device *hdev;
36 	enum hid_bpf_prog_type type;
37 	u16 idx;
38 };
39 
40 struct hid_bpf_jmp_table {
41 	struct bpf_map *map;
42 	struct hid_bpf_prog_entry entries[HID_BPF_MAX_PROGS]; /* compacted list, circular buffer */
43 	int tail, head;
44 	struct bpf_prog *progs[HID_BPF_MAX_PROGS]; /* idx -> progs mapping */
45 	unsigned long enabled[BITS_TO_LONGS(HID_BPF_MAX_PROGS)];
46 };
47 
48 #define FOR_ENTRIES(__i, __start, __end) \
49 	for (__i = __start; CIRC_CNT(__end, __i, HID_BPF_MAX_PROGS); __i = NEXT(__i))
50 
51 static struct hid_bpf_jmp_table jmp_table;
52 
53 static DEFINE_MUTEX(hid_bpf_attach_lock);		/* held when attaching/detaching programs */
54 
55 static void hid_bpf_release_progs(struct work_struct *work);
56 
57 static DECLARE_WORK(release_work, hid_bpf_release_progs);
58 
59 BTF_ID_LIST(hid_bpf_btf_ids)
60 BTF_ID(func, hid_bpf_device_event)			/* HID_BPF_PROG_TYPE_DEVICE_EVENT */
61 BTF_ID(func, hid_bpf_rdesc_fixup)			/* HID_BPF_PROG_TYPE_RDESC_FIXUP */
62 
63 static int hid_bpf_max_programs(enum hid_bpf_prog_type type)
64 {
65 	switch (type) {
66 	case HID_BPF_PROG_TYPE_DEVICE_EVENT:
67 		return HID_BPF_MAX_PROGS_PER_DEV;
68 	case HID_BPF_PROG_TYPE_RDESC_FIXUP:
69 		return 1;
70 	default:
71 		return -EINVAL;
72 	}
73 }
74 
75 static int hid_bpf_program_count(struct hid_device *hdev,
76 				 struct bpf_prog *prog,
77 				 enum hid_bpf_prog_type type)
78 {
79 	int i, n = 0;
80 
81 	if (type >= HID_BPF_PROG_TYPE_MAX)
82 		return -EINVAL;
83 
84 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
85 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
86 
87 		if (type != HID_BPF_PROG_TYPE_UNDEF && entry->type != type)
88 			continue;
89 
90 		if (hdev && entry->hdev != hdev)
91 			continue;
92 
93 		if (prog && entry->prog != prog)
94 			continue;
95 
96 		n++;
97 	}
98 
99 	return n;
100 }
101 
102 __weak noinline int __hid_bpf_tail_call(struct hid_bpf_ctx *ctx)
103 {
104 	return 0;
105 }
106 
107 int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
108 		     struct hid_bpf_ctx_kern *ctx_kern)
109 {
110 	struct hid_bpf_prog_list *prog_list;
111 	int i, idx, err = 0;
112 
113 	rcu_read_lock();
114 	prog_list = rcu_dereference(hdev->bpf.progs[type]);
115 
116 	if (!prog_list)
117 		goto out_unlock;
118 
119 	for (i = 0; i < prog_list->prog_cnt; i++) {
120 		idx = prog_list->prog_idx[i];
121 
122 		if (!test_bit(idx, jmp_table.enabled))
123 			continue;
124 
125 		ctx_kern->ctx.index = idx;
126 		err = __hid_bpf_tail_call(&ctx_kern->ctx);
127 		if (err < 0)
128 			break;
129 		if (err)
130 			ctx_kern->ctx.retval = err;
131 	}
132 
133  out_unlock:
134 	rcu_read_unlock();
135 
136 	return err;
137 }
138 
139 /*
140  * assign the list of programs attached to a given hid device.
141  */
142 static void __hid_bpf_set_hdev_progs(struct hid_device *hdev, struct hid_bpf_prog_list *new_list,
143 				     enum hid_bpf_prog_type type)
144 {
145 	struct hid_bpf_prog_list *old_list;
146 
147 	spin_lock(&hdev->bpf.progs_lock);
148 	old_list = rcu_dereference_protected(hdev->bpf.progs[type],
149 					     lockdep_is_held(&hdev->bpf.progs_lock));
150 	rcu_assign_pointer(hdev->bpf.progs[type], new_list);
151 	spin_unlock(&hdev->bpf.progs_lock);
152 	synchronize_rcu();
153 
154 	kfree(old_list);
155 }
156 
157 /*
158  * allocate and populate the list of programs attached to a given hid device.
159  *
160  * Must be called under lock.
161  */
162 static int hid_bpf_populate_hdev(struct hid_device *hdev, enum hid_bpf_prog_type type)
163 {
164 	struct hid_bpf_prog_list *new_list;
165 	int i;
166 
167 	if (type >= HID_BPF_PROG_TYPE_MAX || !hdev)
168 		return -EINVAL;
169 
170 	if (hdev->bpf.destroyed)
171 		return 0;
172 
173 	new_list = kzalloc(sizeof(*new_list), GFP_KERNEL);
174 	if (!new_list)
175 		return -ENOMEM;
176 
177 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
178 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
179 
180 		if (entry->type == type && entry->hdev == hdev &&
181 		    test_bit(entry->idx, jmp_table.enabled))
182 			new_list->prog_idx[new_list->prog_cnt++] = entry->idx;
183 	}
184 
185 	__hid_bpf_set_hdev_progs(hdev, new_list, type);
186 
187 	return 0;
188 }
189 
190 static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
191 {
192 	skel_map_delete_elem(map_fd, &idx);
193 	jmp_table.progs[idx] = NULL;
194 }
195 
196 static void hid_bpf_release_progs(struct work_struct *work)
197 {
198 	int i, j, n, map_fd = -1;
199 
200 	if (!jmp_table.map)
201 		return;
202 
203 	/* retrieve a fd of our prog_array map in BPF */
204 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
205 	if (map_fd < 0)
206 		return;
207 
208 	mutex_lock(&hid_bpf_attach_lock); /* protects against attaching new programs */
209 
210 	/* detach unused progs from HID devices */
211 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
212 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
213 		enum hid_bpf_prog_type type;
214 		struct hid_device *hdev;
215 
216 		if (test_bit(entry->idx, jmp_table.enabled))
217 			continue;
218 
219 		/* we have an attached prog */
220 		if (entry->hdev) {
221 			hdev = entry->hdev;
222 			type = entry->type;
223 
224 			hid_bpf_populate_hdev(hdev, type);
225 
226 			/* mark all other disabled progs from hdev of the given type as detached */
227 			FOR_ENTRIES(j, i, jmp_table.head) {
228 				struct hid_bpf_prog_entry *next;
229 
230 				next = &jmp_table.entries[j];
231 
232 				if (test_bit(next->idx, jmp_table.enabled))
233 					continue;
234 
235 				if (next->hdev == hdev && next->type == type)
236 					next->hdev = NULL;
237 			}
238 
239 			/* if type was rdesc fixup, reconnect device */
240 			if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP)
241 				hid_bpf_reconnect(hdev);
242 		}
243 	}
244 
245 	/* remove all unused progs from the jump table */
246 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
247 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
248 
249 		if (test_bit(entry->idx, jmp_table.enabled))
250 			continue;
251 
252 		if (entry->prog)
253 			__hid_bpf_do_release_prog(map_fd, entry->idx);
254 	}
255 
256 	/* compact the entry list */
257 	n = jmp_table.tail;
258 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
259 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
260 
261 		if (!test_bit(entry->idx, jmp_table.enabled))
262 			continue;
263 
264 		jmp_table.entries[n] = jmp_table.entries[i];
265 		n = NEXT(n);
266 	}
267 
268 	jmp_table.head = n;
269 
270 	mutex_unlock(&hid_bpf_attach_lock);
271 
272 	if (map_fd >= 0)
273 		close_fd(map_fd);
274 }
275 
276 static void hid_bpf_release_prog_at(int idx)
277 {
278 	int map_fd = -1;
279 
280 	/* retrieve a fd of our prog_array map in BPF */
281 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
282 	if (map_fd < 0)
283 		return;
284 
285 	__hid_bpf_do_release_prog(map_fd, idx);
286 
287 	close(map_fd);
288 }
289 
290 /*
291  * Insert the given BPF program represented by its fd in the jmp table.
292  * Returns the index in the jump table or a negative error.
293  */
294 static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog)
295 {
296 	int i, index = -1, map_fd = -1, err = -EINVAL;
297 
298 	/* retrieve a fd of our prog_array map in BPF */
299 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
300 
301 	if (map_fd < 0) {
302 		err = -EINVAL;
303 		goto out;
304 	}
305 
306 	/* find the first available index in the jmp_table */
307 	for (i = 0; i < HID_BPF_MAX_PROGS; i++) {
308 		if (!jmp_table.progs[i] && index < 0) {
309 			/* mark the index as used */
310 			jmp_table.progs[i] = prog;
311 			index = i;
312 			__set_bit(i, jmp_table.enabled);
313 		}
314 	}
315 	if (index < 0) {
316 		err = -ENOMEM;
317 		goto out;
318 	}
319 
320 	/* insert the program in the jump table */
321 	err = skel_map_update_elem(map_fd, &index, &prog_fd, 0);
322 	if (err)
323 		goto out;
324 
325 	/* return the index */
326 	err = index;
327 
328  out:
329 	if (err < 0)
330 		__hid_bpf_do_release_prog(map_fd, index);
331 	if (map_fd >= 0)
332 		close_fd(map_fd);
333 	return err;
334 }
335 
336 int hid_bpf_get_prog_attach_type(int prog_fd)
337 {
338 	struct bpf_prog *prog = NULL;
339 	int i;
340 	int prog_type = HID_BPF_PROG_TYPE_UNDEF;
341 
342 	prog = bpf_prog_get(prog_fd);
343 	if (IS_ERR(prog))
344 		return PTR_ERR(prog);
345 
346 	for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) {
347 		if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) {
348 			prog_type = i;
349 			break;
350 		}
351 	}
352 
353 	bpf_prog_put(prog);
354 
355 	return prog_type;
356 }
357 
358 static void hid_bpf_link_release(struct bpf_link *link)
359 {
360 	struct hid_bpf_link *hid_link =
361 		container_of(link, struct hid_bpf_link, link);
362 
363 	__clear_bit(hid_link->hid_table_index, jmp_table.enabled);
364 	schedule_work(&release_work);
365 }
366 
367 static void hid_bpf_link_dealloc(struct bpf_link *link)
368 {
369 	struct hid_bpf_link *hid_link =
370 		container_of(link, struct hid_bpf_link, link);
371 
372 	kfree(hid_link);
373 }
374 
375 static void hid_bpf_link_show_fdinfo(const struct bpf_link *link,
376 					 struct seq_file *seq)
377 {
378 	seq_printf(seq,
379 		   "attach_type:\tHID-BPF\n");
380 }
381 
382 static const struct bpf_link_ops hid_bpf_link_lops = {
383 	.release = hid_bpf_link_release,
384 	.dealloc = hid_bpf_link_dealloc,
385 	.show_fdinfo = hid_bpf_link_show_fdinfo,
386 };
387 
388 /* called from syscall */
389 noinline int
390 __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
391 		      int prog_fd, __u32 flags)
392 {
393 	struct bpf_link_primer link_primer;
394 	struct hid_bpf_link *link;
395 	struct bpf_prog *prog = NULL;
396 	struct hid_bpf_prog_entry *prog_entry;
397 	int cnt, err = -EINVAL, prog_table_idx = -1;
398 
399 	/* take a ref on the prog itself */
400 	prog = bpf_prog_get(prog_fd);
401 	if (IS_ERR(prog))
402 		return PTR_ERR(prog);
403 
404 	mutex_lock(&hid_bpf_attach_lock);
405 
406 	link = kzalloc(sizeof(*link), GFP_USER);
407 	if (!link) {
408 		err = -ENOMEM;
409 		goto err_unlock;
410 	}
411 
412 	bpf_link_init(&link->link, BPF_LINK_TYPE_UNSPEC,
413 		      &hid_bpf_link_lops, prog);
414 
415 	/* do not attach too many programs to a given HID device */
416 	cnt = hid_bpf_program_count(hdev, NULL, prog_type);
417 	if (cnt < 0) {
418 		err = cnt;
419 		goto err_unlock;
420 	}
421 
422 	if (cnt >= hid_bpf_max_programs(prog_type)) {
423 		err = -E2BIG;
424 		goto err_unlock;
425 	}
426 
427 	prog_table_idx = hid_bpf_insert_prog(prog_fd, prog);
428 	/* if the jmp table is full, abort */
429 	if (prog_table_idx < 0) {
430 		err = prog_table_idx;
431 		goto err_unlock;
432 	}
433 
434 	if (flags & HID_BPF_FLAG_INSERT_HEAD) {
435 		/* take the previous prog_entry slot */
436 		jmp_table.tail = PREV(jmp_table.tail);
437 		prog_entry = &jmp_table.entries[jmp_table.tail];
438 	} else {
439 		/* take the next prog_entry slot */
440 		prog_entry = &jmp_table.entries[jmp_table.head];
441 		jmp_table.head = NEXT(jmp_table.head);
442 	}
443 
444 	/* we steal the ref here */
445 	prog_entry->prog = prog;
446 	prog_entry->idx = prog_table_idx;
447 	prog_entry->hdev = hdev;
448 	prog_entry->type = prog_type;
449 
450 	/* finally store the index in the device list */
451 	err = hid_bpf_populate_hdev(hdev, prog_type);
452 	if (err) {
453 		hid_bpf_release_prog_at(prog_table_idx);
454 		goto err_unlock;
455 	}
456 
457 	link->hid_table_index = prog_table_idx;
458 
459 	err = bpf_link_prime(&link->link, &link_primer);
460 	if (err)
461 		goto err_unlock;
462 
463 	mutex_unlock(&hid_bpf_attach_lock);
464 
465 	return bpf_link_settle(&link_primer);
466 
467  err_unlock:
468 	mutex_unlock(&hid_bpf_attach_lock);
469 
470 	bpf_prog_put(prog);
471 	kfree(link);
472 
473 	return err;
474 }
475 
476 void __hid_bpf_destroy_device(struct hid_device *hdev)
477 {
478 	int type, i;
479 	struct hid_bpf_prog_list *prog_list;
480 
481 	rcu_read_lock();
482 
483 	for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) {
484 		prog_list = rcu_dereference(hdev->bpf.progs[type]);
485 
486 		if (!prog_list)
487 			continue;
488 
489 		for (i = 0; i < prog_list->prog_cnt; i++)
490 			__clear_bit(prog_list->prog_idx[i], jmp_table.enabled);
491 	}
492 
493 	rcu_read_unlock();
494 
495 	for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++)
496 		__hid_bpf_set_hdev_progs(hdev, NULL, type);
497 
498 	/* schedule release of all detached progs */
499 	schedule_work(&release_work);
500 }
501 
502 #define HID_BPF_PROGS_COUNT 1
503 
504 static struct bpf_link *links[HID_BPF_PROGS_COUNT];
505 static struct entrypoints_bpf *skel;
506 
507 void hid_bpf_free_links_and_skel(void)
508 {
509 	int i;
510 
511 	/* the following is enough to release all programs attached to hid */
512 	if (jmp_table.map)
513 		bpf_map_put_with_uref(jmp_table.map);
514 
515 	for (i = 0; i < ARRAY_SIZE(links); i++) {
516 		if (!IS_ERR_OR_NULL(links[i]))
517 			bpf_link_put(links[i]);
518 	}
519 	entrypoints_bpf__destroy(skel);
520 }
521 
522 #define ATTACH_AND_STORE_LINK(__name) do {					\
523 	err = entrypoints_bpf__##__name##__attach(skel);			\
524 	if (err)								\
525 		goto out;							\
526 										\
527 	links[idx] = bpf_link_get_from_fd(skel->links.__name##_fd);		\
528 	if (IS_ERR(links[idx])) {						\
529 		err = PTR_ERR(links[idx]);					\
530 		goto out;							\
531 	}									\
532 										\
533 	/* Avoid taking over stdin/stdout/stderr of init process. Zeroing out	\
534 	 * makes skel_closenz() a no-op later in iterators_bpf__destroy().	\
535 	 */									\
536 	close_fd(skel->links.__name##_fd);					\
537 	skel->links.__name##_fd = 0;						\
538 	idx++;									\
539 } while (0)
540 
541 int hid_bpf_preload_skel(void)
542 {
543 	int err, idx = 0;
544 
545 	skel = entrypoints_bpf__open();
546 	if (!skel)
547 		return -ENOMEM;
548 
549 	err = entrypoints_bpf__load(skel);
550 	if (err)
551 		goto out;
552 
553 	jmp_table.map = bpf_map_get_with_uref(skel->maps.hid_jmp_table.map_fd);
554 	if (IS_ERR(jmp_table.map)) {
555 		err = PTR_ERR(jmp_table.map);
556 		goto out;
557 	}
558 
559 	ATTACH_AND_STORE_LINK(hid_tail_call);
560 
561 	return 0;
562 out:
563 	hid_bpf_free_links_and_skel();
564 	return err;
565 }
566