1 // SPDX-License-Identifier: GPL-2.0-only 2 3 /* 4 * HID-BPF support for Linux 5 * 6 * Copyright (c) 2022 Benjamin Tissoires 7 */ 8 9 #include <linux/bitops.h> 10 #include <linux/btf.h> 11 #include <linux/btf_ids.h> 12 #include <linux/circ_buf.h> 13 #include <linux/filter.h> 14 #include <linux/hid.h> 15 #include <linux/hid_bpf.h> 16 #include <linux/init.h> 17 #include <linux/module.h> 18 #include <linux/workqueue.h> 19 #include "hid_bpf_dispatch.h" 20 #include "entrypoints/entrypoints.lskel.h" 21 22 #define HID_BPF_MAX_PROGS 1024 /* keep this in sync with preloaded bpf, 23 * needs to be a power of 2 as we use it as 24 * a circular buffer 25 */ 26 27 #define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1)) 28 #define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1)) 29 30 /* 31 * represents one attached program stored in the hid jump table 32 */ 33 struct hid_bpf_prog_entry { 34 struct bpf_prog *prog; 35 struct hid_device *hdev; 36 enum hid_bpf_prog_type type; 37 u16 idx; 38 }; 39 40 struct hid_bpf_jmp_table { 41 struct bpf_map *map; 42 struct bpf_map *prog_keys; 43 struct hid_bpf_prog_entry entries[HID_BPF_MAX_PROGS]; /* compacted list, circular buffer */ 44 int tail, head; 45 struct bpf_prog *progs[HID_BPF_MAX_PROGS]; /* idx -> progs mapping */ 46 unsigned long enabled[BITS_TO_LONGS(HID_BPF_MAX_PROGS)]; 47 }; 48 49 #define FOR_ENTRIES(__i, __start, __end) \ 50 for (__i = __start; CIRC_CNT(__end, __i, HID_BPF_MAX_PROGS); __i = NEXT(__i)) 51 52 static struct hid_bpf_jmp_table jmp_table; 53 54 static DEFINE_MUTEX(hid_bpf_attach_lock); /* held when attaching/detaching programs */ 55 56 static void hid_bpf_release_progs(struct work_struct *work); 57 58 static DECLARE_WORK(release_work, hid_bpf_release_progs); 59 60 BTF_ID_LIST(hid_bpf_btf_ids) 61 BTF_ID(func, hid_bpf_device_event) /* HID_BPF_PROG_TYPE_DEVICE_EVENT */ 62 63 static int hid_bpf_max_programs(enum hid_bpf_prog_type type) 64 { 65 switch (type) { 66 case HID_BPF_PROG_TYPE_DEVICE_EVENT: 67 return HID_BPF_MAX_PROGS_PER_DEV; 68 default: 69 return -EINVAL; 70 } 71 } 72 73 static int hid_bpf_program_count(struct hid_device *hdev, 74 struct bpf_prog *prog, 75 enum hid_bpf_prog_type type) 76 { 77 int i, n = 0; 78 79 if (type >= HID_BPF_PROG_TYPE_MAX) 80 return -EINVAL; 81 82 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 83 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 84 85 if (type != HID_BPF_PROG_TYPE_UNDEF && entry->type != type) 86 continue; 87 88 if (hdev && entry->hdev != hdev) 89 continue; 90 91 if (prog && entry->prog != prog) 92 continue; 93 94 n++; 95 } 96 97 return n; 98 } 99 100 __weak noinline int __hid_bpf_tail_call(struct hid_bpf_ctx *ctx) 101 { 102 return 0; 103 } 104 ALLOW_ERROR_INJECTION(__hid_bpf_tail_call, ERRNO); 105 106 int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type, 107 struct hid_bpf_ctx_kern *ctx_kern) 108 { 109 struct hid_bpf_prog_list *prog_list; 110 int i, idx, err = 0; 111 112 rcu_read_lock(); 113 prog_list = rcu_dereference(hdev->bpf.progs[type]); 114 115 if (!prog_list) 116 goto out_unlock; 117 118 for (i = 0; i < prog_list->prog_cnt; i++) { 119 idx = prog_list->prog_idx[i]; 120 121 if (!test_bit(idx, jmp_table.enabled)) 122 continue; 123 124 ctx_kern->ctx.index = idx; 125 err = __hid_bpf_tail_call(&ctx_kern->ctx); 126 if (err) 127 break; 128 } 129 130 out_unlock: 131 rcu_read_unlock(); 132 133 return err; 134 } 135 136 /* 137 * assign the list of programs attached to a given hid device. 138 */ 139 static void __hid_bpf_set_hdev_progs(struct hid_device *hdev, struct hid_bpf_prog_list *new_list, 140 enum hid_bpf_prog_type type) 141 { 142 struct hid_bpf_prog_list *old_list; 143 144 spin_lock(&hdev->bpf.progs_lock); 145 old_list = rcu_dereference_protected(hdev->bpf.progs[type], 146 lockdep_is_held(&hdev->bpf.progs_lock)); 147 rcu_assign_pointer(hdev->bpf.progs[type], new_list); 148 spin_unlock(&hdev->bpf.progs_lock); 149 synchronize_rcu(); 150 151 kfree(old_list); 152 } 153 154 /* 155 * allocate and populate the list of programs attached to a given hid device. 156 * 157 * Must be called under lock. 158 */ 159 static int hid_bpf_populate_hdev(struct hid_device *hdev, enum hid_bpf_prog_type type) 160 { 161 struct hid_bpf_prog_list *new_list; 162 int i; 163 164 if (type >= HID_BPF_PROG_TYPE_MAX || !hdev) 165 return -EINVAL; 166 167 if (hdev->bpf.destroyed) 168 return 0; 169 170 new_list = kzalloc(sizeof(*new_list), GFP_KERNEL); 171 if (!new_list) 172 return -ENOMEM; 173 174 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 175 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 176 177 if (entry->type == type && entry->hdev == hdev && 178 test_bit(entry->idx, jmp_table.enabled)) 179 new_list->prog_idx[new_list->prog_cnt++] = entry->idx; 180 } 181 182 __hid_bpf_set_hdev_progs(hdev, new_list, type); 183 184 return 0; 185 } 186 187 static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx) 188 { 189 skel_map_delete_elem(map_fd, &idx); 190 jmp_table.progs[idx] = NULL; 191 } 192 193 static void hid_bpf_release_progs(struct work_struct *work) 194 { 195 int i, j, n, map_fd = -1; 196 197 if (!jmp_table.map) 198 return; 199 200 /* retrieve a fd of our prog_array map in BPF */ 201 map_fd = skel_map_get_fd_by_id(jmp_table.map->id); 202 if (map_fd < 0) 203 return; 204 205 mutex_lock(&hid_bpf_attach_lock); /* protects against attaching new programs */ 206 207 /* detach unused progs from HID devices */ 208 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 209 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 210 enum hid_bpf_prog_type type; 211 struct hid_device *hdev; 212 213 if (test_bit(entry->idx, jmp_table.enabled)) 214 continue; 215 216 /* we have an attached prog */ 217 if (entry->hdev) { 218 hdev = entry->hdev; 219 type = entry->type; 220 221 hid_bpf_populate_hdev(hdev, type); 222 223 /* mark all other disabled progs from hdev of the given type as detached */ 224 FOR_ENTRIES(j, i, jmp_table.head) { 225 struct hid_bpf_prog_entry *next; 226 227 next = &jmp_table.entries[j]; 228 229 if (test_bit(next->idx, jmp_table.enabled)) 230 continue; 231 232 if (next->hdev == hdev && next->type == type) 233 next->hdev = NULL; 234 } 235 } 236 } 237 238 /* remove all unused progs from the jump table */ 239 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 240 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 241 242 if (test_bit(entry->idx, jmp_table.enabled)) 243 continue; 244 245 if (entry->prog) 246 __hid_bpf_do_release_prog(map_fd, entry->idx); 247 } 248 249 /* compact the entry list */ 250 n = jmp_table.tail; 251 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) { 252 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i]; 253 254 if (!test_bit(entry->idx, jmp_table.enabled)) 255 continue; 256 257 jmp_table.entries[n] = jmp_table.entries[i]; 258 n = NEXT(n); 259 } 260 261 jmp_table.head = n; 262 263 mutex_unlock(&hid_bpf_attach_lock); 264 265 if (map_fd >= 0) 266 close_fd(map_fd); 267 } 268 269 static void hid_bpf_release_prog_at(int idx) 270 { 271 int map_fd = -1; 272 273 /* retrieve a fd of our prog_array map in BPF */ 274 map_fd = skel_map_get_fd_by_id(jmp_table.map->id); 275 if (map_fd < 0) 276 return; 277 278 __hid_bpf_do_release_prog(map_fd, idx); 279 280 close(map_fd); 281 } 282 283 /* 284 * Insert the given BPF program represented by its fd in the jmp table. 285 * Returns the index in the jump table or a negative error. 286 */ 287 static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog) 288 { 289 int i, cnt, index = -1, map_fd = -1, progs_map_fd = -1, err = -EINVAL; 290 291 /* retrieve a fd of our prog_array map in BPF */ 292 map_fd = skel_map_get_fd_by_id(jmp_table.map->id); 293 /* take an fd for the table of progs we monitor with SEC("fexit/bpf_prog_release") */ 294 progs_map_fd = skel_map_get_fd_by_id(jmp_table.prog_keys->id); 295 296 if (map_fd < 0 || progs_map_fd < 0) { 297 err = -EINVAL; 298 goto out; 299 } 300 301 cnt = 0; 302 /* find the first available index in the jmp_table 303 * and count how many time this program has been inserted 304 */ 305 for (i = 0; i < HID_BPF_MAX_PROGS; i++) { 306 if (!jmp_table.progs[i] && index < 0) { 307 /* mark the index as used */ 308 jmp_table.progs[i] = prog; 309 index = i; 310 __set_bit(i, jmp_table.enabled); 311 cnt++; 312 } else { 313 if (jmp_table.progs[i] == prog) 314 cnt++; 315 } 316 } 317 if (index < 0) { 318 err = -ENOMEM; 319 goto out; 320 } 321 322 /* insert the program in the jump table */ 323 err = skel_map_update_elem(map_fd, &index, &prog_fd, 0); 324 if (err) 325 goto out; 326 327 /* insert the program in the prog list table */ 328 err = skel_map_update_elem(progs_map_fd, &prog, &cnt, 0); 329 if (err) 330 goto out; 331 332 /* return the index */ 333 err = index; 334 335 out: 336 if (err < 0) 337 __hid_bpf_do_release_prog(map_fd, index); 338 if (map_fd >= 0) 339 close_fd(map_fd); 340 if (progs_map_fd >= 0) 341 close_fd(progs_map_fd); 342 return err; 343 } 344 345 int hid_bpf_get_prog_attach_type(int prog_fd) 346 { 347 struct bpf_prog *prog = NULL; 348 int i; 349 int prog_type = HID_BPF_PROG_TYPE_UNDEF; 350 351 prog = bpf_prog_get(prog_fd); 352 if (IS_ERR(prog)) 353 return PTR_ERR(prog); 354 355 for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) { 356 if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) { 357 prog_type = i; 358 break; 359 } 360 } 361 362 bpf_prog_put(prog); 363 364 return prog_type; 365 } 366 367 /* called from syscall */ 368 noinline int 369 __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type, 370 int prog_fd, __u32 flags) 371 { 372 struct bpf_prog *prog = NULL; 373 struct hid_bpf_prog_entry *prog_entry; 374 int cnt, err = -EINVAL, prog_idx = -1; 375 376 /* take a ref on the prog itself */ 377 prog = bpf_prog_get(prog_fd); 378 if (IS_ERR(prog)) 379 return PTR_ERR(prog); 380 381 mutex_lock(&hid_bpf_attach_lock); 382 383 /* do not attach too many programs to a given HID device */ 384 cnt = hid_bpf_program_count(hdev, NULL, prog_type); 385 if (cnt < 0) { 386 err = cnt; 387 goto out_unlock; 388 } 389 390 if (cnt >= hid_bpf_max_programs(prog_type)) { 391 err = -E2BIG; 392 goto out_unlock; 393 } 394 395 prog_idx = hid_bpf_insert_prog(prog_fd, prog); 396 /* if the jmp table is full, abort */ 397 if (prog_idx < 0) { 398 err = prog_idx; 399 goto out_unlock; 400 } 401 402 if (flags & HID_BPF_FLAG_INSERT_HEAD) { 403 /* take the previous prog_entry slot */ 404 jmp_table.tail = PREV(jmp_table.tail); 405 prog_entry = &jmp_table.entries[jmp_table.tail]; 406 } else { 407 /* take the next prog_entry slot */ 408 prog_entry = &jmp_table.entries[jmp_table.head]; 409 jmp_table.head = NEXT(jmp_table.head); 410 } 411 412 /* we steal the ref here */ 413 prog_entry->prog = prog; 414 prog_entry->idx = prog_idx; 415 prog_entry->hdev = hdev; 416 prog_entry->type = prog_type; 417 418 /* finally store the index in the device list */ 419 err = hid_bpf_populate_hdev(hdev, prog_type); 420 if (err) 421 hid_bpf_release_prog_at(prog_idx); 422 423 out_unlock: 424 mutex_unlock(&hid_bpf_attach_lock); 425 426 /* we only use prog as a key in the various tables, so we don't need to actually 427 * increment the ref count. 428 */ 429 bpf_prog_put(prog); 430 431 return err; 432 } 433 434 void __hid_bpf_destroy_device(struct hid_device *hdev) 435 { 436 int type, i; 437 struct hid_bpf_prog_list *prog_list; 438 439 rcu_read_lock(); 440 441 for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) { 442 prog_list = rcu_dereference(hdev->bpf.progs[type]); 443 444 if (!prog_list) 445 continue; 446 447 for (i = 0; i < prog_list->prog_cnt; i++) 448 __clear_bit(prog_list->prog_idx[i], jmp_table.enabled); 449 } 450 451 rcu_read_unlock(); 452 453 for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) 454 __hid_bpf_set_hdev_progs(hdev, NULL, type); 455 456 /* schedule release of all detached progs */ 457 schedule_work(&release_work); 458 } 459 460 noinline bool 461 call_hid_bpf_prog_release(u64 prog_key, int table_cnt) 462 { 463 /* compare with how many refs are left in the bpf program */ 464 struct bpf_prog *prog = (struct bpf_prog *)prog_key; 465 int idx; 466 467 if (!prog) 468 return false; 469 470 if (atomic64_read(&prog->aux->refcnt) != table_cnt) 471 return false; 472 473 /* we don't need locking here because the entries in the progs table 474 * are stable: 475 * if there are other users (and the progs entries might change), we 476 * would return in the statement above. 477 */ 478 for (idx = 0; idx < HID_BPF_MAX_PROGS; idx++) { 479 if (jmp_table.progs[idx] == prog) { 480 __clear_bit(idx, jmp_table.enabled); 481 break; 482 } 483 } 484 if (idx >= HID_BPF_MAX_PROGS) { 485 /* should never happen if we get our refcount right */ 486 idx = -1; 487 } 488 489 /* schedule release of all detached progs */ 490 schedule_work(&release_work); 491 return idx >= 0; 492 } 493 494 #define HID_BPF_PROGS_COUNT 3 495 496 static struct bpf_link *links[HID_BPF_PROGS_COUNT]; 497 static struct entrypoints_bpf *skel; 498 499 void hid_bpf_free_links_and_skel(void) 500 { 501 int i; 502 503 /* the following is enough to release all programs attached to hid */ 504 if (jmp_table.prog_keys) 505 bpf_map_put_with_uref(jmp_table.prog_keys); 506 507 if (jmp_table.map) 508 bpf_map_put_with_uref(jmp_table.map); 509 510 for (i = 0; i < ARRAY_SIZE(links); i++) { 511 if (!IS_ERR_OR_NULL(links[i])) 512 bpf_link_put(links[i]); 513 } 514 entrypoints_bpf__destroy(skel); 515 } 516 517 #define ATTACH_AND_STORE_LINK(__name) do { \ 518 err = entrypoints_bpf__##__name##__attach(skel); \ 519 if (err) \ 520 goto out; \ 521 \ 522 links[idx] = bpf_link_get_from_fd(skel->links.__name##_fd); \ 523 if (IS_ERR(links[idx])) { \ 524 err = PTR_ERR(links[idx]); \ 525 goto out; \ 526 } \ 527 \ 528 /* Avoid taking over stdin/stdout/stderr of init process. Zeroing out \ 529 * makes skel_closenz() a no-op later in iterators_bpf__destroy(). \ 530 */ \ 531 close_fd(skel->links.__name##_fd); \ 532 skel->links.__name##_fd = 0; \ 533 idx++; \ 534 } while (0) 535 536 int hid_bpf_preload_skel(void) 537 { 538 int err, idx = 0; 539 540 skel = entrypoints_bpf__open(); 541 if (!skel) 542 return -ENOMEM; 543 544 err = entrypoints_bpf__load(skel); 545 if (err) 546 goto out; 547 548 jmp_table.map = bpf_map_get_with_uref(skel->maps.hid_jmp_table.map_fd); 549 if (IS_ERR(jmp_table.map)) { 550 err = PTR_ERR(jmp_table.map); 551 goto out; 552 } 553 554 jmp_table.prog_keys = bpf_map_get_with_uref(skel->maps.progs_map.map_fd); 555 if (IS_ERR(jmp_table.prog_keys)) { 556 err = PTR_ERR(jmp_table.prog_keys); 557 goto out; 558 } 559 560 ATTACH_AND_STORE_LINK(hid_tail_call); 561 ATTACH_AND_STORE_LINK(hid_prog_release); 562 ATTACH_AND_STORE_LINK(hid_free_inode); 563 564 return 0; 565 out: 566 hid_bpf_free_links_and_skel(); 567 return err; 568 } 569