xref: /openbmc/linux/drivers/hid/bpf/hid_bpf_jmp_table.c (revision ad190df11a024c1214fbc8dcaab72c592c79d9b5)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 /*
4  *  HID-BPF support for Linux
5  *
6  *  Copyright (c) 2022 Benjamin Tissoires
7  */
8 
9 #include <linux/bitops.h>
10 #include <linux/btf.h>
11 #include <linux/btf_ids.h>
12 #include <linux/circ_buf.h>
13 #include <linux/filter.h>
14 #include <linux/hid.h>
15 #include <linux/hid_bpf.h>
16 #include <linux/init.h>
17 #include <linux/module.h>
18 #include <linux/workqueue.h>
19 #include "hid_bpf_dispatch.h"
20 #include "entrypoints/entrypoints.lskel.h"
21 
22 #define HID_BPF_MAX_PROGS 1024 /* keep this in sync with preloaded bpf,
23 				* needs to be a power of 2 as we use it as
24 				* a circular buffer
25 				*/
26 
27 #define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1))
28 #define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1))
29 
30 /*
31  * represents one attached program stored in the hid jump table
32  */
33 struct hid_bpf_prog_entry {
34 	struct bpf_prog *prog;
35 	struct hid_device *hdev;
36 	enum hid_bpf_prog_type type;
37 	u16 idx;
38 };
39 
40 struct hid_bpf_jmp_table {
41 	struct bpf_map *map;
42 	struct hid_bpf_prog_entry entries[HID_BPF_MAX_PROGS]; /* compacted list, circular buffer */
43 	int tail, head;
44 	struct bpf_prog *progs[HID_BPF_MAX_PROGS]; /* idx -> progs mapping */
45 	unsigned long enabled[BITS_TO_LONGS(HID_BPF_MAX_PROGS)];
46 };
47 
48 #define FOR_ENTRIES(__i, __start, __end) \
49 	for (__i = __start; CIRC_CNT(__end, __i, HID_BPF_MAX_PROGS); __i = NEXT(__i))
50 
51 static struct hid_bpf_jmp_table jmp_table;
52 
53 static DEFINE_MUTEX(hid_bpf_attach_lock);		/* held when attaching/detaching programs */
54 
55 static void hid_bpf_release_progs(struct work_struct *work);
56 
57 static DECLARE_WORK(release_work, hid_bpf_release_progs);
58 
59 BTF_ID_LIST(hid_bpf_btf_ids)
60 BTF_ID(func, hid_bpf_device_event)			/* HID_BPF_PROG_TYPE_DEVICE_EVENT */
61 BTF_ID(func, hid_bpf_rdesc_fixup)			/* HID_BPF_PROG_TYPE_RDESC_FIXUP */
62 
63 static int hid_bpf_max_programs(enum hid_bpf_prog_type type)
64 {
65 	switch (type) {
66 	case HID_BPF_PROG_TYPE_DEVICE_EVENT:
67 		return HID_BPF_MAX_PROGS_PER_DEV;
68 	case HID_BPF_PROG_TYPE_RDESC_FIXUP:
69 		return 1;
70 	default:
71 		return -EINVAL;
72 	}
73 }
74 
75 static int hid_bpf_program_count(struct hid_device *hdev,
76 				 struct bpf_prog *prog,
77 				 enum hid_bpf_prog_type type)
78 {
79 	int i, n = 0;
80 
81 	if (type >= HID_BPF_PROG_TYPE_MAX)
82 		return -EINVAL;
83 
84 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
85 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
86 
87 		if (type != HID_BPF_PROG_TYPE_UNDEF && entry->type != type)
88 			continue;
89 
90 		if (hdev && entry->hdev != hdev)
91 			continue;
92 
93 		if (prog && entry->prog != prog)
94 			continue;
95 
96 		n++;
97 	}
98 
99 	return n;
100 }
101 
102 __weak noinline int __hid_bpf_tail_call(struct hid_bpf_ctx *ctx)
103 {
104 	return 0;
105 }
106 ALLOW_ERROR_INJECTION(__hid_bpf_tail_call, ERRNO);
107 
108 int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
109 		     struct hid_bpf_ctx_kern *ctx_kern)
110 {
111 	struct hid_bpf_prog_list *prog_list;
112 	int i, idx, err = 0;
113 
114 	rcu_read_lock();
115 	prog_list = rcu_dereference(hdev->bpf.progs[type]);
116 
117 	if (!prog_list)
118 		goto out_unlock;
119 
120 	for (i = 0; i < prog_list->prog_cnt; i++) {
121 		idx = prog_list->prog_idx[i];
122 
123 		if (!test_bit(idx, jmp_table.enabled))
124 			continue;
125 
126 		ctx_kern->ctx.index = idx;
127 		err = __hid_bpf_tail_call(&ctx_kern->ctx);
128 		if (err < 0)
129 			break;
130 		if (err)
131 			ctx_kern->ctx.retval = err;
132 	}
133 
134  out_unlock:
135 	rcu_read_unlock();
136 
137 	return err;
138 }
139 
140 /*
141  * assign the list of programs attached to a given hid device.
142  */
143 static void __hid_bpf_set_hdev_progs(struct hid_device *hdev, struct hid_bpf_prog_list *new_list,
144 				     enum hid_bpf_prog_type type)
145 {
146 	struct hid_bpf_prog_list *old_list;
147 
148 	spin_lock(&hdev->bpf.progs_lock);
149 	old_list = rcu_dereference_protected(hdev->bpf.progs[type],
150 					     lockdep_is_held(&hdev->bpf.progs_lock));
151 	rcu_assign_pointer(hdev->bpf.progs[type], new_list);
152 	spin_unlock(&hdev->bpf.progs_lock);
153 	synchronize_rcu();
154 
155 	kfree(old_list);
156 }
157 
158 /*
159  * allocate and populate the list of programs attached to a given hid device.
160  *
161  * Must be called under lock.
162  */
163 static int hid_bpf_populate_hdev(struct hid_device *hdev, enum hid_bpf_prog_type type)
164 {
165 	struct hid_bpf_prog_list *new_list;
166 	int i;
167 
168 	if (type >= HID_BPF_PROG_TYPE_MAX || !hdev)
169 		return -EINVAL;
170 
171 	if (hdev->bpf.destroyed)
172 		return 0;
173 
174 	new_list = kzalloc(sizeof(*new_list), GFP_KERNEL);
175 	if (!new_list)
176 		return -ENOMEM;
177 
178 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
179 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
180 
181 		if (entry->type == type && entry->hdev == hdev &&
182 		    test_bit(entry->idx, jmp_table.enabled))
183 			new_list->prog_idx[new_list->prog_cnt++] = entry->idx;
184 	}
185 
186 	__hid_bpf_set_hdev_progs(hdev, new_list, type);
187 
188 	return 0;
189 }
190 
191 static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
192 {
193 	skel_map_delete_elem(map_fd, &idx);
194 	jmp_table.progs[idx] = NULL;
195 }
196 
197 static void hid_bpf_release_progs(struct work_struct *work)
198 {
199 	int i, j, n, map_fd = -1;
200 
201 	if (!jmp_table.map)
202 		return;
203 
204 	/* retrieve a fd of our prog_array map in BPF */
205 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
206 	if (map_fd < 0)
207 		return;
208 
209 	mutex_lock(&hid_bpf_attach_lock); /* protects against attaching new programs */
210 
211 	/* detach unused progs from HID devices */
212 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
213 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
214 		enum hid_bpf_prog_type type;
215 		struct hid_device *hdev;
216 
217 		if (test_bit(entry->idx, jmp_table.enabled))
218 			continue;
219 
220 		/* we have an attached prog */
221 		if (entry->hdev) {
222 			hdev = entry->hdev;
223 			type = entry->type;
224 
225 			hid_bpf_populate_hdev(hdev, type);
226 
227 			/* mark all other disabled progs from hdev of the given type as detached */
228 			FOR_ENTRIES(j, i, jmp_table.head) {
229 				struct hid_bpf_prog_entry *next;
230 
231 				next = &jmp_table.entries[j];
232 
233 				if (test_bit(next->idx, jmp_table.enabled))
234 					continue;
235 
236 				if (next->hdev == hdev && next->type == type)
237 					next->hdev = NULL;
238 			}
239 
240 			/* if type was rdesc fixup, reconnect device */
241 			if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP)
242 				hid_bpf_reconnect(hdev);
243 		}
244 	}
245 
246 	/* remove all unused progs from the jump table */
247 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
248 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
249 
250 		if (test_bit(entry->idx, jmp_table.enabled))
251 			continue;
252 
253 		if (entry->prog)
254 			__hid_bpf_do_release_prog(map_fd, entry->idx);
255 	}
256 
257 	/* compact the entry list */
258 	n = jmp_table.tail;
259 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
260 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
261 
262 		if (!test_bit(entry->idx, jmp_table.enabled))
263 			continue;
264 
265 		jmp_table.entries[n] = jmp_table.entries[i];
266 		n = NEXT(n);
267 	}
268 
269 	jmp_table.head = n;
270 
271 	mutex_unlock(&hid_bpf_attach_lock);
272 
273 	if (map_fd >= 0)
274 		close_fd(map_fd);
275 }
276 
277 static void hid_bpf_release_prog_at(int idx)
278 {
279 	int map_fd = -1;
280 
281 	/* retrieve a fd of our prog_array map in BPF */
282 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
283 	if (map_fd < 0)
284 		return;
285 
286 	__hid_bpf_do_release_prog(map_fd, idx);
287 
288 	close(map_fd);
289 }
290 
291 /*
292  * Insert the given BPF program represented by its fd in the jmp table.
293  * Returns the index in the jump table or a negative error.
294  */
295 static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog)
296 {
297 	int i, index = -1, map_fd = -1, err = -EINVAL;
298 
299 	/* retrieve a fd of our prog_array map in BPF */
300 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
301 
302 	if (map_fd < 0) {
303 		err = -EINVAL;
304 		goto out;
305 	}
306 
307 	/* find the first available index in the jmp_table */
308 	for (i = 0; i < HID_BPF_MAX_PROGS; i++) {
309 		if (!jmp_table.progs[i] && index < 0) {
310 			/* mark the index as used */
311 			jmp_table.progs[i] = prog;
312 			index = i;
313 			__set_bit(i, jmp_table.enabled);
314 		}
315 	}
316 	if (index < 0) {
317 		err = -ENOMEM;
318 		goto out;
319 	}
320 
321 	/* insert the program in the jump table */
322 	err = skel_map_update_elem(map_fd, &index, &prog_fd, 0);
323 	if (err)
324 		goto out;
325 
326 	/*
327 	 * The program has been safely inserted, decrement the reference count
328 	 * so it doesn't interfere with the number of actual user handles.
329 	 * This is safe to do because:
330 	 * - we overrite the put_ptr in the prog fd map
331 	 * - we also have a cleanup function that monitors when a program gets
332 	 *   released and we manually do the cleanup in the prog fd map
333 	 */
334 	bpf_prog_sub(prog, 1);
335 
336 	/* return the index */
337 	err = index;
338 
339  out:
340 	if (err < 0)
341 		__hid_bpf_do_release_prog(map_fd, index);
342 	if (map_fd >= 0)
343 		close_fd(map_fd);
344 	return err;
345 }
346 
347 int hid_bpf_get_prog_attach_type(int prog_fd)
348 {
349 	struct bpf_prog *prog = NULL;
350 	int i;
351 	int prog_type = HID_BPF_PROG_TYPE_UNDEF;
352 
353 	prog = bpf_prog_get(prog_fd);
354 	if (IS_ERR(prog))
355 		return PTR_ERR(prog);
356 
357 	for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) {
358 		if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) {
359 			prog_type = i;
360 			break;
361 		}
362 	}
363 
364 	bpf_prog_put(prog);
365 
366 	return prog_type;
367 }
368 
369 /* called from syscall */
370 noinline int
371 __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
372 		      int prog_fd, __u32 flags)
373 {
374 	struct bpf_prog *prog = NULL;
375 	struct hid_bpf_prog_entry *prog_entry;
376 	int cnt, err = -EINVAL, prog_idx = -1;
377 
378 	/* take a ref on the prog itself */
379 	prog = bpf_prog_get(prog_fd);
380 	if (IS_ERR(prog))
381 		return PTR_ERR(prog);
382 
383 	mutex_lock(&hid_bpf_attach_lock);
384 
385 	/* do not attach too many programs to a given HID device */
386 	cnt = hid_bpf_program_count(hdev, NULL, prog_type);
387 	if (cnt < 0) {
388 		err = cnt;
389 		goto out_unlock;
390 	}
391 
392 	if (cnt >= hid_bpf_max_programs(prog_type)) {
393 		err = -E2BIG;
394 		goto out_unlock;
395 	}
396 
397 	prog_idx = hid_bpf_insert_prog(prog_fd, prog);
398 	/* if the jmp table is full, abort */
399 	if (prog_idx < 0) {
400 		err = prog_idx;
401 		goto out_unlock;
402 	}
403 
404 	if (flags & HID_BPF_FLAG_INSERT_HEAD) {
405 		/* take the previous prog_entry slot */
406 		jmp_table.tail = PREV(jmp_table.tail);
407 		prog_entry = &jmp_table.entries[jmp_table.tail];
408 	} else {
409 		/* take the next prog_entry slot */
410 		prog_entry = &jmp_table.entries[jmp_table.head];
411 		jmp_table.head = NEXT(jmp_table.head);
412 	}
413 
414 	/* we steal the ref here */
415 	prog_entry->prog = prog;
416 	prog_entry->idx = prog_idx;
417 	prog_entry->hdev = hdev;
418 	prog_entry->type = prog_type;
419 
420 	/* finally store the index in the device list */
421 	err = hid_bpf_populate_hdev(hdev, prog_type);
422 	if (err)
423 		hid_bpf_release_prog_at(prog_idx);
424 
425  out_unlock:
426 	mutex_unlock(&hid_bpf_attach_lock);
427 
428 	/* we only use prog as a key in the various tables, so we don't need to actually
429 	 * increment the ref count.
430 	 */
431 	bpf_prog_put(prog);
432 
433 	return err;
434 }
435 
436 void __hid_bpf_destroy_device(struct hid_device *hdev)
437 {
438 	int type, i;
439 	struct hid_bpf_prog_list *prog_list;
440 
441 	rcu_read_lock();
442 
443 	for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) {
444 		prog_list = rcu_dereference(hdev->bpf.progs[type]);
445 
446 		if (!prog_list)
447 			continue;
448 
449 		for (i = 0; i < prog_list->prog_cnt; i++)
450 			__clear_bit(prog_list->prog_idx[i], jmp_table.enabled);
451 	}
452 
453 	rcu_read_unlock();
454 
455 	for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++)
456 		__hid_bpf_set_hdev_progs(hdev, NULL, type);
457 
458 	/* schedule release of all detached progs */
459 	schedule_work(&release_work);
460 }
461 
462 void call_hid_bpf_prog_put_deferred(struct work_struct *work)
463 {
464 	struct bpf_prog_aux *aux;
465 	struct bpf_prog *prog;
466 	bool found = false;
467 	int i;
468 
469 	aux = container_of(work, struct bpf_prog_aux, work);
470 	prog = aux->prog;
471 
472 	/* we don't need locking here because the entries in the progs table
473 	 * are stable:
474 	 * if there are other users (and the progs entries might change), we
475 	 * would simply not have been called.
476 	 */
477 	for (i = 0; i < HID_BPF_MAX_PROGS; i++) {
478 		if (jmp_table.progs[i] == prog) {
479 			__clear_bit(i, jmp_table.enabled);
480 			found = true;
481 		}
482 	}
483 
484 	if (found)
485 		/* schedule release of all detached progs */
486 		schedule_work(&release_work);
487 }
488 
489 static void hid_bpf_prog_fd_array_put_ptr(void *ptr)
490 {
491 }
492 
493 #define HID_BPF_PROGS_COUNT 2
494 
495 static struct bpf_link *links[HID_BPF_PROGS_COUNT];
496 static struct entrypoints_bpf *skel;
497 
498 void hid_bpf_free_links_and_skel(void)
499 {
500 	int i;
501 
502 	/* the following is enough to release all programs attached to hid */
503 	if (jmp_table.map)
504 		bpf_map_put_with_uref(jmp_table.map);
505 
506 	for (i = 0; i < ARRAY_SIZE(links); i++) {
507 		if (!IS_ERR_OR_NULL(links[i]))
508 			bpf_link_put(links[i]);
509 	}
510 	entrypoints_bpf__destroy(skel);
511 }
512 
513 #define ATTACH_AND_STORE_LINK(__name) do {					\
514 	err = entrypoints_bpf__##__name##__attach(skel);			\
515 	if (err)								\
516 		goto out;							\
517 										\
518 	links[idx] = bpf_link_get_from_fd(skel->links.__name##_fd);		\
519 	if (IS_ERR(links[idx])) {						\
520 		err = PTR_ERR(links[idx]);					\
521 		goto out;							\
522 	}									\
523 										\
524 	/* Avoid taking over stdin/stdout/stderr of init process. Zeroing out	\
525 	 * makes skel_closenz() a no-op later in iterators_bpf__destroy().	\
526 	 */									\
527 	close_fd(skel->links.__name##_fd);					\
528 	skel->links.__name##_fd = 0;						\
529 	idx++;									\
530 } while (0)
531 
532 static struct bpf_map_ops hid_bpf_prog_fd_maps_ops;
533 
534 int hid_bpf_preload_skel(void)
535 {
536 	int err, idx = 0;
537 
538 	skel = entrypoints_bpf__open();
539 	if (!skel)
540 		return -ENOMEM;
541 
542 	err = entrypoints_bpf__load(skel);
543 	if (err)
544 		goto out;
545 
546 	jmp_table.map = bpf_map_get_with_uref(skel->maps.hid_jmp_table.map_fd);
547 	if (IS_ERR(jmp_table.map)) {
548 		err = PTR_ERR(jmp_table.map);
549 		goto out;
550 	}
551 
552 	/* our jump table is stealing refs, so we should not decrement on removal of elements */
553 	hid_bpf_prog_fd_maps_ops = *jmp_table.map->ops;
554 	hid_bpf_prog_fd_maps_ops.map_fd_put_ptr = hid_bpf_prog_fd_array_put_ptr;
555 
556 	jmp_table.map->ops = &hid_bpf_prog_fd_maps_ops;
557 
558 	ATTACH_AND_STORE_LINK(hid_tail_call);
559 	ATTACH_AND_STORE_LINK(hid_bpf_prog_put_deferred);
560 
561 	return 0;
562 out:
563 	hid_bpf_free_links_and_skel();
564 	return err;
565 }
566