1 // SPDX-License-Identifier: GPL-2.0 2 #include <linux/init.h> 3 #include <linux/static_call.h> 4 #include <linux/bug.h> 5 #include <linux/smp.h> 6 #include <linux/sort.h> 7 #include <linux/slab.h> 8 #include <linux/module.h> 9 #include <linux/cpu.h> 10 #include <linux/processor.h> 11 #include <asm/sections.h> 12 13 extern struct static_call_site __start_static_call_sites[], 14 __stop_static_call_sites[]; 15 extern struct static_call_tramp_key __start_static_call_tramp_key[], 16 __stop_static_call_tramp_key[]; 17 18 static bool static_call_initialized; 19 20 /* mutex to protect key modules/sites */ 21 static DEFINE_MUTEX(static_call_mutex); 22 23 static void static_call_lock(void) 24 { 25 mutex_lock(&static_call_mutex); 26 } 27 28 static void static_call_unlock(void) 29 { 30 mutex_unlock(&static_call_mutex); 31 } 32 33 static inline void *static_call_addr(struct static_call_site *site) 34 { 35 return (void *)((long)site->addr + (long)&site->addr); 36 } 37 38 39 static inline struct static_call_key *static_call_key(const struct static_call_site *site) 40 { 41 return (struct static_call_key *) 42 (((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS); 43 } 44 45 /* These assume the key is word-aligned. */ 46 static inline bool static_call_is_init(struct static_call_site *site) 47 { 48 return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT; 49 } 50 51 static inline bool static_call_is_tail(struct static_call_site *site) 52 { 53 return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL; 54 } 55 56 static inline void static_call_set_init(struct static_call_site *site) 57 { 58 site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) - 59 (long)&site->key; 60 } 61 62 static int static_call_site_cmp(const void *_a, const void *_b) 63 { 64 const struct static_call_site *a = _a; 65 const struct static_call_site *b = _b; 66 const struct static_call_key *key_a = static_call_key(a); 67 const struct static_call_key *key_b = static_call_key(b); 68 69 if (key_a < key_b) 70 return -1; 71 72 if (key_a > key_b) 73 return 1; 74 75 return 0; 76 } 77 78 static void static_call_site_swap(void *_a, void *_b, int size) 79 { 80 long delta = (unsigned long)_a - (unsigned long)_b; 81 struct static_call_site *a = _a; 82 struct static_call_site *b = _b; 83 struct static_call_site tmp = *a; 84 85 a->addr = b->addr - delta; 86 a->key = b->key - delta; 87 88 b->addr = tmp.addr + delta; 89 b->key = tmp.key + delta; 90 } 91 92 static inline void static_call_sort_entries(struct static_call_site *start, 93 struct static_call_site *stop) 94 { 95 sort(start, stop - start, sizeof(struct static_call_site), 96 static_call_site_cmp, static_call_site_swap); 97 } 98 99 static inline bool static_call_key_has_mods(struct static_call_key *key) 100 { 101 return !(key->type & 1); 102 } 103 104 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key) 105 { 106 if (!static_call_key_has_mods(key)) 107 return NULL; 108 109 return key->mods; 110 } 111 112 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key) 113 { 114 if (static_call_key_has_mods(key)) 115 return NULL; 116 117 return (struct static_call_site *)(key->type & ~1); 118 } 119 120 void __static_call_update(struct static_call_key *key, void *tramp, void *func) 121 { 122 struct static_call_site *site, *stop; 123 struct static_call_mod *site_mod, first; 124 125 cpus_read_lock(); 126 static_call_lock(); 127 128 if (key->func == func) 129 goto done; 130 131 key->func = func; 132 133 arch_static_call_transform(NULL, tramp, func, false); 134 135 /* 136 * If uninitialized, we'll not update the callsites, but they still 137 * point to the trampoline and we just patched that. 138 */ 139 if (WARN_ON_ONCE(!static_call_initialized)) 140 goto done; 141 142 first = (struct static_call_mod){ 143 .next = static_call_key_next(key), 144 .mod = NULL, 145 .sites = static_call_key_sites(key), 146 }; 147 148 for (site_mod = &first; site_mod; site_mod = site_mod->next) { 149 struct module *mod = site_mod->mod; 150 151 if (!site_mod->sites) { 152 /* 153 * This can happen if the static call key is defined in 154 * a module which doesn't use it. 155 * 156 * It also happens in the has_mods case, where the 157 * 'first' entry has no sites associated with it. 158 */ 159 continue; 160 } 161 162 stop = __stop_static_call_sites; 163 164 #ifdef CONFIG_MODULES 165 if (mod) { 166 stop = mod->static_call_sites + 167 mod->num_static_call_sites; 168 } 169 #endif 170 171 for (site = site_mod->sites; 172 site < stop && static_call_key(site) == key; site++) { 173 void *site_addr = static_call_addr(site); 174 175 if (static_call_is_init(site)) { 176 /* 177 * Don't write to call sites which were in 178 * initmem and have since been freed. 179 */ 180 if (!mod && system_state >= SYSTEM_RUNNING) 181 continue; 182 if (mod && !within_module_init((unsigned long)site_addr, mod)) 183 continue; 184 } 185 186 if (!kernel_text_address((unsigned long)site_addr)) { 187 WARN_ONCE(1, "can't patch static call site at %pS", 188 site_addr); 189 continue; 190 } 191 192 arch_static_call_transform(site_addr, NULL, func, 193 static_call_is_tail(site)); 194 } 195 } 196 197 done: 198 static_call_unlock(); 199 cpus_read_unlock(); 200 } 201 EXPORT_SYMBOL_GPL(__static_call_update); 202 203 static int __static_call_init(struct module *mod, 204 struct static_call_site *start, 205 struct static_call_site *stop) 206 { 207 struct static_call_site *site; 208 struct static_call_key *key, *prev_key = NULL; 209 struct static_call_mod *site_mod; 210 211 if (start == stop) 212 return 0; 213 214 static_call_sort_entries(start, stop); 215 216 for (site = start; site < stop; site++) { 217 void *site_addr = static_call_addr(site); 218 219 if ((mod && within_module_init((unsigned long)site_addr, mod)) || 220 (!mod && init_section_contains(site_addr, 1))) 221 static_call_set_init(site); 222 223 key = static_call_key(site); 224 if (key != prev_key) { 225 prev_key = key; 226 227 /* 228 * For vmlinux (!mod) avoid the allocation by storing 229 * the sites pointer in the key itself. Also see 230 * __static_call_update()'s @first. 231 * 232 * This allows architectures (eg. x86) to call 233 * static_call_init() before memory allocation works. 234 */ 235 if (!mod) { 236 key->sites = site; 237 key->type |= 1; 238 goto do_transform; 239 } 240 241 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 242 if (!site_mod) 243 return -ENOMEM; 244 245 /* 246 * When the key has a direct sites pointer, extract 247 * that into an explicit struct static_call_mod, so we 248 * can have a list of modules. 249 */ 250 if (static_call_key_sites(key)) { 251 site_mod->mod = NULL; 252 site_mod->next = NULL; 253 site_mod->sites = static_call_key_sites(key); 254 255 key->mods = site_mod; 256 257 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 258 if (!site_mod) 259 return -ENOMEM; 260 } 261 262 site_mod->mod = mod; 263 site_mod->sites = site; 264 site_mod->next = static_call_key_next(key); 265 key->mods = site_mod; 266 } 267 268 do_transform: 269 arch_static_call_transform(site_addr, NULL, key->func, 270 static_call_is_tail(site)); 271 } 272 273 return 0; 274 } 275 276 static int addr_conflict(struct static_call_site *site, void *start, void *end) 277 { 278 unsigned long addr = (unsigned long)static_call_addr(site); 279 280 if (addr <= (unsigned long)end && 281 addr + CALL_INSN_SIZE > (unsigned long)start) 282 return 1; 283 284 return 0; 285 } 286 287 static int __static_call_text_reserved(struct static_call_site *iter_start, 288 struct static_call_site *iter_stop, 289 void *start, void *end) 290 { 291 struct static_call_site *iter = iter_start; 292 293 while (iter < iter_stop) { 294 if (addr_conflict(iter, start, end)) 295 return 1; 296 iter++; 297 } 298 299 return 0; 300 } 301 302 #ifdef CONFIG_MODULES 303 304 static int __static_call_mod_text_reserved(void *start, void *end) 305 { 306 struct module *mod; 307 int ret; 308 309 preempt_disable(); 310 mod = __module_text_address((unsigned long)start); 311 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod); 312 if (!try_module_get(mod)) 313 mod = NULL; 314 preempt_enable(); 315 316 if (!mod) 317 return 0; 318 319 ret = __static_call_text_reserved(mod->static_call_sites, 320 mod->static_call_sites + mod->num_static_call_sites, 321 start, end); 322 323 module_put(mod); 324 325 return ret; 326 } 327 328 static unsigned long tramp_key_lookup(unsigned long addr) 329 { 330 struct static_call_tramp_key *start = __start_static_call_tramp_key; 331 struct static_call_tramp_key *stop = __stop_static_call_tramp_key; 332 struct static_call_tramp_key *tramp_key; 333 334 for (tramp_key = start; tramp_key != stop; tramp_key++) { 335 unsigned long tramp; 336 337 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp; 338 if (tramp == addr) 339 return (long)tramp_key->key + (long)&tramp_key->key; 340 } 341 342 return 0; 343 } 344 345 static int static_call_add_module(struct module *mod) 346 { 347 struct static_call_site *start = mod->static_call_sites; 348 struct static_call_site *stop = start + mod->num_static_call_sites; 349 struct static_call_site *site; 350 351 for (site = start; site != stop; site++) { 352 unsigned long addr = (unsigned long)static_call_key(site); 353 unsigned long key; 354 355 /* 356 * Is the key is exported, 'addr' points to the key, which 357 * means modules are allowed to call static_call_update() on 358 * it. 359 * 360 * Otherwise, the key isn't exported, and 'addr' points to the 361 * trampoline so we need to lookup the key. 362 * 363 * We go through this dance to prevent crazy modules from 364 * abusing sensitive static calls. 365 */ 366 if (!kernel_text_address(addr)) 367 continue; 368 369 key = tramp_key_lookup(addr); 370 if (!key) { 371 pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n", 372 static_call_addr(site)); 373 return -EINVAL; 374 } 375 376 site->key = (key - (long)&site->key) | 377 (site->key & STATIC_CALL_SITE_FLAGS); 378 } 379 380 return __static_call_init(mod, start, stop); 381 } 382 383 static void static_call_del_module(struct module *mod) 384 { 385 struct static_call_site *start = mod->static_call_sites; 386 struct static_call_site *stop = mod->static_call_sites + 387 mod->num_static_call_sites; 388 struct static_call_key *key, *prev_key = NULL; 389 struct static_call_mod *site_mod, **prev; 390 struct static_call_site *site; 391 392 for (site = start; site < stop; site++) { 393 key = static_call_key(site); 394 if (key == prev_key) 395 continue; 396 397 prev_key = key; 398 399 for (prev = &key->mods, site_mod = key->mods; 400 site_mod && site_mod->mod != mod; 401 prev = &site_mod->next, site_mod = site_mod->next) 402 ; 403 404 if (!site_mod) 405 continue; 406 407 *prev = site_mod->next; 408 kfree(site_mod); 409 } 410 } 411 412 static int static_call_module_notify(struct notifier_block *nb, 413 unsigned long val, void *data) 414 { 415 struct module *mod = data; 416 int ret = 0; 417 418 cpus_read_lock(); 419 static_call_lock(); 420 421 switch (val) { 422 case MODULE_STATE_COMING: 423 ret = static_call_add_module(mod); 424 if (ret) { 425 WARN(1, "Failed to allocate memory for static calls"); 426 static_call_del_module(mod); 427 } 428 break; 429 case MODULE_STATE_GOING: 430 static_call_del_module(mod); 431 break; 432 } 433 434 static_call_unlock(); 435 cpus_read_unlock(); 436 437 return notifier_from_errno(ret); 438 } 439 440 static struct notifier_block static_call_module_nb = { 441 .notifier_call = static_call_module_notify, 442 }; 443 444 #else 445 446 static inline int __static_call_mod_text_reserved(void *start, void *end) 447 { 448 return 0; 449 } 450 451 #endif /* CONFIG_MODULES */ 452 453 int static_call_text_reserved(void *start, void *end) 454 { 455 int ret = __static_call_text_reserved(__start_static_call_sites, 456 __stop_static_call_sites, start, end); 457 458 if (ret) 459 return ret; 460 461 return __static_call_mod_text_reserved(start, end); 462 } 463 464 int __init static_call_init(void) 465 { 466 int ret; 467 468 if (static_call_initialized) 469 return 0; 470 471 cpus_read_lock(); 472 static_call_lock(); 473 ret = __static_call_init(NULL, __start_static_call_sites, 474 __stop_static_call_sites); 475 static_call_unlock(); 476 cpus_read_unlock(); 477 478 if (ret) { 479 pr_err("Failed to allocate memory for static_call!\n"); 480 BUG(); 481 } 482 483 static_call_initialized = true; 484 485 #ifdef CONFIG_MODULES 486 register_module_notifier(&static_call_module_nb); 487 #endif 488 return 0; 489 } 490 early_initcall(static_call_init); 491 492 long __static_call_return0(void) 493 { 494 return 0; 495 } 496 497 #ifdef CONFIG_STATIC_CALL_SELFTEST 498 499 static int func_a(int x) 500 { 501 return x+1; 502 } 503 504 static int func_b(int x) 505 { 506 return x+2; 507 } 508 509 DEFINE_STATIC_CALL(sc_selftest, func_a); 510 511 static struct static_call_data { 512 int (*func)(int); 513 int val; 514 int expect; 515 } static_call_data [] __initdata = { 516 { NULL, 2, 3 }, 517 { func_b, 2, 4 }, 518 { func_a, 2, 3 } 519 }; 520 521 static int __init test_static_call_init(void) 522 { 523 int i; 524 525 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) { 526 struct static_call_data *scd = &static_call_data[i]; 527 528 if (scd->func) 529 static_call_update(sc_selftest, scd->func); 530 531 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect); 532 } 533 534 return 0; 535 } 536 early_initcall(test_static_call_init); 537 538 #endif /* CONFIG_STATIC_CALL_SELFTEST */ 539