1 // SPDX-License-Identifier: GPL-2.0 2 #include <linux/init.h> 3 #include <linux/static_call.h> 4 #include <linux/bug.h> 5 #include <linux/smp.h> 6 #include <linux/sort.h> 7 #include <linux/slab.h> 8 #include <linux/module.h> 9 #include <linux/cpu.h> 10 #include <linux/processor.h> 11 #include <asm/sections.h> 12 13 extern struct static_call_site __start_static_call_sites[], 14 __stop_static_call_sites[]; 15 extern struct static_call_tramp_key __start_static_call_tramp_key[], 16 __stop_static_call_tramp_key[]; 17 18 static bool static_call_initialized; 19 20 /* mutex to protect key modules/sites */ 21 static DEFINE_MUTEX(static_call_mutex); 22 23 static void static_call_lock(void) 24 { 25 mutex_lock(&static_call_mutex); 26 } 27 28 static void static_call_unlock(void) 29 { 30 mutex_unlock(&static_call_mutex); 31 } 32 33 static inline void *static_call_addr(struct static_call_site *site) 34 { 35 return (void *)((long)site->addr + (long)&site->addr); 36 } 37 38 39 static inline struct static_call_key *static_call_key(const struct static_call_site *site) 40 { 41 return (struct static_call_key *) 42 (((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS); 43 } 44 45 /* These assume the key is word-aligned. */ 46 static inline bool static_call_is_init(struct static_call_site *site) 47 { 48 return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT; 49 } 50 51 static inline bool static_call_is_tail(struct static_call_site *site) 52 { 53 return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL; 54 } 55 56 static inline void static_call_set_init(struct static_call_site *site) 57 { 58 site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) - 59 (long)&site->key; 60 } 61 62 static int static_call_site_cmp(const void *_a, const void *_b) 63 { 64 const struct static_call_site *a = _a; 65 const struct static_call_site *b = _b; 66 const struct static_call_key *key_a = static_call_key(a); 67 const struct static_call_key *key_b = static_call_key(b); 68 69 if (key_a < key_b) 70 return -1; 71 72 if (key_a > key_b) 73 return 1; 74 75 return 0; 76 } 77 78 static void static_call_site_swap(void *_a, void *_b, int size) 79 { 80 long delta = (unsigned long)_a - (unsigned long)_b; 81 struct static_call_site *a = _a; 82 struct static_call_site *b = _b; 83 struct static_call_site tmp = *a; 84 85 a->addr = b->addr - delta; 86 a->key = b->key - delta; 87 88 b->addr = tmp.addr + delta; 89 b->key = tmp.key + delta; 90 } 91 92 static inline void static_call_sort_entries(struct static_call_site *start, 93 struct static_call_site *stop) 94 { 95 sort(start, stop - start, sizeof(struct static_call_site), 96 static_call_site_cmp, static_call_site_swap); 97 } 98 99 static inline bool static_call_key_has_mods(struct static_call_key *key) 100 { 101 return !(key->type & 1); 102 } 103 104 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key) 105 { 106 if (!static_call_key_has_mods(key)) 107 return NULL; 108 109 return key->mods; 110 } 111 112 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key) 113 { 114 if (static_call_key_has_mods(key)) 115 return NULL; 116 117 return (struct static_call_site *)(key->type & ~1); 118 } 119 120 void __static_call_update(struct static_call_key *key, void *tramp, void *func) 121 { 122 struct static_call_site *site, *stop; 123 struct static_call_mod *site_mod, first; 124 125 cpus_read_lock(); 126 static_call_lock(); 127 128 if (key->func == func) 129 goto done; 130 131 key->func = func; 132 133 arch_static_call_transform(NULL, tramp, func, false); 134 135 /* 136 * If uninitialized, we'll not update the callsites, but they still 137 * point to the trampoline and we just patched that. 138 */ 139 if (WARN_ON_ONCE(!static_call_initialized)) 140 goto done; 141 142 first = (struct static_call_mod){ 143 .next = static_call_key_next(key), 144 .mod = NULL, 145 .sites = static_call_key_sites(key), 146 }; 147 148 for (site_mod = &first; site_mod; site_mod = site_mod->next) { 149 struct module *mod = site_mod->mod; 150 151 if (!site_mod->sites) { 152 /* 153 * This can happen if the static call key is defined in 154 * a module which doesn't use it. 155 * 156 * It also happens in the has_mods case, where the 157 * 'first' entry has no sites associated with it. 158 */ 159 continue; 160 } 161 162 stop = __stop_static_call_sites; 163 164 #ifdef CONFIG_MODULES 165 if (mod) { 166 stop = mod->static_call_sites + 167 mod->num_static_call_sites; 168 } 169 #endif 170 171 for (site = site_mod->sites; 172 site < stop && static_call_key(site) == key; site++) { 173 void *site_addr = static_call_addr(site); 174 175 if (static_call_is_init(site)) { 176 /* 177 * Don't write to call sites which were in 178 * initmem and have since been freed. 179 */ 180 if (!mod && system_state >= SYSTEM_RUNNING) 181 continue; 182 if (mod && !within_module_init((unsigned long)site_addr, mod)) 183 continue; 184 } 185 186 if (!kernel_text_address((unsigned long)site_addr)) { 187 WARN_ONCE(1, "can't patch static call site at %pS", 188 site_addr); 189 continue; 190 } 191 192 arch_static_call_transform(site_addr, NULL, func, 193 static_call_is_tail(site)); 194 } 195 } 196 197 done: 198 static_call_unlock(); 199 cpus_read_unlock(); 200 } 201 EXPORT_SYMBOL_GPL(__static_call_update); 202 203 static int __static_call_init(struct module *mod, 204 struct static_call_site *start, 205 struct static_call_site *stop) 206 { 207 struct static_call_site *site; 208 struct static_call_key *key, *prev_key = NULL; 209 struct static_call_mod *site_mod; 210 211 if (start == stop) 212 return 0; 213 214 static_call_sort_entries(start, stop); 215 216 for (site = start; site < stop; site++) { 217 void *site_addr = static_call_addr(site); 218 219 if ((mod && within_module_init((unsigned long)site_addr, mod)) || 220 (!mod && init_section_contains(site_addr, 1))) 221 static_call_set_init(site); 222 223 key = static_call_key(site); 224 if (key != prev_key) { 225 prev_key = key; 226 227 /* 228 * For vmlinux (!mod) avoid the allocation by storing 229 * the sites pointer in the key itself. Also see 230 * __static_call_update()'s @first. 231 * 232 * This allows architectures (eg. x86) to call 233 * static_call_init() before memory allocation works. 234 */ 235 if (!mod) { 236 key->sites = site; 237 key->type |= 1; 238 goto do_transform; 239 } 240 241 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 242 if (!site_mod) 243 return -ENOMEM; 244 245 /* 246 * When the key has a direct sites pointer, extract 247 * that into an explicit struct static_call_mod, so we 248 * can have a list of modules. 249 */ 250 if (static_call_key_sites(key)) { 251 site_mod->mod = NULL; 252 site_mod->next = NULL; 253 site_mod->sites = static_call_key_sites(key); 254 255 key->mods = site_mod; 256 257 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 258 if (!site_mod) 259 return -ENOMEM; 260 } 261 262 site_mod->mod = mod; 263 site_mod->sites = site; 264 site_mod->next = static_call_key_next(key); 265 key->mods = site_mod; 266 } 267 268 do_transform: 269 arch_static_call_transform(site_addr, NULL, key->func, 270 static_call_is_tail(site)); 271 } 272 273 return 0; 274 } 275 276 static int addr_conflict(struct static_call_site *site, void *start, void *end) 277 { 278 unsigned long addr = (unsigned long)static_call_addr(site); 279 280 if (addr <= (unsigned long)end && 281 addr + CALL_INSN_SIZE > (unsigned long)start) 282 return 1; 283 284 return 0; 285 } 286 287 static int __static_call_text_reserved(struct static_call_site *iter_start, 288 struct static_call_site *iter_stop, 289 void *start, void *end) 290 { 291 struct static_call_site *iter = iter_start; 292 293 while (iter < iter_stop) { 294 if (addr_conflict(iter, start, end)) 295 return 1; 296 iter++; 297 } 298 299 return 0; 300 } 301 302 #ifdef CONFIG_MODULES 303 304 static int __static_call_mod_text_reserved(void *start, void *end) 305 { 306 struct module *mod; 307 int ret; 308 309 preempt_disable(); 310 mod = __module_text_address((unsigned long)start); 311 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod); 312 if (!try_module_get(mod)) 313 mod = NULL; 314 preempt_enable(); 315 316 if (!mod) 317 return 0; 318 319 ret = __static_call_text_reserved(mod->static_call_sites, 320 mod->static_call_sites + mod->num_static_call_sites, 321 start, end); 322 323 module_put(mod); 324 325 return ret; 326 } 327 328 static unsigned long tramp_key_lookup(unsigned long addr) 329 { 330 struct static_call_tramp_key *start = __start_static_call_tramp_key; 331 struct static_call_tramp_key *stop = __stop_static_call_tramp_key; 332 struct static_call_tramp_key *tramp_key; 333 334 for (tramp_key = start; tramp_key != stop; tramp_key++) { 335 unsigned long tramp; 336 337 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp; 338 if (tramp == addr) 339 return (long)tramp_key->key + (long)&tramp_key->key; 340 } 341 342 return 0; 343 } 344 345 static int static_call_add_module(struct module *mod) 346 { 347 struct static_call_site *start = mod->static_call_sites; 348 struct static_call_site *stop = start + mod->num_static_call_sites; 349 struct static_call_site *site; 350 351 for (site = start; site != stop; site++) { 352 unsigned long s_key = (long)site->key + (long)&site->key; 353 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS; 354 unsigned long key; 355 356 /* 357 * Is the key is exported, 'addr' points to the key, which 358 * means modules are allowed to call static_call_update() on 359 * it. 360 * 361 * Otherwise, the key isn't exported, and 'addr' points to the 362 * trampoline so we need to lookup the key. 363 * 364 * We go through this dance to prevent crazy modules from 365 * abusing sensitive static calls. 366 */ 367 if (!kernel_text_address(addr)) 368 continue; 369 370 key = tramp_key_lookup(addr); 371 if (!key) { 372 pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n", 373 static_call_addr(site)); 374 return -EINVAL; 375 } 376 377 key |= s_key & STATIC_CALL_SITE_FLAGS; 378 site->key = key - (long)&site->key; 379 } 380 381 return __static_call_init(mod, start, stop); 382 } 383 384 static void static_call_del_module(struct module *mod) 385 { 386 struct static_call_site *start = mod->static_call_sites; 387 struct static_call_site *stop = mod->static_call_sites + 388 mod->num_static_call_sites; 389 struct static_call_key *key, *prev_key = NULL; 390 struct static_call_mod *site_mod, **prev; 391 struct static_call_site *site; 392 393 for (site = start; site < stop; site++) { 394 key = static_call_key(site); 395 if (key == prev_key) 396 continue; 397 398 prev_key = key; 399 400 for (prev = &key->mods, site_mod = key->mods; 401 site_mod && site_mod->mod != mod; 402 prev = &site_mod->next, site_mod = site_mod->next) 403 ; 404 405 if (!site_mod) 406 continue; 407 408 *prev = site_mod->next; 409 kfree(site_mod); 410 } 411 } 412 413 static int static_call_module_notify(struct notifier_block *nb, 414 unsigned long val, void *data) 415 { 416 struct module *mod = data; 417 int ret = 0; 418 419 cpus_read_lock(); 420 static_call_lock(); 421 422 switch (val) { 423 case MODULE_STATE_COMING: 424 ret = static_call_add_module(mod); 425 if (ret) { 426 WARN(1, "Failed to allocate memory for static calls"); 427 static_call_del_module(mod); 428 } 429 break; 430 case MODULE_STATE_GOING: 431 static_call_del_module(mod); 432 break; 433 } 434 435 static_call_unlock(); 436 cpus_read_unlock(); 437 438 return notifier_from_errno(ret); 439 } 440 441 static struct notifier_block static_call_module_nb = { 442 .notifier_call = static_call_module_notify, 443 }; 444 445 #else 446 447 static inline int __static_call_mod_text_reserved(void *start, void *end) 448 { 449 return 0; 450 } 451 452 #endif /* CONFIG_MODULES */ 453 454 int static_call_text_reserved(void *start, void *end) 455 { 456 int ret = __static_call_text_reserved(__start_static_call_sites, 457 __stop_static_call_sites, start, end); 458 459 if (ret) 460 return ret; 461 462 return __static_call_mod_text_reserved(start, end); 463 } 464 465 int __init static_call_init(void) 466 { 467 int ret; 468 469 if (static_call_initialized) 470 return 0; 471 472 cpus_read_lock(); 473 static_call_lock(); 474 ret = __static_call_init(NULL, __start_static_call_sites, 475 __stop_static_call_sites); 476 static_call_unlock(); 477 cpus_read_unlock(); 478 479 if (ret) { 480 pr_err("Failed to allocate memory for static_call!\n"); 481 BUG(); 482 } 483 484 static_call_initialized = true; 485 486 #ifdef CONFIG_MODULES 487 register_module_notifier(&static_call_module_nb); 488 #endif 489 return 0; 490 } 491 early_initcall(static_call_init); 492 493 long __static_call_return0(void) 494 { 495 return 0; 496 } 497 498 #ifdef CONFIG_STATIC_CALL_SELFTEST 499 500 static int func_a(int x) 501 { 502 return x+1; 503 } 504 505 static int func_b(int x) 506 { 507 return x+2; 508 } 509 510 DEFINE_STATIC_CALL(sc_selftest, func_a); 511 512 static struct static_call_data { 513 int (*func)(int); 514 int val; 515 int expect; 516 } static_call_data [] __initdata = { 517 { NULL, 2, 3 }, 518 { func_b, 2, 4 }, 519 { func_a, 2, 3 } 520 }; 521 522 static int __init test_static_call_init(void) 523 { 524 int i; 525 526 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) { 527 struct static_call_data *scd = &static_call_data[i]; 528 529 if (scd->func) 530 static_call_update(sc_selftest, scd->func); 531 532 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect); 533 } 534 535 return 0; 536 } 537 early_initcall(test_static_call_init); 538 539 #endif /* CONFIG_STATIC_CALL_SELFTEST */ 540