1 // SPDX-License-Identifier: GPL-2.0-only 2 3 /* 4 * HID-BPF support for Linux 5 * 6 * Copyright (c) 2022 Benjamin Tissoires 7 */ 8 9 #include <linux/bitops.h> 10 #include <linux/btf.h> 11 #include <linux/btf_ids.h> 12 #include <linux/circ_buf.h> 13 #include <linux/filter.h> 14 #include <linux/hid.h> 15 #include <linux/hid_bpf.h> 16 #include <linux/init.h> 17 #include <linux/module.h> 18 #include <linux/workqueue.h> 19 #include "hid_bpf_dispatch.h" 20 #include "entrypoints/entrypoints.lskel.h" 21 22 #define HID_BPF_MAX_PROGS 1024 /* keep this in sync with preloaded bpf, 23 * needs to be a power of 2 as we use it as 24 * a circular buffer 25 */ 26 27 #define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1)) 28 #define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1)) 29 30 /* 31 * represents one attached program stored in the hid jump table 32 */ 33 struct hid_bpf_prog_entry { 34 struct bpf_prog *prog; 35 struct hid_device *hdev; 36 enum hid_bpf_prog_type type; 37 u16 idx; 38 }; 39 40 struct hid_bpf_jmp_table { 41 struct bpf_map *map; 42 struct hid_bpf_prog_entry entries[HID_BPF_MAX_PROGS]; /* compacted list, circular buffer */ 43 int tail, head; 44 struct bpf_prog *progs[HID_BPF_MAX_PROGS]; /* idx -> progs mapping */ 45 unsigned long enabled[BITS_TO_LONGS(HID_BPF_MAX_PROGS)]; 46 }; 47 48 #define FOR_ENTRIES(__i, __start, __end) \ 49 for (__i = __start; CIRC_CNT(__end, __i, HID_BPF_MAX_PROGS); __i = NEXT(__i)) 50 51 static struct hid_bpf_jmp_table jmp_table; 52 53 static DEFINE_MUTEX(hid_bpf_attach_lock); /* held when attaching/detaching programs */ 54 55 static void hid_bpf_release_progs(struct work_struct *work); 56 57 static DECLARE_WORK(release_work, hid_bpf_release_progs); 58 59 BTF_ID_LIST(hid_bpf_btf_ids) 60 BTF_ID(func, hid_bpf_device_event) /* HID_BPF_PROG_TYPE_DEVICE_EVENT */ 61 BTF_ID(func, hid_bpf_rdesc_fixup) /* HID_BPF_PROG_TYPE_RDESC_FIXUP */ 62 63 static int hid_bpf_max_programs(enum hid_bpf_prog_type type) 64 { 65 switch (type) { 66 case HID_BPF_PROG_TYPE_DEVICE_EVENT: 67 return HID_BPF_MAX_PROGS_PER_DEV; 68 case HID_BPF_PROG_TYPE_RDESC_FIXUP: 69 return 1; 70 default: 71 return -EINVAL; 72 } 73 } 74 75 static int hid_bpf_program_count(struct hid_device *hdev, 76 struct bpf_prog *prog, 77 enum hid_bpf_prog_type type) 78 { 79 int i, n = 0; 80 81 if (type >= HID_BPF_PROG_TYPE_MAX) 82 return -EINVAL; 83 84 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 85 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 86 87 if (type != HID_BPF_PROG_TYPE_UNDEF && entry->type != type) 88 continue; 89 90 if (hdev && entry->hdev != hdev) 91 continue; 92 93 if (prog && entry->prog != prog) 94 continue; 95 96 n++; 97 } 98 99 return n; 100 } 101 102 __weak noinline int __hid_bpf_tail_call(struct hid_bpf_ctx *ctx) 103 { 104 return 0; 105 } 106 107 int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type, 108 struct hid_bpf_ctx_kern *ctx_kern) 109 { 110 struct hid_bpf_prog_list *prog_list; 111 int i, idx, err = 0; 112 113 rcu_read_lock(); 114 prog_list = rcu_dereference(hdev->bpf.progs[type]); 115 116 if (!prog_list) 117 goto out_unlock; 118 119 for (i = 0; i < prog_list->prog_cnt; i++) { 120 idx = prog_list->prog_idx[i]; 121 122 if (!test_bit(idx, jmp_table.enabled)) 123 continue; 124 125 ctx_kern->ctx.index = idx; 126 err = __hid_bpf_tail_call(&ctx_kern->ctx); 127 if (err < 0) 128 break; 129 if (err) 130 ctx_kern->ctx.retval = err; 131 } 132 133 out_unlock: 134 rcu_read_unlock(); 135 136 return err; 137 } 138 139 /* 140 * assign the list of programs attached to a given hid device. 141 */ 142 static void __hid_bpf_set_hdev_progs(struct hid_device *hdev, struct hid_bpf_prog_list *new_list, 143 enum hid_bpf_prog_type type) 144 { 145 struct hid_bpf_prog_list *old_list; 146 147 spin_lock(&hdev->bpf.progs_lock); 148 old_list = rcu_dereference_protected(hdev->bpf.progs[type], 149 lockdep_is_held(&hdev->bpf.progs_lock)); 150 rcu_assign_pointer(hdev->bpf.progs[type], new_list); 151 spin_unlock(&hdev->bpf.progs_lock); 152 synchronize_rcu(); 153 154 kfree(old_list); 155 } 156 157 /* 158 * allocate and populate the list of programs attached to a given hid device. 159 * 160 * Must be called under lock. 161 */ 162 static int hid_bpf_populate_hdev(struct hid_device *hdev, enum hid_bpf_prog_type type) 163 { 164 struct hid_bpf_prog_list *new_list; 165 int i; 166 167 if (type >= HID_BPF_PROG_TYPE_MAX || !hdev) 168 return -EINVAL; 169 170 if (hdev->bpf.destroyed) 171 return 0; 172 173 new_list = kzalloc(sizeof(*new_list), GFP_KERNEL); 174 if (!new_list) 175 return -ENOMEM; 176 177 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 178 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 179 180 if (entry->type == type && entry->hdev == hdev && 181 test_bit(entry->idx, jmp_table.enabled)) 182 new_list->prog_idx[new_list->prog_cnt++] = entry->idx; 183 } 184 185 __hid_bpf_set_hdev_progs(hdev, new_list, type); 186 187 return 0; 188 } 189 190 static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx) 191 { 192 skel_map_delete_elem(map_fd, &idx); 193 jmp_table.progs[idx] = NULL; 194 } 195 196 static void hid_bpf_release_progs(struct work_struct *work) 197 { 198 int i, j, n, map_fd = -1; 199 bool hdev_destroyed; 200 201 if (!jmp_table.map) 202 return; 203 204 /* retrieve a fd of our prog_array map in BPF */ 205 map_fd = skel_map_get_fd_by_id(jmp_table.map->id); 206 if (map_fd < 0) 207 return; 208 209 mutex_lock(&hid_bpf_attach_lock); /* protects against attaching new programs */ 210 211 /* detach unused progs from HID devices */ 212 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 213 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 214 enum hid_bpf_prog_type type; 215 struct hid_device *hdev; 216 217 if (test_bit(entry->idx, jmp_table.enabled)) 218 continue; 219 220 /* we have an attached prog */ 221 if (entry->hdev) { 222 hdev = entry->hdev; 223 type = entry->type; 224 /* 225 * hdev is still valid, even if we are called after hid_destroy_device(): 226 * when hid_bpf_attach() gets called, it takes a ref on the dev through 227 * bus_find_device() 228 */ 229 hdev_destroyed = hdev->bpf.destroyed; 230 231 hid_bpf_populate_hdev(hdev, type); 232 233 /* mark all other disabled progs from hdev of the given type as detached */ 234 FOR_ENTRIES(j, i, jmp_table.head) { 235 struct hid_bpf_prog_entry *next; 236 237 next = &jmp_table.entries[j]; 238 239 if (test_bit(next->idx, jmp_table.enabled)) 240 continue; 241 242 if (next->hdev == hdev && next->type == type) { 243 /* 244 * clear the hdev reference and decrement the device ref 245 * that was taken during bus_find_device() while calling 246 * hid_bpf_attach() 247 */ 248 next->hdev = NULL; 249 put_device(&hdev->dev); 250 } 251 } 252 253 /* if type was rdesc fixup and the device is not gone, reconnect device */ 254 if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP && !hdev_destroyed) 255 hid_bpf_reconnect(hdev); 256 } 257 } 258 259 /* remove all unused progs from the jump table */ 260 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 261 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 262 263 if (test_bit(entry->idx, jmp_table.enabled)) 264 continue; 265 266 if (entry->prog) 267 __hid_bpf_do_release_prog(map_fd, entry->idx); 268 } 269 270 /* compact the entry list */ 271 n = jmp_table.tail; 272 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 273 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 274 275 if (!test_bit(entry->idx, jmp_table.enabled)) 276 continue; 277 278 jmp_table.entries[n] = jmp_table.entries[i]; 279 n = NEXT(n); 280 } 281 282 jmp_table.head = n; 283 284 mutex_unlock(&hid_bpf_attach_lock); 285 286 if (map_fd >= 0) 287 close_fd(map_fd); 288 } 289 290 static void hid_bpf_release_prog_at(int idx) 291 { 292 int map_fd = -1; 293 294 /* retrieve a fd of our prog_array map in BPF */ 295 map_fd = skel_map_get_fd_by_id(jmp_table.map->id); 296 if (map_fd < 0) 297 return; 298 299 __hid_bpf_do_release_prog(map_fd, idx); 300 301 close(map_fd); 302 } 303 304 /* 305 * Insert the given BPF program represented by its fd in the jmp table. 306 * Returns the index in the jump table or a negative error. 307 */ 308 static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog) 309 { 310 int i, index = -1, map_fd = -1, err = -EINVAL; 311 312 /* retrieve a fd of our prog_array map in BPF */ 313 map_fd = skel_map_get_fd_by_id(jmp_table.map->id); 314 315 if (map_fd < 0) { 316 err = -EINVAL; 317 goto out; 318 } 319 320 /* find the first available index in the jmp_table */ 321 for (i = 0; i < HID_BPF_MAX_PROGS; i++) { 322 if (!jmp_table.progs[i] && index < 0) { 323 /* mark the index as used */ 324 jmp_table.progs[i] = prog; 325 index = i; 326 __set_bit(i, jmp_table.enabled); 327 } 328 } 329 if (index < 0) { 330 err = -ENOMEM; 331 goto out; 332 } 333 334 /* insert the program in the jump table */ 335 err = skel_map_update_elem(map_fd, &index, &prog_fd, 0); 336 if (err) 337 goto out; 338 339 /* return the index */ 340 err = index; 341 342 out: 343 if (err < 0) 344 __hid_bpf_do_release_prog(map_fd, index); 345 if (map_fd >= 0) 346 close_fd(map_fd); 347 return err; 348 } 349 350 int hid_bpf_get_prog_attach_type(struct bpf_prog *prog) 351 { 352 int prog_type = HID_BPF_PROG_TYPE_UNDEF; 353 int i; 354 355 for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) { 356 if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) { 357 prog_type = i; 358 break; 359 } 360 } 361 362 return prog_type; 363 } 364 365 static void hid_bpf_link_release(struct bpf_link *link) 366 { 367 struct hid_bpf_link *hid_link = 368 container_of(link, struct hid_bpf_link, link); 369 370 __clear_bit(hid_link->hid_table_index, jmp_table.enabled); 371 schedule_work(&release_work); 372 } 373 374 static void hid_bpf_link_dealloc(struct bpf_link *link) 375 { 376 struct hid_bpf_link *hid_link = 377 container_of(link, struct hid_bpf_link, link); 378 379 kfree(hid_link); 380 } 381 382 static void hid_bpf_link_show_fdinfo(const struct bpf_link *link, 383 struct seq_file *seq) 384 { 385 seq_printf(seq, 386 "attach_type:\tHID-BPF\n"); 387 } 388 389 static const struct bpf_link_ops hid_bpf_link_lops = { 390 .release = hid_bpf_link_release, 391 .dealloc = hid_bpf_link_dealloc, 392 .show_fdinfo = hid_bpf_link_show_fdinfo, 393 }; 394 395 /* called from syscall */ 396 noinline int 397 __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type, 398 int prog_fd, struct bpf_prog *prog, __u32 flags) 399 { 400 struct bpf_link_primer link_primer; 401 struct hid_bpf_link *link; 402 struct hid_bpf_prog_entry *prog_entry; 403 int cnt, err = -EINVAL, prog_table_idx = -1; 404 405 mutex_lock(&hid_bpf_attach_lock); 406 407 link = kzalloc(sizeof(*link), GFP_USER); 408 if (!link) { 409 err = -ENOMEM; 410 goto err_unlock; 411 } 412 413 bpf_link_init(&link->link, BPF_LINK_TYPE_UNSPEC, 414 &hid_bpf_link_lops, prog); 415 416 /* do not attach too many programs to a given HID device */ 417 cnt = hid_bpf_program_count(hdev, NULL, prog_type); 418 if (cnt < 0) { 419 err = cnt; 420 goto err_unlock; 421 } 422 423 if (cnt >= hid_bpf_max_programs(prog_type)) { 424 err = -E2BIG; 425 goto err_unlock; 426 } 427 428 prog_table_idx = hid_bpf_insert_prog(prog_fd, prog); 429 /* if the jmp table is full, abort */ 430 if (prog_table_idx < 0) { 431 err = prog_table_idx; 432 goto err_unlock; 433 } 434 435 if (flags & HID_BPF_FLAG_INSERT_HEAD) { 436 /* take the previous prog_entry slot */ 437 jmp_table.tail = PREV(jmp_table.tail); 438 prog_entry = &jmp_table.entries[jmp_table.tail]; 439 } else { 440 /* take the next prog_entry slot */ 441 prog_entry = &jmp_table.entries[jmp_table.head]; 442 jmp_table.head = NEXT(jmp_table.head); 443 } 444 445 /* we steal the ref here */ 446 prog_entry->prog = prog; 447 prog_entry->idx = prog_table_idx; 448 prog_entry->hdev = hdev; 449 prog_entry->type = prog_type; 450 451 /* finally store the index in the device list */ 452 err = hid_bpf_populate_hdev(hdev, prog_type); 453 if (err) { 454 hid_bpf_release_prog_at(prog_table_idx); 455 goto err_unlock; 456 } 457 458 link->hid_table_index = prog_table_idx; 459 460 err = bpf_link_prime(&link->link, &link_primer); 461 if (err) 462 goto err_unlock; 463 464 mutex_unlock(&hid_bpf_attach_lock); 465 466 return bpf_link_settle(&link_primer); 467 468 err_unlock: 469 mutex_unlock(&hid_bpf_attach_lock); 470 471 kfree(link); 472 473 return err; 474 } 475 476 void __hid_bpf_destroy_device(struct hid_device *hdev) 477 { 478 int type, i; 479 struct hid_bpf_prog_list *prog_list; 480 481 rcu_read_lock(); 482 483 for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) { 484 prog_list = rcu_dereference(hdev->bpf.progs[type]); 485 486 if (!prog_list) 487 continue; 488 489 for (i = 0; i < prog_list->prog_cnt; i++) 490 __clear_bit(prog_list->prog_idx[i], jmp_table.enabled); 491 } 492 493 rcu_read_unlock(); 494 495 for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) 496 __hid_bpf_set_hdev_progs(hdev, NULL, type); 497 498 /* schedule release of all detached progs */ 499 schedule_work(&release_work); 500 } 501 502 #define HID_BPF_PROGS_COUNT 1 503 504 static struct bpf_link *links[HID_BPF_PROGS_COUNT]; 505 static struct entrypoints_bpf *skel; 506 507 void hid_bpf_free_links_and_skel(void) 508 { 509 int i; 510 511 /* the following is enough to release all programs attached to hid */ 512 if (jmp_table.map) 513 bpf_map_put_with_uref(jmp_table.map); 514 515 for (i = 0; i < ARRAY_SIZE(links); i++) { 516 if (!IS_ERR_OR_NULL(links[i])) 517 bpf_link_put(links[i]); 518 } 519 entrypoints_bpf__destroy(skel); 520 } 521 522 #define ATTACH_AND_STORE_LINK(__name) do { \ 523 err = entrypoints_bpf__##__name##__attach(skel); \ 524 if (err) \ 525 goto out; \ 526 \ 527 links[idx] = bpf_link_get_from_fd(skel->links.__name##_fd); \ 528 if (IS_ERR(links[idx])) { \ 529 err = PTR_ERR(links[idx]); \ 530 goto out; \ 531 } \ 532 \ 533 /* Avoid taking over stdin/stdout/stderr of init process. Zeroing out \ 534 * makes skel_closenz() a no-op later in iterators_bpf__destroy(). \ 535 */ \ 536 close_fd(skel->links.__name##_fd); \ 537 skel->links.__name##_fd = 0; \ 538 idx++; \ 539 } while (0) 540 541 int hid_bpf_preload_skel(void) 542 { 543 int err, idx = 0; 544 545 skel = entrypoints_bpf__open(); 546 if (!skel) 547 return -ENOMEM; 548 549 err = entrypoints_bpf__load(skel); 550 if (err) 551 goto out; 552 553 jmp_table.map = bpf_map_get_with_uref(skel->maps.hid_jmp_table.map_fd); 554 if (IS_ERR(jmp_table.map)) { 555 err = PTR_ERR(jmp_table.map); 556 goto out; 557 } 558 559 ATTACH_AND_STORE_LINK(hid_tail_call); 560 561 return 0; 562 out: 563 hid_bpf_free_links_and_skel(); 564 return err; 565 } 566