1 // SPDX-License-Identifier: GPL-2.0-only 2 3 /* 4 * HID-BPF support for Linux 5 * 6 * Copyright (c) 2022 Benjamin Tissoires 7 */ 8 9 #include <linux/bitops.h> 10 #include <linux/btf.h> 11 #include <linux/btf_ids.h> 12 #include <linux/circ_buf.h> 13 #include <linux/filter.h> 14 #include <linux/hid.h> 15 #include <linux/hid_bpf.h> 16 #include <linux/init.h> 17 #include <linux/module.h> 18 #include <linux/workqueue.h> 19 #include "hid_bpf_dispatch.h" 20 #include "entrypoints/entrypoints.lskel.h" 21 22 #define HID_BPF_MAX_PROGS 1024 /* keep this in sync with preloaded bpf, 23 * needs to be a power of 2 as we use it as 24 * a circular buffer 25 */ 26 27 #define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1)) 28 #define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1)) 29 30 /* 31 * represents one attached program stored in the hid jump table 32 */ 33 struct hid_bpf_prog_entry { 34 struct bpf_prog *prog; 35 struct hid_device *hdev; 36 enum hid_bpf_prog_type type; 37 u16 idx; 38 }; 39 40 struct hid_bpf_jmp_table { 41 struct bpf_map *map; 42 struct hid_bpf_prog_entry entries[HID_BPF_MAX_PROGS]; /* compacted list, circular buffer */ 43 int tail, head; 44 struct bpf_prog *progs[HID_BPF_MAX_PROGS]; /* idx -> progs mapping */ 45 unsigned long enabled[BITS_TO_LONGS(HID_BPF_MAX_PROGS)]; 46 }; 47 48 #define FOR_ENTRIES(__i, __start, __end) \ 49 for (__i = __start; CIRC_CNT(__end, __i, HID_BPF_MAX_PROGS); __i = NEXT(__i)) 50 51 static struct hid_bpf_jmp_table jmp_table; 52 53 static DEFINE_MUTEX(hid_bpf_attach_lock); /* held when attaching/detaching programs */ 54 55 static void hid_bpf_release_progs(struct work_struct *work); 56 57 static DECLARE_WORK(release_work, hid_bpf_release_progs); 58 59 BTF_ID_LIST(hid_bpf_btf_ids) 60 BTF_ID(func, hid_bpf_device_event) /* HID_BPF_PROG_TYPE_DEVICE_EVENT */ 61 BTF_ID(func, hid_bpf_rdesc_fixup) /* HID_BPF_PROG_TYPE_RDESC_FIXUP */ 62 63 static int hid_bpf_max_programs(enum hid_bpf_prog_type type) 64 { 65 switch (type) { 66 case HID_BPF_PROG_TYPE_DEVICE_EVENT: 67 return HID_BPF_MAX_PROGS_PER_DEV; 68 case HID_BPF_PROG_TYPE_RDESC_FIXUP: 69 return 1; 70 default: 71 return -EINVAL; 72 } 73 } 74 75 static int hid_bpf_program_count(struct hid_device *hdev, 76 struct bpf_prog *prog, 77 enum hid_bpf_prog_type type) 78 { 79 int i, n = 0; 80 81 if (type >= HID_BPF_PROG_TYPE_MAX) 82 return -EINVAL; 83 84 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 85 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 86 87 if (type != HID_BPF_PROG_TYPE_UNDEF && entry->type != type) 88 continue; 89 90 if (hdev && entry->hdev != hdev) 91 continue; 92 93 if (prog && entry->prog != prog) 94 continue; 95 96 n++; 97 } 98 99 return n; 100 } 101 102 __weak noinline int __hid_bpf_tail_call(struct hid_bpf_ctx *ctx) 103 { 104 return 0; 105 } 106 ALLOW_ERROR_INJECTION(__hid_bpf_tail_call, ERRNO); 107 108 int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type, 109 struct hid_bpf_ctx_kern *ctx_kern) 110 { 111 struct hid_bpf_prog_list *prog_list; 112 int i, idx, err = 0; 113 114 rcu_read_lock(); 115 prog_list = rcu_dereference(hdev->bpf.progs[type]); 116 117 if (!prog_list) 118 goto out_unlock; 119 120 for (i = 0; i < prog_list->prog_cnt; i++) { 121 idx = prog_list->prog_idx[i]; 122 123 if (!test_bit(idx, jmp_table.enabled)) 124 continue; 125 126 ctx_kern->ctx.index = idx; 127 err = __hid_bpf_tail_call(&ctx_kern->ctx); 128 if (err < 0) 129 break; 130 if (err) 131 ctx_kern->ctx.retval = err; 132 } 133 134 out_unlock: 135 rcu_read_unlock(); 136 137 return err; 138 } 139 140 /* 141 * assign the list of programs attached to a given hid device. 142 */ 143 static void __hid_bpf_set_hdev_progs(struct hid_device *hdev, struct hid_bpf_prog_list *new_list, 144 enum hid_bpf_prog_type type) 145 { 146 struct hid_bpf_prog_list *old_list; 147 148 spin_lock(&hdev->bpf.progs_lock); 149 old_list = rcu_dereference_protected(hdev->bpf.progs[type], 150 lockdep_is_held(&hdev->bpf.progs_lock)); 151 rcu_assign_pointer(hdev->bpf.progs[type], new_list); 152 spin_unlock(&hdev->bpf.progs_lock); 153 synchronize_rcu(); 154 155 kfree(old_list); 156 } 157 158 /* 159 * allocate and populate the list of programs attached to a given hid device. 160 * 161 * Must be called under lock. 162 */ 163 static int hid_bpf_populate_hdev(struct hid_device *hdev, enum hid_bpf_prog_type type) 164 { 165 struct hid_bpf_prog_list *new_list; 166 int i; 167 168 if (type >= HID_BPF_PROG_TYPE_MAX || !hdev) 169 return -EINVAL; 170 171 if (hdev->bpf.destroyed) 172 return 0; 173 174 new_list = kzalloc(sizeof(*new_list), GFP_KERNEL); 175 if (!new_list) 176 return -ENOMEM; 177 178 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 179 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 180 181 if (entry->type == type && entry->hdev == hdev && 182 test_bit(entry->idx, jmp_table.enabled)) 183 new_list->prog_idx[new_list->prog_cnt++] = entry->idx; 184 } 185 186 __hid_bpf_set_hdev_progs(hdev, new_list, type); 187 188 return 0; 189 } 190 191 static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx) 192 { 193 skel_map_delete_elem(map_fd, &idx); 194 jmp_table.progs[idx] = NULL; 195 } 196 197 static void hid_bpf_release_progs(struct work_struct *work) 198 { 199 int i, j, n, map_fd = -1; 200 201 if (!jmp_table.map) 202 return; 203 204 /* retrieve a fd of our prog_array map in BPF */ 205 map_fd = skel_map_get_fd_by_id(jmp_table.map->id); 206 if (map_fd < 0) 207 return; 208 209 mutex_lock(&hid_bpf_attach_lock); /* protects against attaching new programs */ 210 211 /* detach unused progs from HID devices */ 212 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 213 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 214 enum hid_bpf_prog_type type; 215 struct hid_device *hdev; 216 217 if (test_bit(entry->idx, jmp_table.enabled)) 218 continue; 219 220 /* we have an attached prog */ 221 if (entry->hdev) { 222 hdev = entry->hdev; 223 type = entry->type; 224 225 hid_bpf_populate_hdev(hdev, type); 226 227 /* mark all other disabled progs from hdev of the given type as detached */ 228 FOR_ENTRIES(j, i, jmp_table.head) { 229 struct hid_bpf_prog_entry *next; 230 231 next = &jmp_table.entries[j]; 232 233 if (test_bit(next->idx, jmp_table.enabled)) 234 continue; 235 236 if (next->hdev == hdev && next->type == type) 237 next->hdev = NULL; 238 } 239 240 /* if type was rdesc fixup, reconnect device */ 241 if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP) 242 hid_bpf_reconnect(hdev); 243 } 244 } 245 246 /* remove all unused progs from the jump table */ 247 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 248 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 249 250 if (test_bit(entry->idx, jmp_table.enabled)) 251 continue; 252 253 if (entry->prog) 254 __hid_bpf_do_release_prog(map_fd, entry->idx); 255 } 256 257 /* compact the entry list */ 258 n = jmp_table.tail; 259 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 260 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 261 262 if (!test_bit(entry->idx, jmp_table.enabled)) 263 continue; 264 265 jmp_table.entries[n] = jmp_table.entries[i]; 266 n = NEXT(n); 267 } 268 269 jmp_table.head = n; 270 271 mutex_unlock(&hid_bpf_attach_lock); 272 273 if (map_fd >= 0) 274 close_fd(map_fd); 275 } 276 277 static void hid_bpf_release_prog_at(int idx) 278 { 279 int map_fd = -1; 280 281 /* retrieve a fd of our prog_array map in BPF */ 282 map_fd = skel_map_get_fd_by_id(jmp_table.map->id); 283 if (map_fd < 0) 284 return; 285 286 __hid_bpf_do_release_prog(map_fd, idx); 287 288 close(map_fd); 289 } 290 291 /* 292 * Insert the given BPF program represented by its fd in the jmp table. 293 * Returns the index in the jump table or a negative error. 294 */ 295 static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog) 296 { 297 int i, index = -1, map_fd = -1, err = -EINVAL; 298 299 /* retrieve a fd of our prog_array map in BPF */ 300 map_fd = skel_map_get_fd_by_id(jmp_table.map->id); 301 302 if (map_fd < 0) { 303 err = -EINVAL; 304 goto out; 305 } 306 307 /* find the first available index in the jmp_table */ 308 for (i = 0; i < HID_BPF_MAX_PROGS; i++) { 309 if (!jmp_table.progs[i] && index < 0) { 310 /* mark the index as used */ 311 jmp_table.progs[i] = prog; 312 index = i; 313 __set_bit(i, jmp_table.enabled); 314 } 315 } 316 if (index < 0) { 317 err = -ENOMEM; 318 goto out; 319 } 320 321 /* insert the program in the jump table */ 322 err = skel_map_update_elem(map_fd, &index, &prog_fd, 0); 323 if (err) 324 goto out; 325 326 /* 327 * The program has been safely inserted, decrement the reference count 328 * so it doesn't interfere with the number of actual user handles. 329 * This is safe to do because: 330 * - we overrite the put_ptr in the prog fd map 331 * - we also have a cleanup function that monitors when a program gets 332 * released and we manually do the cleanup in the prog fd map 333 */ 334 bpf_prog_sub(prog, 1); 335 336 /* return the index */ 337 err = index; 338 339 out: 340 if (err < 0) 341 __hid_bpf_do_release_prog(map_fd, index); 342 if (map_fd >= 0) 343 close_fd(map_fd); 344 return err; 345 } 346 347 int hid_bpf_get_prog_attach_type(int prog_fd) 348 { 349 struct bpf_prog *prog = NULL; 350 int i; 351 int prog_type = HID_BPF_PROG_TYPE_UNDEF; 352 353 prog = bpf_prog_get(prog_fd); 354 if (IS_ERR(prog)) 355 return PTR_ERR(prog); 356 357 for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) { 358 if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) { 359 prog_type = i; 360 break; 361 } 362 } 363 364 bpf_prog_put(prog); 365 366 return prog_type; 367 } 368 369 /* called from syscall */ 370 noinline int 371 __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type, 372 int prog_fd, __u32 flags) 373 { 374 struct bpf_prog *prog = NULL; 375 struct hid_bpf_prog_entry *prog_entry; 376 int cnt, err = -EINVAL, prog_idx = -1; 377 378 /* take a ref on the prog itself */ 379 prog = bpf_prog_get(prog_fd); 380 if (IS_ERR(prog)) 381 return PTR_ERR(prog); 382 383 mutex_lock(&hid_bpf_attach_lock); 384 385 /* do not attach too many programs to a given HID device */ 386 cnt = hid_bpf_program_count(hdev, NULL, prog_type); 387 if (cnt < 0) { 388 err = cnt; 389 goto out_unlock; 390 } 391 392 if (cnt >= hid_bpf_max_programs(prog_type)) { 393 err = -E2BIG; 394 goto out_unlock; 395 } 396 397 prog_idx = hid_bpf_insert_prog(prog_fd, prog); 398 /* if the jmp table is full, abort */ 399 if (prog_idx < 0) { 400 err = prog_idx; 401 goto out_unlock; 402 } 403 404 if (flags & HID_BPF_FLAG_INSERT_HEAD) { 405 /* take the previous prog_entry slot */ 406 jmp_table.tail = PREV(jmp_table.tail); 407 prog_entry = &jmp_table.entries[jmp_table.tail]; 408 } else { 409 /* take the next prog_entry slot */ 410 prog_entry = &jmp_table.entries[jmp_table.head]; 411 jmp_table.head = NEXT(jmp_table.head); 412 } 413 414 /* we steal the ref here */ 415 prog_entry->prog = prog; 416 prog_entry->idx = prog_idx; 417 prog_entry->hdev = hdev; 418 prog_entry->type = prog_type; 419 420 /* finally store the index in the device list */ 421 err = hid_bpf_populate_hdev(hdev, prog_type); 422 if (err) 423 hid_bpf_release_prog_at(prog_idx); 424 425 out_unlock: 426 mutex_unlock(&hid_bpf_attach_lock); 427 428 /* we only use prog as a key in the various tables, so we don't need to actually 429 * increment the ref count. 430 */ 431 bpf_prog_put(prog); 432 433 return err; 434 } 435 436 void __hid_bpf_destroy_device(struct hid_device *hdev) 437 { 438 int type, i; 439 struct hid_bpf_prog_list *prog_list; 440 441 rcu_read_lock(); 442 443 for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) { 444 prog_list = rcu_dereference(hdev->bpf.progs[type]); 445 446 if (!prog_list) 447 continue; 448 449 for (i = 0; i < prog_list->prog_cnt; i++) 450 __clear_bit(prog_list->prog_idx[i], jmp_table.enabled); 451 } 452 453 rcu_read_unlock(); 454 455 for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) 456 __hid_bpf_set_hdev_progs(hdev, NULL, type); 457 458 /* schedule release of all detached progs */ 459 schedule_work(&release_work); 460 } 461 462 void call_hid_bpf_prog_put_deferred(struct work_struct *work) 463 { 464 struct bpf_prog_aux *aux; 465 struct bpf_prog *prog; 466 bool found = false; 467 int i; 468 469 aux = container_of(work, struct bpf_prog_aux, work); 470 prog = aux->prog; 471 472 /* we don't need locking here because the entries in the progs table 473 * are stable: 474 * if there are other users (and the progs entries might change), we 475 * would simply not have been called. 476 */ 477 for (i = 0; i < HID_BPF_MAX_PROGS; i++) { 478 if (jmp_table.progs[i] == prog) { 479 __clear_bit(i, jmp_table.enabled); 480 found = true; 481 } 482 } 483 484 if (found) 485 /* schedule release of all detached progs */ 486 schedule_work(&release_work); 487 } 488 489 static void hid_bpf_prog_fd_array_put_ptr(void *ptr) 490 { 491 } 492 493 #define HID_BPF_PROGS_COUNT 2 494 495 static struct bpf_link *links[HID_BPF_PROGS_COUNT]; 496 static struct entrypoints_bpf *skel; 497 498 void hid_bpf_free_links_and_skel(void) 499 { 500 int i; 501 502 /* the following is enough to release all programs attached to hid */ 503 if (jmp_table.map) 504 bpf_map_put_with_uref(jmp_table.map); 505 506 for (i = 0; i < ARRAY_SIZE(links); i++) { 507 if (!IS_ERR_OR_NULL(links[i])) 508 bpf_link_put(links[i]); 509 } 510 entrypoints_bpf__destroy(skel); 511 } 512 513 #define ATTACH_AND_STORE_LINK(__name) do { \ 514 err = entrypoints_bpf__##__name##__attach(skel); \ 515 if (err) \ 516 goto out; \ 517 \ 518 links[idx] = bpf_link_get_from_fd(skel->links.__name##_fd); \ 519 if (IS_ERR(links[idx])) { \ 520 err = PTR_ERR(links[idx]); \ 521 goto out; \ 522 } \ 523 \ 524 /* Avoid taking over stdin/stdout/stderr of init process. Zeroing out \ 525 * makes skel_closenz() a no-op later in iterators_bpf__destroy(). \ 526 */ \ 527 close_fd(skel->links.__name##_fd); \ 528 skel->links.__name##_fd = 0; \ 529 idx++; \ 530 } while (0) 531 532 static struct bpf_map_ops hid_bpf_prog_fd_maps_ops; 533 534 int hid_bpf_preload_skel(void) 535 { 536 int err, idx = 0; 537 538 skel = entrypoints_bpf__open(); 539 if (!skel) 540 return -ENOMEM; 541 542 err = entrypoints_bpf__load(skel); 543 if (err) 544 goto out; 545 546 jmp_table.map = bpf_map_get_with_uref(skel->maps.hid_jmp_table.map_fd); 547 if (IS_ERR(jmp_table.map)) { 548 err = PTR_ERR(jmp_table.map); 549 goto out; 550 } 551 552 /* our jump table is stealing refs, so we should not decrement on removal of elements */ 553 hid_bpf_prog_fd_maps_ops = *jmp_table.map->ops; 554 hid_bpf_prog_fd_maps_ops.map_fd_put_ptr = hid_bpf_prog_fd_array_put_ptr; 555 556 jmp_table.map->ops = &hid_bpf_prog_fd_maps_ops; 557 558 ATTACH_AND_STORE_LINK(hid_tail_call); 559 ATTACH_AND_STORE_LINK(hid_bpf_prog_put_deferred); 560 561 return 0; 562 out: 563 hid_bpf_free_links_and_skel(); 564 return err; 565 } 566