xref: /openbmc/linux/arch/x86/kernel/static_call.c (revision afba8b0a)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/static_call.h>
3 #include <linux/memory.h>
4 #include <linux/bug.h>
5 #include <asm/text-patching.h>
6 
7 enum insn_type {
8 	CALL = 0, /* site call */
9 	NOP = 1,  /* site cond-call */
10 	JMP = 2,  /* tramp / site tail-call */
11 	RET = 3,  /* tramp / site cond-tail-call */
12 };
13 
14 static void __ref __static_call_transform(void *insn, enum insn_type type, void *func)
15 {
16 	int size = CALL_INSN_SIZE;
17 	const void *code;
18 
19 	switch (type) {
20 	case CALL:
21 		code = text_gen_insn(CALL_INSN_OPCODE, insn, func);
22 		break;
23 
24 	case NOP:
25 		code = ideal_nops[NOP_ATOMIC5];
26 		break;
27 
28 	case JMP:
29 		code = text_gen_insn(JMP32_INSN_OPCODE, insn, func);
30 		break;
31 
32 	case RET:
33 		code = text_gen_insn(RET_INSN_OPCODE, insn, func);
34 		size = RET_INSN_SIZE;
35 		break;
36 	}
37 
38 	if (memcmp(insn, code, size) == 0)
39 		return;
40 
41 	if (unlikely(system_state == SYSTEM_BOOTING))
42 		return text_poke_early(insn, code, size);
43 
44 	text_poke_bp(insn, code, size, NULL);
45 }
46 
47 static void __static_call_validate(void *insn, bool tail)
48 {
49 	u8 opcode = *(u8 *)insn;
50 
51 	if (tail) {
52 		if (opcode == JMP32_INSN_OPCODE ||
53 		    opcode == RET_INSN_OPCODE)
54 			return;
55 	} else {
56 		if (opcode == CALL_INSN_OPCODE ||
57 		    !memcmp(insn, ideal_nops[NOP_ATOMIC5], 5))
58 			return;
59 	}
60 
61 	/*
62 	 * If we ever trigger this, our text is corrupt, we'll probably not live long.
63 	 */
64 	WARN_ONCE(1, "unexpected static_call insn opcode 0x%x at %pS\n", opcode, insn);
65 }
66 
67 static inline enum insn_type __sc_insn(bool null, bool tail)
68 {
69 	/*
70 	 * Encode the following table without branches:
71 	 *
72 	 *	tail	null	insn
73 	 *	-----+-------+------
74 	 *	  0  |   0   |  CALL
75 	 *	  0  |   1   |  NOP
76 	 *	  1  |   0   |  JMP
77 	 *	  1  |   1   |  RET
78 	 */
79 	return 2*tail + null;
80 }
81 
82 void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
83 {
84 	mutex_lock(&text_mutex);
85 
86 	if (tramp) {
87 		__static_call_validate(tramp, true);
88 		__static_call_transform(tramp, __sc_insn(!func, true), func);
89 	}
90 
91 	if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site) {
92 		__static_call_validate(site, tail);
93 		__static_call_transform(site, __sc_insn(!func, tail), func);
94 	}
95 
96 	mutex_unlock(&text_mutex);
97 }
98 EXPORT_SYMBOL_GPL(arch_static_call_transform);
99