xref: /openbmc/linux/drivers/hid/bpf/hid_bpf_jmp_table.c (revision 658ee5a64fcfbbf758447fa3af425729eaabb0dc)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 /*
4  *  HID-BPF support for Linux
5  *
6  *  Copyright (c) 2022 Benjamin Tissoires
7  */
8 
9 #include <linux/bitops.h>
10 #include <linux/btf.h>
11 #include <linux/btf_ids.h>
12 #include <linux/circ_buf.h>
13 #include <linux/filter.h>
14 #include <linux/hid.h>
15 #include <linux/hid_bpf.h>
16 #include <linux/init.h>
17 #include <linux/module.h>
18 #include <linux/workqueue.h>
19 #include "hid_bpf_dispatch.h"
20 #include "entrypoints/entrypoints.lskel.h"
21 
22 #define HID_BPF_MAX_PROGS 1024 /* keep this in sync with preloaded bpf,
23 				* needs to be a power of 2 as we use it as
24 				* a circular buffer
25 				*/
26 
27 #define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1))
28 #define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1))
29 
30 /*
31  * represents one attached program stored in the hid jump table
32  */
33 struct hid_bpf_prog_entry {
34 	struct bpf_prog *prog;
35 	struct hid_device *hdev;
36 	enum hid_bpf_prog_type type;
37 	u16 idx;
38 };
39 
40 struct hid_bpf_jmp_table {
41 	struct bpf_map *map;
42 	struct hid_bpf_prog_entry entries[HID_BPF_MAX_PROGS]; /* compacted list, circular buffer */
43 	int tail, head;
44 	struct bpf_prog *progs[HID_BPF_MAX_PROGS]; /* idx -> progs mapping */
45 	unsigned long enabled[BITS_TO_LONGS(HID_BPF_MAX_PROGS)];
46 };
47 
48 #define FOR_ENTRIES(__i, __start, __end) \
49 	for (__i = __start; CIRC_CNT(__end, __i, HID_BPF_MAX_PROGS); __i = NEXT(__i))
50 
51 static struct hid_bpf_jmp_table jmp_table;
52 
53 static DEFINE_MUTEX(hid_bpf_attach_lock);		/* held when attaching/detaching programs */
54 
55 static void hid_bpf_release_progs(struct work_struct *work);
56 
57 static DECLARE_WORK(release_work, hid_bpf_release_progs);
58 
59 BTF_ID_LIST(hid_bpf_btf_ids)
60 BTF_ID(func, hid_bpf_device_event)			/* HID_BPF_PROG_TYPE_DEVICE_EVENT */
61 
62 static int hid_bpf_max_programs(enum hid_bpf_prog_type type)
63 {
64 	switch (type) {
65 	case HID_BPF_PROG_TYPE_DEVICE_EVENT:
66 		return HID_BPF_MAX_PROGS_PER_DEV;
67 	default:
68 		return -EINVAL;
69 	}
70 }
71 
72 static int hid_bpf_program_count(struct hid_device *hdev,
73 				 struct bpf_prog *prog,
74 				 enum hid_bpf_prog_type type)
75 {
76 	int i, n = 0;
77 
78 	if (type >= HID_BPF_PROG_TYPE_MAX)
79 		return -EINVAL;
80 
81 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
82 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
83 
84 		if (type != HID_BPF_PROG_TYPE_UNDEF && entry->type != type)
85 			continue;
86 
87 		if (hdev && entry->hdev != hdev)
88 			continue;
89 
90 		if (prog && entry->prog != prog)
91 			continue;
92 
93 		n++;
94 	}
95 
96 	return n;
97 }
98 
99 __weak noinline int __hid_bpf_tail_call(struct hid_bpf_ctx *ctx)
100 {
101 	return 0;
102 }
103 ALLOW_ERROR_INJECTION(__hid_bpf_tail_call, ERRNO);
104 
105 int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
106 		     struct hid_bpf_ctx_kern *ctx_kern)
107 {
108 	struct hid_bpf_prog_list *prog_list;
109 	int i, idx, err = 0;
110 
111 	rcu_read_lock();
112 	prog_list = rcu_dereference(hdev->bpf.progs[type]);
113 
114 	if (!prog_list)
115 		goto out_unlock;
116 
117 	for (i = 0; i < prog_list->prog_cnt; i++) {
118 		idx = prog_list->prog_idx[i];
119 
120 		if (!test_bit(idx, jmp_table.enabled))
121 			continue;
122 
123 		ctx_kern->ctx.index = idx;
124 		err = __hid_bpf_tail_call(&ctx_kern->ctx);
125 		if (err < 0)
126 			break;
127 		if (err)
128 			ctx_kern->ctx.retval = err;
129 	}
130 
131  out_unlock:
132 	rcu_read_unlock();
133 
134 	return err;
135 }
136 
137 /*
138  * assign the list of programs attached to a given hid device.
139  */
140 static void __hid_bpf_set_hdev_progs(struct hid_device *hdev, struct hid_bpf_prog_list *new_list,
141 				     enum hid_bpf_prog_type type)
142 {
143 	struct hid_bpf_prog_list *old_list;
144 
145 	spin_lock(&hdev->bpf.progs_lock);
146 	old_list = rcu_dereference_protected(hdev->bpf.progs[type],
147 					     lockdep_is_held(&hdev->bpf.progs_lock));
148 	rcu_assign_pointer(hdev->bpf.progs[type], new_list);
149 	spin_unlock(&hdev->bpf.progs_lock);
150 	synchronize_rcu();
151 
152 	kfree(old_list);
153 }
154 
155 /*
156  * allocate and populate the list of programs attached to a given hid device.
157  *
158  * Must be called under lock.
159  */
160 static int hid_bpf_populate_hdev(struct hid_device *hdev, enum hid_bpf_prog_type type)
161 {
162 	struct hid_bpf_prog_list *new_list;
163 	int i;
164 
165 	if (type >= HID_BPF_PROG_TYPE_MAX || !hdev)
166 		return -EINVAL;
167 
168 	if (hdev->bpf.destroyed)
169 		return 0;
170 
171 	new_list = kzalloc(sizeof(*new_list), GFP_KERNEL);
172 	if (!new_list)
173 		return -ENOMEM;
174 
175 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
176 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
177 
178 		if (entry->type == type && entry->hdev == hdev &&
179 		    test_bit(entry->idx, jmp_table.enabled))
180 			new_list->prog_idx[new_list->prog_cnt++] = entry->idx;
181 	}
182 
183 	__hid_bpf_set_hdev_progs(hdev, new_list, type);
184 
185 	return 0;
186 }
187 
188 static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
189 {
190 	skel_map_delete_elem(map_fd, &idx);
191 	jmp_table.progs[idx] = NULL;
192 }
193 
194 static void hid_bpf_release_progs(struct work_struct *work)
195 {
196 	int i, j, n, map_fd = -1;
197 
198 	if (!jmp_table.map)
199 		return;
200 
201 	/* retrieve a fd of our prog_array map in BPF */
202 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
203 	if (map_fd < 0)
204 		return;
205 
206 	mutex_lock(&hid_bpf_attach_lock); /* protects against attaching new programs */
207 
208 	/* detach unused progs from HID devices */
209 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
210 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
211 		enum hid_bpf_prog_type type;
212 		struct hid_device *hdev;
213 
214 		if (test_bit(entry->idx, jmp_table.enabled))
215 			continue;
216 
217 		/* we have an attached prog */
218 		if (entry->hdev) {
219 			hdev = entry->hdev;
220 			type = entry->type;
221 
222 			hid_bpf_populate_hdev(hdev, type);
223 
224 			/* mark all other disabled progs from hdev of the given type as detached */
225 			FOR_ENTRIES(j, i, jmp_table.head) {
226 				struct hid_bpf_prog_entry *next;
227 
228 				next = &jmp_table.entries[j];
229 
230 				if (test_bit(next->idx, jmp_table.enabled))
231 					continue;
232 
233 				if (next->hdev == hdev && next->type == type)
234 					next->hdev = NULL;
235 			}
236 		}
237 	}
238 
239 	/* remove all unused progs from the jump table */
240 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
241 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
242 
243 		if (test_bit(entry->idx, jmp_table.enabled))
244 			continue;
245 
246 		if (entry->prog)
247 			__hid_bpf_do_release_prog(map_fd, entry->idx);
248 	}
249 
250 	/* compact the entry list */
251 	n = jmp_table.tail;
252 	FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
253 		struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
254 
255 		if (!test_bit(entry->idx, jmp_table.enabled))
256 			continue;
257 
258 		jmp_table.entries[n] = jmp_table.entries[i];
259 		n = NEXT(n);
260 	}
261 
262 	jmp_table.head = n;
263 
264 	mutex_unlock(&hid_bpf_attach_lock);
265 
266 	if (map_fd >= 0)
267 		close_fd(map_fd);
268 }
269 
270 static void hid_bpf_release_prog_at(int idx)
271 {
272 	int map_fd = -1;
273 
274 	/* retrieve a fd of our prog_array map in BPF */
275 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
276 	if (map_fd < 0)
277 		return;
278 
279 	__hid_bpf_do_release_prog(map_fd, idx);
280 
281 	close(map_fd);
282 }
283 
284 /*
285  * Insert the given BPF program represented by its fd in the jmp table.
286  * Returns the index in the jump table or a negative error.
287  */
288 static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog)
289 {
290 	int i, index = -1, map_fd = -1, err = -EINVAL;
291 
292 	/* retrieve a fd of our prog_array map in BPF */
293 	map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
294 
295 	if (map_fd < 0) {
296 		err = -EINVAL;
297 		goto out;
298 	}
299 
300 	/* find the first available index in the jmp_table */
301 	for (i = 0; i < HID_BPF_MAX_PROGS; i++) {
302 		if (!jmp_table.progs[i] && index < 0) {
303 			/* mark the index as used */
304 			jmp_table.progs[i] = prog;
305 			index = i;
306 			__set_bit(i, jmp_table.enabled);
307 		}
308 	}
309 	if (index < 0) {
310 		err = -ENOMEM;
311 		goto out;
312 	}
313 
314 	/* insert the program in the jump table */
315 	err = skel_map_update_elem(map_fd, &index, &prog_fd, 0);
316 	if (err)
317 		goto out;
318 
319 	/*
320 	 * The program has been safely inserted, decrement the reference count
321 	 * so it doesn't interfere with the number of actual user handles.
322 	 * This is safe to do because:
323 	 * - we overrite the put_ptr in the prog fd map
324 	 * - we also have a cleanup function that monitors when a program gets
325 	 *   released and we manually do the cleanup in the prog fd map
326 	 */
327 	bpf_prog_sub(prog, 1);
328 
329 	/* return the index */
330 	err = index;
331 
332  out:
333 	if (err < 0)
334 		__hid_bpf_do_release_prog(map_fd, index);
335 	if (map_fd >= 0)
336 		close_fd(map_fd);
337 	return err;
338 }
339 
340 int hid_bpf_get_prog_attach_type(int prog_fd)
341 {
342 	struct bpf_prog *prog = NULL;
343 	int i;
344 	int prog_type = HID_BPF_PROG_TYPE_UNDEF;
345 
346 	prog = bpf_prog_get(prog_fd);
347 	if (IS_ERR(prog))
348 		return PTR_ERR(prog);
349 
350 	for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) {
351 		if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) {
352 			prog_type = i;
353 			break;
354 		}
355 	}
356 
357 	bpf_prog_put(prog);
358 
359 	return prog_type;
360 }
361 
362 /* called from syscall */
363 noinline int
364 __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
365 		      int prog_fd, __u32 flags)
366 {
367 	struct bpf_prog *prog = NULL;
368 	struct hid_bpf_prog_entry *prog_entry;
369 	int cnt, err = -EINVAL, prog_idx = -1;
370 
371 	/* take a ref on the prog itself */
372 	prog = bpf_prog_get(prog_fd);
373 	if (IS_ERR(prog))
374 		return PTR_ERR(prog);
375 
376 	mutex_lock(&hid_bpf_attach_lock);
377 
378 	/* do not attach too many programs to a given HID device */
379 	cnt = hid_bpf_program_count(hdev, NULL, prog_type);
380 	if (cnt < 0) {
381 		err = cnt;
382 		goto out_unlock;
383 	}
384 
385 	if (cnt >= hid_bpf_max_programs(prog_type)) {
386 		err = -E2BIG;
387 		goto out_unlock;
388 	}
389 
390 	prog_idx = hid_bpf_insert_prog(prog_fd, prog);
391 	/* if the jmp table is full, abort */
392 	if (prog_idx < 0) {
393 		err = prog_idx;
394 		goto out_unlock;
395 	}
396 
397 	if (flags & HID_BPF_FLAG_INSERT_HEAD) {
398 		/* take the previous prog_entry slot */
399 		jmp_table.tail = PREV(jmp_table.tail);
400 		prog_entry = &jmp_table.entries[jmp_table.tail];
401 	} else {
402 		/* take the next prog_entry slot */
403 		prog_entry = &jmp_table.entries[jmp_table.head];
404 		jmp_table.head = NEXT(jmp_table.head);
405 	}
406 
407 	/* we steal the ref here */
408 	prog_entry->prog = prog;
409 	prog_entry->idx = prog_idx;
410 	prog_entry->hdev = hdev;
411 	prog_entry->type = prog_type;
412 
413 	/* finally store the index in the device list */
414 	err = hid_bpf_populate_hdev(hdev, prog_type);
415 	if (err)
416 		hid_bpf_release_prog_at(prog_idx);
417 
418  out_unlock:
419 	mutex_unlock(&hid_bpf_attach_lock);
420 
421 	/* we only use prog as a key in the various tables, so we don't need to actually
422 	 * increment the ref count.
423 	 */
424 	bpf_prog_put(prog);
425 
426 	return err;
427 }
428 
429 void __hid_bpf_destroy_device(struct hid_device *hdev)
430 {
431 	int type, i;
432 	struct hid_bpf_prog_list *prog_list;
433 
434 	rcu_read_lock();
435 
436 	for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) {
437 		prog_list = rcu_dereference(hdev->bpf.progs[type]);
438 
439 		if (!prog_list)
440 			continue;
441 
442 		for (i = 0; i < prog_list->prog_cnt; i++)
443 			__clear_bit(prog_list->prog_idx[i], jmp_table.enabled);
444 	}
445 
446 	rcu_read_unlock();
447 
448 	for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++)
449 		__hid_bpf_set_hdev_progs(hdev, NULL, type);
450 
451 	/* schedule release of all detached progs */
452 	schedule_work(&release_work);
453 }
454 
455 void call_hid_bpf_prog_put_deferred(struct work_struct *work)
456 {
457 	struct bpf_prog_aux *aux;
458 	struct bpf_prog *prog;
459 	bool found = false;
460 	int i;
461 
462 	aux = container_of(work, struct bpf_prog_aux, work);
463 	prog = aux->prog;
464 
465 	/* we don't need locking here because the entries in the progs table
466 	 * are stable:
467 	 * if there are other users (and the progs entries might change), we
468 	 * would simply not have been called.
469 	 */
470 	for (i = 0; i < HID_BPF_MAX_PROGS; i++) {
471 		if (jmp_table.progs[i] == prog) {
472 			__clear_bit(i, jmp_table.enabled);
473 			found = true;
474 		}
475 	}
476 
477 	if (found)
478 		/* schedule release of all detached progs */
479 		schedule_work(&release_work);
480 }
481 
482 static void hid_bpf_prog_fd_array_put_ptr(void *ptr)
483 {
484 }
485 
486 #define HID_BPF_PROGS_COUNT 2
487 
488 static struct bpf_link *links[HID_BPF_PROGS_COUNT];
489 static struct entrypoints_bpf *skel;
490 
491 void hid_bpf_free_links_and_skel(void)
492 {
493 	int i;
494 
495 	/* the following is enough to release all programs attached to hid */
496 	if (jmp_table.map)
497 		bpf_map_put_with_uref(jmp_table.map);
498 
499 	for (i = 0; i < ARRAY_SIZE(links); i++) {
500 		if (!IS_ERR_OR_NULL(links[i]))
501 			bpf_link_put(links[i]);
502 	}
503 	entrypoints_bpf__destroy(skel);
504 }
505 
506 #define ATTACH_AND_STORE_LINK(__name) do {					\
507 	err = entrypoints_bpf__##__name##__attach(skel);			\
508 	if (err)								\
509 		goto out;							\
510 										\
511 	links[idx] = bpf_link_get_from_fd(skel->links.__name##_fd);		\
512 	if (IS_ERR(links[idx])) {						\
513 		err = PTR_ERR(links[idx]);					\
514 		goto out;							\
515 	}									\
516 										\
517 	/* Avoid taking over stdin/stdout/stderr of init process. Zeroing out	\
518 	 * makes skel_closenz() a no-op later in iterators_bpf__destroy().	\
519 	 */									\
520 	close_fd(skel->links.__name##_fd);					\
521 	skel->links.__name##_fd = 0;						\
522 	idx++;									\
523 } while (0)
524 
525 static struct bpf_map_ops hid_bpf_prog_fd_maps_ops;
526 
527 int hid_bpf_preload_skel(void)
528 {
529 	int err, idx = 0;
530 
531 	skel = entrypoints_bpf__open();
532 	if (!skel)
533 		return -ENOMEM;
534 
535 	err = entrypoints_bpf__load(skel);
536 	if (err)
537 		goto out;
538 
539 	jmp_table.map = bpf_map_get_with_uref(skel->maps.hid_jmp_table.map_fd);
540 	if (IS_ERR(jmp_table.map)) {
541 		err = PTR_ERR(jmp_table.map);
542 		goto out;
543 	}
544 
545 	/* our jump table is stealing refs, so we should not decrement on removal of elements */
546 	hid_bpf_prog_fd_maps_ops = *jmp_table.map->ops;
547 	hid_bpf_prog_fd_maps_ops.map_fd_put_ptr = hid_bpf_prog_fd_array_put_ptr;
548 
549 	jmp_table.map->ops = &hid_bpf_prog_fd_maps_ops;
550 
551 	ATTACH_AND_STORE_LINK(hid_tail_call);
552 	ATTACH_AND_STORE_LINK(hid_bpf_prog_put_deferred);
553 
554 	return 0;
555 out:
556 	hid_bpf_free_links_and_skel();
557 	return err;
558 }
559