xref: /openbmc/linux/kernel/static_call.c (revision 6417f03132a6952cd17ddd8eaddbac92b61b17e0)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/init.h>
3 #include <linux/static_call.h>
4 #include <linux/bug.h>
5 #include <linux/smp.h>
6 #include <linux/sort.h>
7 #include <linux/slab.h>
8 #include <linux/module.h>
9 #include <linux/cpu.h>
10 #include <linux/processor.h>
11 #include <asm/sections.h>
12 
13 extern struct static_call_site __start_static_call_sites[],
14 			       __stop_static_call_sites[];
15 extern struct static_call_tramp_key __start_static_call_tramp_key[],
16 				    __stop_static_call_tramp_key[];
17 
18 static bool static_call_initialized;
19 
20 /* mutex to protect key modules/sites */
21 static DEFINE_MUTEX(static_call_mutex);
22 
23 static void static_call_lock(void)
24 {
25 	mutex_lock(&static_call_mutex);
26 }
27 
28 static void static_call_unlock(void)
29 {
30 	mutex_unlock(&static_call_mutex);
31 }
32 
33 static inline void *static_call_addr(struct static_call_site *site)
34 {
35 	return (void *)((long)site->addr + (long)&site->addr);
36 }
37 
38 
39 static inline struct static_call_key *static_call_key(const struct static_call_site *site)
40 {
41 	return (struct static_call_key *)
42 		(((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS);
43 }
44 
45 /* These assume the key is word-aligned. */
46 static inline bool static_call_is_init(struct static_call_site *site)
47 {
48 	return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT;
49 }
50 
51 static inline bool static_call_is_tail(struct static_call_site *site)
52 {
53 	return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL;
54 }
55 
56 static inline void static_call_set_init(struct static_call_site *site)
57 {
58 	site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) -
59 		    (long)&site->key;
60 }
61 
62 static int static_call_site_cmp(const void *_a, const void *_b)
63 {
64 	const struct static_call_site *a = _a;
65 	const struct static_call_site *b = _b;
66 	const struct static_call_key *key_a = static_call_key(a);
67 	const struct static_call_key *key_b = static_call_key(b);
68 
69 	if (key_a < key_b)
70 		return -1;
71 
72 	if (key_a > key_b)
73 		return 1;
74 
75 	return 0;
76 }
77 
78 static void static_call_site_swap(void *_a, void *_b, int size)
79 {
80 	long delta = (unsigned long)_a - (unsigned long)_b;
81 	struct static_call_site *a = _a;
82 	struct static_call_site *b = _b;
83 	struct static_call_site tmp = *a;
84 
85 	a->addr = b->addr  - delta;
86 	a->key  = b->key   - delta;
87 
88 	b->addr = tmp.addr + delta;
89 	b->key  = tmp.key  + delta;
90 }
91 
92 static inline void static_call_sort_entries(struct static_call_site *start,
93 					    struct static_call_site *stop)
94 {
95 	sort(start, stop - start, sizeof(struct static_call_site),
96 	     static_call_site_cmp, static_call_site_swap);
97 }
98 
99 static inline bool static_call_key_has_mods(struct static_call_key *key)
100 {
101 	return !(key->type & 1);
102 }
103 
104 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key)
105 {
106 	if (!static_call_key_has_mods(key))
107 		return NULL;
108 
109 	return key->mods;
110 }
111 
112 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key)
113 {
114 	if (static_call_key_has_mods(key))
115 		return NULL;
116 
117 	return (struct static_call_site *)(key->type & ~1);
118 }
119 
120 void __static_call_update(struct static_call_key *key, void *tramp, void *func)
121 {
122 	struct static_call_site *site, *stop;
123 	struct static_call_mod *site_mod, first;
124 
125 	cpus_read_lock();
126 	static_call_lock();
127 
128 	if (key->func == func)
129 		goto done;
130 
131 	key->func = func;
132 
133 	arch_static_call_transform(NULL, tramp, func, false);
134 
135 	/*
136 	 * If uninitialized, we'll not update the callsites, but they still
137 	 * point to the trampoline and we just patched that.
138 	 */
139 	if (WARN_ON_ONCE(!static_call_initialized))
140 		goto done;
141 
142 	first = (struct static_call_mod){
143 		.next = static_call_key_next(key),
144 		.mod = NULL,
145 		.sites = static_call_key_sites(key),
146 	};
147 
148 	for (site_mod = &first; site_mod; site_mod = site_mod->next) {
149 		struct module *mod = site_mod->mod;
150 
151 		if (!site_mod->sites) {
152 			/*
153 			 * This can happen if the static call key is defined in
154 			 * a module which doesn't use it.
155 			 *
156 			 * It also happens in the has_mods case, where the
157 			 * 'first' entry has no sites associated with it.
158 			 */
159 			continue;
160 		}
161 
162 		stop = __stop_static_call_sites;
163 
164 #ifdef CONFIG_MODULES
165 		if (mod) {
166 			stop = mod->static_call_sites +
167 			       mod->num_static_call_sites;
168 		}
169 #endif
170 
171 		for (site = site_mod->sites;
172 		     site < stop && static_call_key(site) == key; site++) {
173 			void *site_addr = static_call_addr(site);
174 
175 			if (static_call_is_init(site)) {
176 				/*
177 				 * Don't write to call sites which were in
178 				 * initmem and have since been freed.
179 				 */
180 				if (!mod && system_state >= SYSTEM_RUNNING)
181 					continue;
182 				if (mod && !within_module_init((unsigned long)site_addr, mod))
183 					continue;
184 			}
185 
186 			if (!kernel_text_address((unsigned long)site_addr)) {
187 				WARN_ONCE(1, "can't patch static call site at %pS",
188 					  site_addr);
189 				continue;
190 			}
191 
192 			arch_static_call_transform(site_addr, NULL, func,
193 				static_call_is_tail(site));
194 		}
195 	}
196 
197 done:
198 	static_call_unlock();
199 	cpus_read_unlock();
200 }
201 EXPORT_SYMBOL_GPL(__static_call_update);
202 
203 static int __static_call_init(struct module *mod,
204 			      struct static_call_site *start,
205 			      struct static_call_site *stop)
206 {
207 	struct static_call_site *site;
208 	struct static_call_key *key, *prev_key = NULL;
209 	struct static_call_mod *site_mod;
210 
211 	if (start == stop)
212 		return 0;
213 
214 	static_call_sort_entries(start, stop);
215 
216 	for (site = start; site < stop; site++) {
217 		void *site_addr = static_call_addr(site);
218 
219 		if ((mod && within_module_init((unsigned long)site_addr, mod)) ||
220 		    (!mod && init_section_contains(site_addr, 1)))
221 			static_call_set_init(site);
222 
223 		key = static_call_key(site);
224 		if (key != prev_key) {
225 			prev_key = key;
226 
227 			/*
228 			 * For vmlinux (!mod) avoid the allocation by storing
229 			 * the sites pointer in the key itself. Also see
230 			 * __static_call_update()'s @first.
231 			 *
232 			 * This allows architectures (eg. x86) to call
233 			 * static_call_init() before memory allocation works.
234 			 */
235 			if (!mod) {
236 				key->sites = site;
237 				key->type |= 1;
238 				goto do_transform;
239 			}
240 
241 			site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
242 			if (!site_mod)
243 				return -ENOMEM;
244 
245 			/*
246 			 * When the key has a direct sites pointer, extract
247 			 * that into an explicit struct static_call_mod, so we
248 			 * can have a list of modules.
249 			 */
250 			if (static_call_key_sites(key)) {
251 				site_mod->mod = NULL;
252 				site_mod->next = NULL;
253 				site_mod->sites = static_call_key_sites(key);
254 
255 				key->mods = site_mod;
256 
257 				site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
258 				if (!site_mod)
259 					return -ENOMEM;
260 			}
261 
262 			site_mod->mod = mod;
263 			site_mod->sites = site;
264 			site_mod->next = static_call_key_next(key);
265 			key->mods = site_mod;
266 		}
267 
268 do_transform:
269 		arch_static_call_transform(site_addr, NULL, key->func,
270 				static_call_is_tail(site));
271 	}
272 
273 	return 0;
274 }
275 
276 static int addr_conflict(struct static_call_site *site, void *start, void *end)
277 {
278 	unsigned long addr = (unsigned long)static_call_addr(site);
279 
280 	if (addr <= (unsigned long)end &&
281 	    addr + CALL_INSN_SIZE > (unsigned long)start)
282 		return 1;
283 
284 	return 0;
285 }
286 
287 static int __static_call_text_reserved(struct static_call_site *iter_start,
288 				       struct static_call_site *iter_stop,
289 				       void *start, void *end)
290 {
291 	struct static_call_site *iter = iter_start;
292 
293 	while (iter < iter_stop) {
294 		if (addr_conflict(iter, start, end))
295 			return 1;
296 		iter++;
297 	}
298 
299 	return 0;
300 }
301 
302 #ifdef CONFIG_MODULES
303 
304 static int __static_call_mod_text_reserved(void *start, void *end)
305 {
306 	struct module *mod;
307 	int ret;
308 
309 	preempt_disable();
310 	mod = __module_text_address((unsigned long)start);
311 	WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
312 	if (!try_module_get(mod))
313 		mod = NULL;
314 	preempt_enable();
315 
316 	if (!mod)
317 		return 0;
318 
319 	ret = __static_call_text_reserved(mod->static_call_sites,
320 			mod->static_call_sites + mod->num_static_call_sites,
321 			start, end);
322 
323 	module_put(mod);
324 
325 	return ret;
326 }
327 
328 static unsigned long tramp_key_lookup(unsigned long addr)
329 {
330 	struct static_call_tramp_key *start = __start_static_call_tramp_key;
331 	struct static_call_tramp_key *stop = __stop_static_call_tramp_key;
332 	struct static_call_tramp_key *tramp_key;
333 
334 	for (tramp_key = start; tramp_key != stop; tramp_key++) {
335 		unsigned long tramp;
336 
337 		tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp;
338 		if (tramp == addr)
339 			return (long)tramp_key->key + (long)&tramp_key->key;
340 	}
341 
342 	return 0;
343 }
344 
345 static int static_call_add_module(struct module *mod)
346 {
347 	struct static_call_site *start = mod->static_call_sites;
348 	struct static_call_site *stop = start + mod->num_static_call_sites;
349 	struct static_call_site *site;
350 
351 	for (site = start; site != stop; site++) {
352 		unsigned long s_key = (long)site->key + (long)&site->key;
353 		unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS;
354 		unsigned long key;
355 
356 		/*
357 		 * Is the key is exported, 'addr' points to the key, which
358 		 * means modules are allowed to call static_call_update() on
359 		 * it.
360 		 *
361 		 * Otherwise, the key isn't exported, and 'addr' points to the
362 		 * trampoline so we need to lookup the key.
363 		 *
364 		 * We go through this dance to prevent crazy modules from
365 		 * abusing sensitive static calls.
366 		 */
367 		if (!kernel_text_address(addr))
368 			continue;
369 
370 		key = tramp_key_lookup(addr);
371 		if (!key) {
372 			pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n",
373 				static_call_addr(site));
374 			return -EINVAL;
375 		}
376 
377 		key |= s_key & STATIC_CALL_SITE_FLAGS;
378 		site->key = key - (long)&site->key;
379 	}
380 
381 	return __static_call_init(mod, start, stop);
382 }
383 
384 static void static_call_del_module(struct module *mod)
385 {
386 	struct static_call_site *start = mod->static_call_sites;
387 	struct static_call_site *stop = mod->static_call_sites +
388 					mod->num_static_call_sites;
389 	struct static_call_key *key, *prev_key = NULL;
390 	struct static_call_mod *site_mod, **prev;
391 	struct static_call_site *site;
392 
393 	for (site = start; site < stop; site++) {
394 		key = static_call_key(site);
395 		if (key == prev_key)
396 			continue;
397 
398 		prev_key = key;
399 
400 		for (prev = &key->mods, site_mod = key->mods;
401 		     site_mod && site_mod->mod != mod;
402 		     prev = &site_mod->next, site_mod = site_mod->next)
403 			;
404 
405 		if (!site_mod)
406 			continue;
407 
408 		*prev = site_mod->next;
409 		kfree(site_mod);
410 	}
411 }
412 
413 static int static_call_module_notify(struct notifier_block *nb,
414 				     unsigned long val, void *data)
415 {
416 	struct module *mod = data;
417 	int ret = 0;
418 
419 	cpus_read_lock();
420 	static_call_lock();
421 
422 	switch (val) {
423 	case MODULE_STATE_COMING:
424 		ret = static_call_add_module(mod);
425 		if (ret) {
426 			WARN(1, "Failed to allocate memory for static calls");
427 			static_call_del_module(mod);
428 		}
429 		break;
430 	case MODULE_STATE_GOING:
431 		static_call_del_module(mod);
432 		break;
433 	}
434 
435 	static_call_unlock();
436 	cpus_read_unlock();
437 
438 	return notifier_from_errno(ret);
439 }
440 
441 static struct notifier_block static_call_module_nb = {
442 	.notifier_call = static_call_module_notify,
443 };
444 
445 #else
446 
447 static inline int __static_call_mod_text_reserved(void *start, void *end)
448 {
449 	return 0;
450 }
451 
452 #endif /* CONFIG_MODULES */
453 
454 int static_call_text_reserved(void *start, void *end)
455 {
456 	int ret = __static_call_text_reserved(__start_static_call_sites,
457 			__stop_static_call_sites, start, end);
458 
459 	if (ret)
460 		return ret;
461 
462 	return __static_call_mod_text_reserved(start, end);
463 }
464 
465 int __init static_call_init(void)
466 {
467 	int ret;
468 
469 	if (static_call_initialized)
470 		return 0;
471 
472 	cpus_read_lock();
473 	static_call_lock();
474 	ret = __static_call_init(NULL, __start_static_call_sites,
475 				 __stop_static_call_sites);
476 	static_call_unlock();
477 	cpus_read_unlock();
478 
479 	if (ret) {
480 		pr_err("Failed to allocate memory for static_call!\n");
481 		BUG();
482 	}
483 
484 	static_call_initialized = true;
485 
486 #ifdef CONFIG_MODULES
487 	register_module_notifier(&static_call_module_nb);
488 #endif
489 	return 0;
490 }
491 early_initcall(static_call_init);
492 
493 long __static_call_return0(void)
494 {
495 	return 0;
496 }
497 
498 #ifdef CONFIG_STATIC_CALL_SELFTEST
499 
500 static int func_a(int x)
501 {
502 	return x+1;
503 }
504 
505 static int func_b(int x)
506 {
507 	return x+2;
508 }
509 
510 DEFINE_STATIC_CALL(sc_selftest, func_a);
511 
512 static struct static_call_data {
513       int (*func)(int);
514       int val;
515       int expect;
516 } static_call_data [] __initdata = {
517       { NULL,   2, 3 },
518       { func_b, 2, 4 },
519       { func_a, 2, 3 }
520 };
521 
522 static int __init test_static_call_init(void)
523 {
524       int i;
525 
526       for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
527 	      struct static_call_data *scd = &static_call_data[i];
528 
529               if (scd->func)
530                       static_call_update(sc_selftest, scd->func);
531 
532               WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
533       }
534 
535       return 0;
536 }
537 early_initcall(test_static_call_init);
538 
539 #endif /* CONFIG_STATIC_CALL_SELFTEST */
540