xref: /openbmc/linux/arch/arm64/net/bpf_jit_comp.c (revision 5fb859f7)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler for ARM64
4  *
5  * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
6  */
7 
8 #define pr_fmt(fmt) "bpf_jit: " fmt
9 
10 #include <linux/bitfield.h>
11 #include <linux/bpf.h>
12 #include <linux/filter.h>
13 #include <linux/printk.h>
14 #include <linux/slab.h>
15 
16 #include <asm/asm-extable.h>
17 #include <asm/byteorder.h>
18 #include <asm/cacheflush.h>
19 #include <asm/debug-monitors.h>
20 #include <asm/insn.h>
21 #include <asm/set_memory.h>
22 
23 #include "bpf_jit.h"
24 
25 #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
26 #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
27 #define TCALL_CNT (MAX_BPF_JIT_REG + 2)
28 #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
29 #define FP_BOTTOM (MAX_BPF_JIT_REG + 4)
30 
31 #define check_imm(bits, imm) do {				\
32 	if ((((imm) > 0) && ((imm) >> (bits))) ||		\
33 	    (((imm) < 0) && (~(imm) >> (bits)))) {		\
34 		pr_info("[%2d] imm=%d(0x%x) out of range\n",	\
35 			i, imm, imm);				\
36 		return -EINVAL;					\
37 	}							\
38 } while (0)
39 #define check_imm19(imm) check_imm(19, imm)
40 #define check_imm26(imm) check_imm(26, imm)
41 
42 /* Map BPF registers to A64 registers */
43 static const int bpf2a64[] = {
44 	/* return value from in-kernel function, and exit value from eBPF */
45 	[BPF_REG_0] = A64_R(7),
46 	/* arguments from eBPF program to in-kernel function */
47 	[BPF_REG_1] = A64_R(0),
48 	[BPF_REG_2] = A64_R(1),
49 	[BPF_REG_3] = A64_R(2),
50 	[BPF_REG_4] = A64_R(3),
51 	[BPF_REG_5] = A64_R(4),
52 	/* callee saved registers that in-kernel function will preserve */
53 	[BPF_REG_6] = A64_R(19),
54 	[BPF_REG_7] = A64_R(20),
55 	[BPF_REG_8] = A64_R(21),
56 	[BPF_REG_9] = A64_R(22),
57 	/* read-only frame pointer to access stack */
58 	[BPF_REG_FP] = A64_R(25),
59 	/* temporary registers for BPF JIT */
60 	[TMP_REG_1] = A64_R(10),
61 	[TMP_REG_2] = A64_R(11),
62 	[TMP_REG_3] = A64_R(12),
63 	/* tail_call_cnt */
64 	[TCALL_CNT] = A64_R(26),
65 	/* temporary register for blinding constants */
66 	[BPF_REG_AX] = A64_R(9),
67 	[FP_BOTTOM] = A64_R(27),
68 };
69 
70 struct jit_ctx {
71 	const struct bpf_prog *prog;
72 	int idx;
73 	int epilogue_offset;
74 	int *offset;
75 	int exentry_idx;
76 	__le32 *image;
77 	u32 stack_size;
78 	int fpb_offset;
79 };
80 
81 static inline void emit(const u32 insn, struct jit_ctx *ctx)
82 {
83 	if (ctx->image != NULL)
84 		ctx->image[ctx->idx] = cpu_to_le32(insn);
85 
86 	ctx->idx++;
87 }
88 
89 static inline void emit_a64_mov_i(const int is64, const int reg,
90 				  const s32 val, struct jit_ctx *ctx)
91 {
92 	u16 hi = val >> 16;
93 	u16 lo = val & 0xffff;
94 
95 	if (hi & 0x8000) {
96 		if (hi == 0xffff) {
97 			emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
98 		} else {
99 			emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
100 			if (lo != 0xffff)
101 				emit(A64_MOVK(is64, reg, lo, 0), ctx);
102 		}
103 	} else {
104 		emit(A64_MOVZ(is64, reg, lo, 0), ctx);
105 		if (hi)
106 			emit(A64_MOVK(is64, reg, hi, 16), ctx);
107 	}
108 }
109 
110 static int i64_i16_blocks(const u64 val, bool inverse)
111 {
112 	return (((val >>  0) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
113 	       (((val >> 16) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
114 	       (((val >> 32) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
115 	       (((val >> 48) & 0xffff) != (inverse ? 0xffff : 0x0000));
116 }
117 
118 static inline void emit_a64_mov_i64(const int reg, const u64 val,
119 				    struct jit_ctx *ctx)
120 {
121 	u64 nrm_tmp = val, rev_tmp = ~val;
122 	bool inverse;
123 	int shift;
124 
125 	if (!(nrm_tmp >> 32))
126 		return emit_a64_mov_i(0, reg, (u32)val, ctx);
127 
128 	inverse = i64_i16_blocks(nrm_tmp, true) < i64_i16_blocks(nrm_tmp, false);
129 	shift = max(round_down((inverse ? (fls64(rev_tmp) - 1) :
130 					  (fls64(nrm_tmp) - 1)), 16), 0);
131 	if (inverse)
132 		emit(A64_MOVN(1, reg, (rev_tmp >> shift) & 0xffff, shift), ctx);
133 	else
134 		emit(A64_MOVZ(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
135 	shift -= 16;
136 	while (shift >= 0) {
137 		if (((nrm_tmp >> shift) & 0xffff) != (inverse ? 0xffff : 0x0000))
138 			emit(A64_MOVK(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
139 		shift -= 16;
140 	}
141 }
142 
143 /*
144  * Kernel addresses in the vmalloc space use at most 48 bits, and the
145  * remaining bits are guaranteed to be 0x1. So we can compose the address
146  * with a fixed length movn/movk/movk sequence.
147  */
148 static inline void emit_addr_mov_i64(const int reg, const u64 val,
149 				     struct jit_ctx *ctx)
150 {
151 	u64 tmp = val;
152 	int shift = 0;
153 
154 	emit(A64_MOVN(1, reg, ~tmp & 0xffff, shift), ctx);
155 	while (shift < 32) {
156 		tmp >>= 16;
157 		shift += 16;
158 		emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
159 	}
160 }
161 
162 static inline int bpf2a64_offset(int bpf_insn, int off,
163 				 const struct jit_ctx *ctx)
164 {
165 	/* BPF JMP offset is relative to the next instruction */
166 	bpf_insn++;
167 	/*
168 	 * Whereas arm64 branch instructions encode the offset
169 	 * from the branch itself, so we must subtract 1 from the
170 	 * instruction offset.
171 	 */
172 	return ctx->offset[bpf_insn + off] - (ctx->offset[bpf_insn] - 1);
173 }
174 
175 static void jit_fill_hole(void *area, unsigned int size)
176 {
177 	__le32 *ptr;
178 	/* We are guaranteed to have aligned memory. */
179 	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
180 		*ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
181 }
182 
183 static inline int epilogue_offset(const struct jit_ctx *ctx)
184 {
185 	int to = ctx->epilogue_offset;
186 	int from = ctx->idx;
187 
188 	return to - from;
189 }
190 
191 static bool is_addsub_imm(u32 imm)
192 {
193 	/* Either imm12 or shifted imm12. */
194 	return !(imm & ~0xfff) || !(imm & ~0xfff000);
195 }
196 
197 /*
198  * There are 3 types of AArch64 LDR/STR (immediate) instruction:
199  * Post-index, Pre-index, Unsigned offset.
200  *
201  * For BPF ldr/str, the "unsigned offset" type is sufficient.
202  *
203  * "Unsigned offset" type LDR(immediate) format:
204  *
205  *    3                   2                   1                   0
206  *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
207  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
208  * |x x|1 1 1 0 0 1 0 1|         imm12         |    Rn   |    Rt   |
209  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
210  * scale
211  *
212  * "Unsigned offset" type STR(immediate) format:
213  *    3                   2                   1                   0
214  *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
215  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
216  * |x x|1 1 1 0 0 1 0 0|         imm12         |    Rn   |    Rt   |
217  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
218  * scale
219  *
220  * The offset is calculated from imm12 and scale in the following way:
221  *
222  * offset = (u64)imm12 << scale
223  */
224 static bool is_lsi_offset(int offset, int scale)
225 {
226 	if (offset < 0)
227 		return false;
228 
229 	if (offset > (0xFFF << scale))
230 		return false;
231 
232 	if (offset & ((1 << scale) - 1))
233 		return false;
234 
235 	return true;
236 }
237 
238 /* Tail call offset to jump into */
239 #if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL) || \
240 	IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL)
241 #define PROLOGUE_OFFSET 9
242 #else
243 #define PROLOGUE_OFFSET 8
244 #endif
245 
246 static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
247 {
248 	const struct bpf_prog *prog = ctx->prog;
249 	const bool is_main_prog = prog->aux->func_idx == 0;
250 	const u8 r6 = bpf2a64[BPF_REG_6];
251 	const u8 r7 = bpf2a64[BPF_REG_7];
252 	const u8 r8 = bpf2a64[BPF_REG_8];
253 	const u8 r9 = bpf2a64[BPF_REG_9];
254 	const u8 fp = bpf2a64[BPF_REG_FP];
255 	const u8 tcc = bpf2a64[TCALL_CNT];
256 	const u8 fpb = bpf2a64[FP_BOTTOM];
257 	const int idx0 = ctx->idx;
258 	int cur_offset;
259 
260 	/*
261 	 * BPF prog stack layout
262 	 *
263 	 *                         high
264 	 * original A64_SP =>   0:+-----+ BPF prologue
265 	 *                        |FP/LR|
266 	 * current A64_FP =>  -16:+-----+
267 	 *                        | ... | callee saved registers
268 	 * BPF fp register => -64:+-----+ <= (BPF_FP)
269 	 *                        |     |
270 	 *                        | ... | BPF prog stack
271 	 *                        |     |
272 	 *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
273 	 *                        |RSVD | padding
274 	 * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
275 	 *                        |     |
276 	 *                        | ... | Function call stack
277 	 *                        |     |
278 	 *                        +-----+
279 	 *                          low
280 	 *
281 	 */
282 
283 	/* Sign lr */
284 	if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
285 		emit(A64_PACIASP, ctx);
286 	/* BTI landing pad */
287 	else if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
288 		emit(A64_BTI_C, ctx);
289 
290 	/* Save FP and LR registers to stay align with ARM64 AAPCS */
291 	emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
292 	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
293 
294 	/* Save callee-saved registers */
295 	emit(A64_PUSH(r6, r7, A64_SP), ctx);
296 	emit(A64_PUSH(r8, r9, A64_SP), ctx);
297 	emit(A64_PUSH(fp, tcc, A64_SP), ctx);
298 	emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx);
299 
300 	/* Set up BPF prog stack base register */
301 	emit(A64_MOV(1, fp, A64_SP), ctx);
302 
303 	if (!ebpf_from_cbpf && is_main_prog) {
304 		/* Initialize tail_call_cnt */
305 		emit(A64_MOVZ(1, tcc, 0, 0), ctx);
306 
307 		cur_offset = ctx->idx - idx0;
308 		if (cur_offset != PROLOGUE_OFFSET) {
309 			pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
310 				    cur_offset, PROLOGUE_OFFSET);
311 			return -1;
312 		}
313 
314 		/* BTI landing pad for the tail call, done with a BR */
315 		if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
316 			emit(A64_BTI_J, ctx);
317 	}
318 
319 	emit(A64_SUB_I(1, fpb, fp, ctx->fpb_offset), ctx);
320 
321 	/* Stack must be multiples of 16B */
322 	ctx->stack_size = round_up(prog->aux->stack_depth, 16);
323 
324 	/* Set up function call stack */
325 	emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
326 	return 0;
327 }
328 
329 static int out_offset = -1; /* initialized on the first pass of build_body() */
330 static int emit_bpf_tail_call(struct jit_ctx *ctx)
331 {
332 	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
333 	const u8 r2 = bpf2a64[BPF_REG_2];
334 	const u8 r3 = bpf2a64[BPF_REG_3];
335 
336 	const u8 tmp = bpf2a64[TMP_REG_1];
337 	const u8 prg = bpf2a64[TMP_REG_2];
338 	const u8 tcc = bpf2a64[TCALL_CNT];
339 	const int idx0 = ctx->idx;
340 #define cur_offset (ctx->idx - idx0)
341 #define jmp_offset (out_offset - (cur_offset))
342 	size_t off;
343 
344 	/* if (index >= array->map.max_entries)
345 	 *     goto out;
346 	 */
347 	off = offsetof(struct bpf_array, map.max_entries);
348 	emit_a64_mov_i64(tmp, off, ctx);
349 	emit(A64_LDR32(tmp, r2, tmp), ctx);
350 	emit(A64_MOV(0, r3, r3), ctx);
351 	emit(A64_CMP(0, r3, tmp), ctx);
352 	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
353 
354 	/*
355 	 * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
356 	 *     goto out;
357 	 * tail_call_cnt++;
358 	 */
359 	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
360 	emit(A64_CMP(1, tcc, tmp), ctx);
361 	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
362 	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
363 
364 	/* prog = array->ptrs[index];
365 	 * if (prog == NULL)
366 	 *     goto out;
367 	 */
368 	off = offsetof(struct bpf_array, ptrs);
369 	emit_a64_mov_i64(tmp, off, ctx);
370 	emit(A64_ADD(1, tmp, r2, tmp), ctx);
371 	emit(A64_LSL(1, prg, r3, 3), ctx);
372 	emit(A64_LDR64(prg, tmp, prg), ctx);
373 	emit(A64_CBZ(1, prg, jmp_offset), ctx);
374 
375 	/* goto *(prog->bpf_func + prologue_offset); */
376 	off = offsetof(struct bpf_prog, bpf_func);
377 	emit_a64_mov_i64(tmp, off, ctx);
378 	emit(A64_LDR64(tmp, prg, tmp), ctx);
379 	emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
380 	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
381 	emit(A64_BR(tmp), ctx);
382 
383 	/* out: */
384 	if (out_offset == -1)
385 		out_offset = cur_offset;
386 	if (cur_offset != out_offset) {
387 		pr_err_once("tail_call out_offset = %d, expected %d!\n",
388 			    cur_offset, out_offset);
389 		return -1;
390 	}
391 	return 0;
392 #undef cur_offset
393 #undef jmp_offset
394 }
395 
396 #ifdef CONFIG_ARM64_LSE_ATOMICS
397 static int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
398 {
399 	const u8 code = insn->code;
400 	const u8 dst = bpf2a64[insn->dst_reg];
401 	const u8 src = bpf2a64[insn->src_reg];
402 	const u8 tmp = bpf2a64[TMP_REG_1];
403 	const u8 tmp2 = bpf2a64[TMP_REG_2];
404 	const bool isdw = BPF_SIZE(code) == BPF_DW;
405 	const s16 off = insn->off;
406 	u8 reg;
407 
408 	if (!off) {
409 		reg = dst;
410 	} else {
411 		emit_a64_mov_i(1, tmp, off, ctx);
412 		emit(A64_ADD(1, tmp, tmp, dst), ctx);
413 		reg = tmp;
414 	}
415 
416 	switch (insn->imm) {
417 	/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
418 	case BPF_ADD:
419 		emit(A64_STADD(isdw, reg, src), ctx);
420 		break;
421 	case BPF_AND:
422 		emit(A64_MVN(isdw, tmp2, src), ctx);
423 		emit(A64_STCLR(isdw, reg, tmp2), ctx);
424 		break;
425 	case BPF_OR:
426 		emit(A64_STSET(isdw, reg, src), ctx);
427 		break;
428 	case BPF_XOR:
429 		emit(A64_STEOR(isdw, reg, src), ctx);
430 		break;
431 	/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
432 	case BPF_ADD | BPF_FETCH:
433 		emit(A64_LDADDAL(isdw, src, reg, src), ctx);
434 		break;
435 	case BPF_AND | BPF_FETCH:
436 		emit(A64_MVN(isdw, tmp2, src), ctx);
437 		emit(A64_LDCLRAL(isdw, src, reg, tmp2), ctx);
438 		break;
439 	case BPF_OR | BPF_FETCH:
440 		emit(A64_LDSETAL(isdw, src, reg, src), ctx);
441 		break;
442 	case BPF_XOR | BPF_FETCH:
443 		emit(A64_LDEORAL(isdw, src, reg, src), ctx);
444 		break;
445 	/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
446 	case BPF_XCHG:
447 		emit(A64_SWPAL(isdw, src, reg, src), ctx);
448 		break;
449 	/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
450 	case BPF_CMPXCHG:
451 		emit(A64_CASAL(isdw, src, reg, bpf2a64[BPF_REG_0]), ctx);
452 		break;
453 	default:
454 		pr_err_once("unknown atomic op code %02x\n", insn->imm);
455 		return -EINVAL;
456 	}
457 
458 	return 0;
459 }
460 #else
461 static inline int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
462 {
463 	return -EINVAL;
464 }
465 #endif
466 
467 static int emit_ll_sc_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
468 {
469 	const u8 code = insn->code;
470 	const u8 dst = bpf2a64[insn->dst_reg];
471 	const u8 src = bpf2a64[insn->src_reg];
472 	const u8 tmp = bpf2a64[TMP_REG_1];
473 	const u8 tmp2 = bpf2a64[TMP_REG_2];
474 	const u8 tmp3 = bpf2a64[TMP_REG_3];
475 	const int i = insn - ctx->prog->insnsi;
476 	const s32 imm = insn->imm;
477 	const s16 off = insn->off;
478 	const bool isdw = BPF_SIZE(code) == BPF_DW;
479 	u8 reg;
480 	s32 jmp_offset;
481 
482 	if (!off) {
483 		reg = dst;
484 	} else {
485 		emit_a64_mov_i(1, tmp, off, ctx);
486 		emit(A64_ADD(1, tmp, tmp, dst), ctx);
487 		reg = tmp;
488 	}
489 
490 	if (imm == BPF_ADD || imm == BPF_AND ||
491 	    imm == BPF_OR || imm == BPF_XOR) {
492 		/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
493 		emit(A64_LDXR(isdw, tmp2, reg), ctx);
494 		if (imm == BPF_ADD)
495 			emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
496 		else if (imm == BPF_AND)
497 			emit(A64_AND(isdw, tmp2, tmp2, src), ctx);
498 		else if (imm == BPF_OR)
499 			emit(A64_ORR(isdw, tmp2, tmp2, src), ctx);
500 		else
501 			emit(A64_EOR(isdw, tmp2, tmp2, src), ctx);
502 		emit(A64_STXR(isdw, tmp2, reg, tmp3), ctx);
503 		jmp_offset = -3;
504 		check_imm19(jmp_offset);
505 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
506 	} else if (imm == (BPF_ADD | BPF_FETCH) ||
507 		   imm == (BPF_AND | BPF_FETCH) ||
508 		   imm == (BPF_OR | BPF_FETCH) ||
509 		   imm == (BPF_XOR | BPF_FETCH)) {
510 		/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
511 		const u8 ax = bpf2a64[BPF_REG_AX];
512 
513 		emit(A64_MOV(isdw, ax, src), ctx);
514 		emit(A64_LDXR(isdw, src, reg), ctx);
515 		if (imm == (BPF_ADD | BPF_FETCH))
516 			emit(A64_ADD(isdw, tmp2, src, ax), ctx);
517 		else if (imm == (BPF_AND | BPF_FETCH))
518 			emit(A64_AND(isdw, tmp2, src, ax), ctx);
519 		else if (imm == (BPF_OR | BPF_FETCH))
520 			emit(A64_ORR(isdw, tmp2, src, ax), ctx);
521 		else
522 			emit(A64_EOR(isdw, tmp2, src, ax), ctx);
523 		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
524 		jmp_offset = -3;
525 		check_imm19(jmp_offset);
526 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
527 		emit(A64_DMB_ISH, ctx);
528 	} else if (imm == BPF_XCHG) {
529 		/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
530 		emit(A64_MOV(isdw, tmp2, src), ctx);
531 		emit(A64_LDXR(isdw, src, reg), ctx);
532 		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
533 		jmp_offset = -2;
534 		check_imm19(jmp_offset);
535 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
536 		emit(A64_DMB_ISH, ctx);
537 	} else if (imm == BPF_CMPXCHG) {
538 		/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
539 		const u8 r0 = bpf2a64[BPF_REG_0];
540 
541 		emit(A64_MOV(isdw, tmp2, r0), ctx);
542 		emit(A64_LDXR(isdw, r0, reg), ctx);
543 		emit(A64_EOR(isdw, tmp3, r0, tmp2), ctx);
544 		jmp_offset = 4;
545 		check_imm19(jmp_offset);
546 		emit(A64_CBNZ(isdw, tmp3, jmp_offset), ctx);
547 		emit(A64_STLXR(isdw, src, reg, tmp3), ctx);
548 		jmp_offset = -4;
549 		check_imm19(jmp_offset);
550 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
551 		emit(A64_DMB_ISH, ctx);
552 	} else {
553 		pr_err_once("unknown atomic op code %02x\n", imm);
554 		return -EINVAL;
555 	}
556 
557 	return 0;
558 }
559 
560 static void build_epilogue(struct jit_ctx *ctx)
561 {
562 	const u8 r0 = bpf2a64[BPF_REG_0];
563 	const u8 r6 = bpf2a64[BPF_REG_6];
564 	const u8 r7 = bpf2a64[BPF_REG_7];
565 	const u8 r8 = bpf2a64[BPF_REG_8];
566 	const u8 r9 = bpf2a64[BPF_REG_9];
567 	const u8 fp = bpf2a64[BPF_REG_FP];
568 	const u8 fpb = bpf2a64[FP_BOTTOM];
569 
570 	/* We're done with BPF stack */
571 	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
572 
573 	/* Restore x27 and x28 */
574 	emit(A64_POP(fpb, A64_R(28), A64_SP), ctx);
575 	/* Restore fs (x25) and x26 */
576 	emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
577 
578 	/* Restore callee-saved register */
579 	emit(A64_POP(r8, r9, A64_SP), ctx);
580 	emit(A64_POP(r6, r7, A64_SP), ctx);
581 
582 	/* Restore FP/LR registers */
583 	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
584 
585 	/* Set return value */
586 	emit(A64_MOV(1, A64_R(0), r0), ctx);
587 
588 	/* Authenticate lr */
589 	if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
590 		emit(A64_AUTIASP, ctx);
591 
592 	emit(A64_RET(A64_LR), ctx);
593 }
594 
595 #define BPF_FIXUP_OFFSET_MASK	GENMASK(26, 0)
596 #define BPF_FIXUP_REG_MASK	GENMASK(31, 27)
597 
598 bool ex_handler_bpf(const struct exception_table_entry *ex,
599 		    struct pt_regs *regs)
600 {
601 	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
602 	int dst_reg = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
603 
604 	regs->regs[dst_reg] = 0;
605 	regs->pc = (unsigned long)&ex->fixup - offset;
606 	return true;
607 }
608 
609 /* For accesses to BTF pointers, add an entry to the exception table */
610 static int add_exception_handler(const struct bpf_insn *insn,
611 				 struct jit_ctx *ctx,
612 				 int dst_reg)
613 {
614 	off_t offset;
615 	unsigned long pc;
616 	struct exception_table_entry *ex;
617 
618 	if (!ctx->image)
619 		/* First pass */
620 		return 0;
621 
622 	if (BPF_MODE(insn->code) != BPF_PROBE_MEM)
623 		return 0;
624 
625 	if (!ctx->prog->aux->extable ||
626 	    WARN_ON_ONCE(ctx->exentry_idx >= ctx->prog->aux->num_exentries))
627 		return -EINVAL;
628 
629 	ex = &ctx->prog->aux->extable[ctx->exentry_idx];
630 	pc = (unsigned long)&ctx->image[ctx->idx - 1];
631 
632 	offset = pc - (long)&ex->insn;
633 	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
634 		return -ERANGE;
635 	ex->insn = offset;
636 
637 	/*
638 	 * Since the extable follows the program, the fixup offset is always
639 	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
640 	 * to keep things simple, and put the destination register in the upper
641 	 * bits. We don't need to worry about buildtime or runtime sort
642 	 * modifying the upper bits because the table is already sorted, and
643 	 * isn't part of the main exception table.
644 	 */
645 	offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE);
646 	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
647 		return -ERANGE;
648 
649 	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
650 		    FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
651 
652 	ex->type = EX_TYPE_BPF;
653 
654 	ctx->exentry_idx++;
655 	return 0;
656 }
657 
658 /* JITs an eBPF instruction.
659  * Returns:
660  * 0  - successfully JITed an 8-byte eBPF instruction.
661  * >0 - successfully JITed a 16-byte eBPF instruction.
662  * <0 - failed to JIT.
663  */
664 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
665 		      bool extra_pass)
666 {
667 	const u8 code = insn->code;
668 	const u8 dst = bpf2a64[insn->dst_reg];
669 	const u8 src = bpf2a64[insn->src_reg];
670 	const u8 tmp = bpf2a64[TMP_REG_1];
671 	const u8 tmp2 = bpf2a64[TMP_REG_2];
672 	const u8 fp = bpf2a64[BPF_REG_FP];
673 	const u8 fpb = bpf2a64[FP_BOTTOM];
674 	const s16 off = insn->off;
675 	const s32 imm = insn->imm;
676 	const int i = insn - ctx->prog->insnsi;
677 	const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
678 			  BPF_CLASS(code) == BPF_JMP;
679 	u8 jmp_cond;
680 	s32 jmp_offset;
681 	u32 a64_insn;
682 	u8 src_adj;
683 	u8 dst_adj;
684 	int off_adj;
685 	int ret;
686 
687 	switch (code) {
688 	/* dst = src */
689 	case BPF_ALU | BPF_MOV | BPF_X:
690 	case BPF_ALU64 | BPF_MOV | BPF_X:
691 		emit(A64_MOV(is64, dst, src), ctx);
692 		break;
693 	/* dst = dst OP src */
694 	case BPF_ALU | BPF_ADD | BPF_X:
695 	case BPF_ALU64 | BPF_ADD | BPF_X:
696 		emit(A64_ADD(is64, dst, dst, src), ctx);
697 		break;
698 	case BPF_ALU | BPF_SUB | BPF_X:
699 	case BPF_ALU64 | BPF_SUB | BPF_X:
700 		emit(A64_SUB(is64, dst, dst, src), ctx);
701 		break;
702 	case BPF_ALU | BPF_AND | BPF_X:
703 	case BPF_ALU64 | BPF_AND | BPF_X:
704 		emit(A64_AND(is64, dst, dst, src), ctx);
705 		break;
706 	case BPF_ALU | BPF_OR | BPF_X:
707 	case BPF_ALU64 | BPF_OR | BPF_X:
708 		emit(A64_ORR(is64, dst, dst, src), ctx);
709 		break;
710 	case BPF_ALU | BPF_XOR | BPF_X:
711 	case BPF_ALU64 | BPF_XOR | BPF_X:
712 		emit(A64_EOR(is64, dst, dst, src), ctx);
713 		break;
714 	case BPF_ALU | BPF_MUL | BPF_X:
715 	case BPF_ALU64 | BPF_MUL | BPF_X:
716 		emit(A64_MUL(is64, dst, dst, src), ctx);
717 		break;
718 	case BPF_ALU | BPF_DIV | BPF_X:
719 	case BPF_ALU64 | BPF_DIV | BPF_X:
720 		emit(A64_UDIV(is64, dst, dst, src), ctx);
721 		break;
722 	case BPF_ALU | BPF_MOD | BPF_X:
723 	case BPF_ALU64 | BPF_MOD | BPF_X:
724 		emit(A64_UDIV(is64, tmp, dst, src), ctx);
725 		emit(A64_MSUB(is64, dst, dst, tmp, src), ctx);
726 		break;
727 	case BPF_ALU | BPF_LSH | BPF_X:
728 	case BPF_ALU64 | BPF_LSH | BPF_X:
729 		emit(A64_LSLV(is64, dst, dst, src), ctx);
730 		break;
731 	case BPF_ALU | BPF_RSH | BPF_X:
732 	case BPF_ALU64 | BPF_RSH | BPF_X:
733 		emit(A64_LSRV(is64, dst, dst, src), ctx);
734 		break;
735 	case BPF_ALU | BPF_ARSH | BPF_X:
736 	case BPF_ALU64 | BPF_ARSH | BPF_X:
737 		emit(A64_ASRV(is64, dst, dst, src), ctx);
738 		break;
739 	/* dst = -dst */
740 	case BPF_ALU | BPF_NEG:
741 	case BPF_ALU64 | BPF_NEG:
742 		emit(A64_NEG(is64, dst, dst), ctx);
743 		break;
744 	/* dst = BSWAP##imm(dst) */
745 	case BPF_ALU | BPF_END | BPF_FROM_LE:
746 	case BPF_ALU | BPF_END | BPF_FROM_BE:
747 #ifdef CONFIG_CPU_BIG_ENDIAN
748 		if (BPF_SRC(code) == BPF_FROM_BE)
749 			goto emit_bswap_uxt;
750 #else /* !CONFIG_CPU_BIG_ENDIAN */
751 		if (BPF_SRC(code) == BPF_FROM_LE)
752 			goto emit_bswap_uxt;
753 #endif
754 		switch (imm) {
755 		case 16:
756 			emit(A64_REV16(is64, dst, dst), ctx);
757 			/* zero-extend 16 bits into 64 bits */
758 			emit(A64_UXTH(is64, dst, dst), ctx);
759 			break;
760 		case 32:
761 			emit(A64_REV32(is64, dst, dst), ctx);
762 			/* upper 32 bits already cleared */
763 			break;
764 		case 64:
765 			emit(A64_REV64(dst, dst), ctx);
766 			break;
767 		}
768 		break;
769 emit_bswap_uxt:
770 		switch (imm) {
771 		case 16:
772 			/* zero-extend 16 bits into 64 bits */
773 			emit(A64_UXTH(is64, dst, dst), ctx);
774 			break;
775 		case 32:
776 			/* zero-extend 32 bits into 64 bits */
777 			emit(A64_UXTW(is64, dst, dst), ctx);
778 			break;
779 		case 64:
780 			/* nop */
781 			break;
782 		}
783 		break;
784 	/* dst = imm */
785 	case BPF_ALU | BPF_MOV | BPF_K:
786 	case BPF_ALU64 | BPF_MOV | BPF_K:
787 		emit_a64_mov_i(is64, dst, imm, ctx);
788 		break;
789 	/* dst = dst OP imm */
790 	case BPF_ALU | BPF_ADD | BPF_K:
791 	case BPF_ALU64 | BPF_ADD | BPF_K:
792 		if (is_addsub_imm(imm)) {
793 			emit(A64_ADD_I(is64, dst, dst, imm), ctx);
794 		} else if (is_addsub_imm(-imm)) {
795 			emit(A64_SUB_I(is64, dst, dst, -imm), ctx);
796 		} else {
797 			emit_a64_mov_i(is64, tmp, imm, ctx);
798 			emit(A64_ADD(is64, dst, dst, tmp), ctx);
799 		}
800 		break;
801 	case BPF_ALU | BPF_SUB | BPF_K:
802 	case BPF_ALU64 | BPF_SUB | BPF_K:
803 		if (is_addsub_imm(imm)) {
804 			emit(A64_SUB_I(is64, dst, dst, imm), ctx);
805 		} else if (is_addsub_imm(-imm)) {
806 			emit(A64_ADD_I(is64, dst, dst, -imm), ctx);
807 		} else {
808 			emit_a64_mov_i(is64, tmp, imm, ctx);
809 			emit(A64_SUB(is64, dst, dst, tmp), ctx);
810 		}
811 		break;
812 	case BPF_ALU | BPF_AND | BPF_K:
813 	case BPF_ALU64 | BPF_AND | BPF_K:
814 		a64_insn = A64_AND_I(is64, dst, dst, imm);
815 		if (a64_insn != AARCH64_BREAK_FAULT) {
816 			emit(a64_insn, ctx);
817 		} else {
818 			emit_a64_mov_i(is64, tmp, imm, ctx);
819 			emit(A64_AND(is64, dst, dst, tmp), ctx);
820 		}
821 		break;
822 	case BPF_ALU | BPF_OR | BPF_K:
823 	case BPF_ALU64 | BPF_OR | BPF_K:
824 		a64_insn = A64_ORR_I(is64, dst, dst, imm);
825 		if (a64_insn != AARCH64_BREAK_FAULT) {
826 			emit(a64_insn, ctx);
827 		} else {
828 			emit_a64_mov_i(is64, tmp, imm, ctx);
829 			emit(A64_ORR(is64, dst, dst, tmp), ctx);
830 		}
831 		break;
832 	case BPF_ALU | BPF_XOR | BPF_K:
833 	case BPF_ALU64 | BPF_XOR | BPF_K:
834 		a64_insn = A64_EOR_I(is64, dst, dst, imm);
835 		if (a64_insn != AARCH64_BREAK_FAULT) {
836 			emit(a64_insn, ctx);
837 		} else {
838 			emit_a64_mov_i(is64, tmp, imm, ctx);
839 			emit(A64_EOR(is64, dst, dst, tmp), ctx);
840 		}
841 		break;
842 	case BPF_ALU | BPF_MUL | BPF_K:
843 	case BPF_ALU64 | BPF_MUL | BPF_K:
844 		emit_a64_mov_i(is64, tmp, imm, ctx);
845 		emit(A64_MUL(is64, dst, dst, tmp), ctx);
846 		break;
847 	case BPF_ALU | BPF_DIV | BPF_K:
848 	case BPF_ALU64 | BPF_DIV | BPF_K:
849 		emit_a64_mov_i(is64, tmp, imm, ctx);
850 		emit(A64_UDIV(is64, dst, dst, tmp), ctx);
851 		break;
852 	case BPF_ALU | BPF_MOD | BPF_K:
853 	case BPF_ALU64 | BPF_MOD | BPF_K:
854 		emit_a64_mov_i(is64, tmp2, imm, ctx);
855 		emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
856 		emit(A64_MSUB(is64, dst, dst, tmp, tmp2), ctx);
857 		break;
858 	case BPF_ALU | BPF_LSH | BPF_K:
859 	case BPF_ALU64 | BPF_LSH | BPF_K:
860 		emit(A64_LSL(is64, dst, dst, imm), ctx);
861 		break;
862 	case BPF_ALU | BPF_RSH | BPF_K:
863 	case BPF_ALU64 | BPF_RSH | BPF_K:
864 		emit(A64_LSR(is64, dst, dst, imm), ctx);
865 		break;
866 	case BPF_ALU | BPF_ARSH | BPF_K:
867 	case BPF_ALU64 | BPF_ARSH | BPF_K:
868 		emit(A64_ASR(is64, dst, dst, imm), ctx);
869 		break;
870 
871 	/* JUMP off */
872 	case BPF_JMP | BPF_JA:
873 		jmp_offset = bpf2a64_offset(i, off, ctx);
874 		check_imm26(jmp_offset);
875 		emit(A64_B(jmp_offset), ctx);
876 		break;
877 	/* IF (dst COND src) JUMP off */
878 	case BPF_JMP | BPF_JEQ | BPF_X:
879 	case BPF_JMP | BPF_JGT | BPF_X:
880 	case BPF_JMP | BPF_JLT | BPF_X:
881 	case BPF_JMP | BPF_JGE | BPF_X:
882 	case BPF_JMP | BPF_JLE | BPF_X:
883 	case BPF_JMP | BPF_JNE | BPF_X:
884 	case BPF_JMP | BPF_JSGT | BPF_X:
885 	case BPF_JMP | BPF_JSLT | BPF_X:
886 	case BPF_JMP | BPF_JSGE | BPF_X:
887 	case BPF_JMP | BPF_JSLE | BPF_X:
888 	case BPF_JMP32 | BPF_JEQ | BPF_X:
889 	case BPF_JMP32 | BPF_JGT | BPF_X:
890 	case BPF_JMP32 | BPF_JLT | BPF_X:
891 	case BPF_JMP32 | BPF_JGE | BPF_X:
892 	case BPF_JMP32 | BPF_JLE | BPF_X:
893 	case BPF_JMP32 | BPF_JNE | BPF_X:
894 	case BPF_JMP32 | BPF_JSGT | BPF_X:
895 	case BPF_JMP32 | BPF_JSLT | BPF_X:
896 	case BPF_JMP32 | BPF_JSGE | BPF_X:
897 	case BPF_JMP32 | BPF_JSLE | BPF_X:
898 		emit(A64_CMP(is64, dst, src), ctx);
899 emit_cond_jmp:
900 		jmp_offset = bpf2a64_offset(i, off, ctx);
901 		check_imm19(jmp_offset);
902 		switch (BPF_OP(code)) {
903 		case BPF_JEQ:
904 			jmp_cond = A64_COND_EQ;
905 			break;
906 		case BPF_JGT:
907 			jmp_cond = A64_COND_HI;
908 			break;
909 		case BPF_JLT:
910 			jmp_cond = A64_COND_CC;
911 			break;
912 		case BPF_JGE:
913 			jmp_cond = A64_COND_CS;
914 			break;
915 		case BPF_JLE:
916 			jmp_cond = A64_COND_LS;
917 			break;
918 		case BPF_JSET:
919 		case BPF_JNE:
920 			jmp_cond = A64_COND_NE;
921 			break;
922 		case BPF_JSGT:
923 			jmp_cond = A64_COND_GT;
924 			break;
925 		case BPF_JSLT:
926 			jmp_cond = A64_COND_LT;
927 			break;
928 		case BPF_JSGE:
929 			jmp_cond = A64_COND_GE;
930 			break;
931 		case BPF_JSLE:
932 			jmp_cond = A64_COND_LE;
933 			break;
934 		default:
935 			return -EFAULT;
936 		}
937 		emit(A64_B_(jmp_cond, jmp_offset), ctx);
938 		break;
939 	case BPF_JMP | BPF_JSET | BPF_X:
940 	case BPF_JMP32 | BPF_JSET | BPF_X:
941 		emit(A64_TST(is64, dst, src), ctx);
942 		goto emit_cond_jmp;
943 	/* IF (dst COND imm) JUMP off */
944 	case BPF_JMP | BPF_JEQ | BPF_K:
945 	case BPF_JMP | BPF_JGT | BPF_K:
946 	case BPF_JMP | BPF_JLT | BPF_K:
947 	case BPF_JMP | BPF_JGE | BPF_K:
948 	case BPF_JMP | BPF_JLE | BPF_K:
949 	case BPF_JMP | BPF_JNE | BPF_K:
950 	case BPF_JMP | BPF_JSGT | BPF_K:
951 	case BPF_JMP | BPF_JSLT | BPF_K:
952 	case BPF_JMP | BPF_JSGE | BPF_K:
953 	case BPF_JMP | BPF_JSLE | BPF_K:
954 	case BPF_JMP32 | BPF_JEQ | BPF_K:
955 	case BPF_JMP32 | BPF_JGT | BPF_K:
956 	case BPF_JMP32 | BPF_JLT | BPF_K:
957 	case BPF_JMP32 | BPF_JGE | BPF_K:
958 	case BPF_JMP32 | BPF_JLE | BPF_K:
959 	case BPF_JMP32 | BPF_JNE | BPF_K:
960 	case BPF_JMP32 | BPF_JSGT | BPF_K:
961 	case BPF_JMP32 | BPF_JSLT | BPF_K:
962 	case BPF_JMP32 | BPF_JSGE | BPF_K:
963 	case BPF_JMP32 | BPF_JSLE | BPF_K:
964 		if (is_addsub_imm(imm)) {
965 			emit(A64_CMP_I(is64, dst, imm), ctx);
966 		} else if (is_addsub_imm(-imm)) {
967 			emit(A64_CMN_I(is64, dst, -imm), ctx);
968 		} else {
969 			emit_a64_mov_i(is64, tmp, imm, ctx);
970 			emit(A64_CMP(is64, dst, tmp), ctx);
971 		}
972 		goto emit_cond_jmp;
973 	case BPF_JMP | BPF_JSET | BPF_K:
974 	case BPF_JMP32 | BPF_JSET | BPF_K:
975 		a64_insn = A64_TST_I(is64, dst, imm);
976 		if (a64_insn != AARCH64_BREAK_FAULT) {
977 			emit(a64_insn, ctx);
978 		} else {
979 			emit_a64_mov_i(is64, tmp, imm, ctx);
980 			emit(A64_TST(is64, dst, tmp), ctx);
981 		}
982 		goto emit_cond_jmp;
983 	/* function call */
984 	case BPF_JMP | BPF_CALL:
985 	{
986 		const u8 r0 = bpf2a64[BPF_REG_0];
987 		bool func_addr_fixed;
988 		u64 func_addr;
989 
990 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
991 					    &func_addr, &func_addr_fixed);
992 		if (ret < 0)
993 			return ret;
994 		emit_addr_mov_i64(tmp, func_addr, ctx);
995 		emit(A64_BLR(tmp), ctx);
996 		emit(A64_MOV(1, r0, A64_R(0)), ctx);
997 		break;
998 	}
999 	/* tail call */
1000 	case BPF_JMP | BPF_TAIL_CALL:
1001 		if (emit_bpf_tail_call(ctx))
1002 			return -EFAULT;
1003 		break;
1004 	/* function return */
1005 	case BPF_JMP | BPF_EXIT:
1006 		/* Optimization: when last instruction is EXIT,
1007 		   simply fallthrough to epilogue. */
1008 		if (i == ctx->prog->len - 1)
1009 			break;
1010 		jmp_offset = epilogue_offset(ctx);
1011 		check_imm26(jmp_offset);
1012 		emit(A64_B(jmp_offset), ctx);
1013 		break;
1014 
1015 	/* dst = imm64 */
1016 	case BPF_LD | BPF_IMM | BPF_DW:
1017 	{
1018 		const struct bpf_insn insn1 = insn[1];
1019 		u64 imm64;
1020 
1021 		imm64 = (u64)insn1.imm << 32 | (u32)imm;
1022 		if (bpf_pseudo_func(insn))
1023 			emit_addr_mov_i64(dst, imm64, ctx);
1024 		else
1025 			emit_a64_mov_i64(dst, imm64, ctx);
1026 
1027 		return 1;
1028 	}
1029 
1030 	/* LDX: dst = *(size *)(src + off) */
1031 	case BPF_LDX | BPF_MEM | BPF_W:
1032 	case BPF_LDX | BPF_MEM | BPF_H:
1033 	case BPF_LDX | BPF_MEM | BPF_B:
1034 	case BPF_LDX | BPF_MEM | BPF_DW:
1035 	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1036 	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1037 	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1038 	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1039 		if (ctx->fpb_offset > 0 && src == fp) {
1040 			src_adj = fpb;
1041 			off_adj = off + ctx->fpb_offset;
1042 		} else {
1043 			src_adj = src;
1044 			off_adj = off;
1045 		}
1046 		switch (BPF_SIZE(code)) {
1047 		case BPF_W:
1048 			if (is_lsi_offset(off_adj, 2)) {
1049 				emit(A64_LDR32I(dst, src_adj, off_adj), ctx);
1050 			} else {
1051 				emit_a64_mov_i(1, tmp, off, ctx);
1052 				emit(A64_LDR32(dst, src, tmp), ctx);
1053 			}
1054 			break;
1055 		case BPF_H:
1056 			if (is_lsi_offset(off_adj, 1)) {
1057 				emit(A64_LDRHI(dst, src_adj, off_adj), ctx);
1058 			} else {
1059 				emit_a64_mov_i(1, tmp, off, ctx);
1060 				emit(A64_LDRH(dst, src, tmp), ctx);
1061 			}
1062 			break;
1063 		case BPF_B:
1064 			if (is_lsi_offset(off_adj, 0)) {
1065 				emit(A64_LDRBI(dst, src_adj, off_adj), ctx);
1066 			} else {
1067 				emit_a64_mov_i(1, tmp, off, ctx);
1068 				emit(A64_LDRB(dst, src, tmp), ctx);
1069 			}
1070 			break;
1071 		case BPF_DW:
1072 			if (is_lsi_offset(off_adj, 3)) {
1073 				emit(A64_LDR64I(dst, src_adj, off_adj), ctx);
1074 			} else {
1075 				emit_a64_mov_i(1, tmp, off, ctx);
1076 				emit(A64_LDR64(dst, src, tmp), ctx);
1077 			}
1078 			break;
1079 		}
1080 
1081 		ret = add_exception_handler(insn, ctx, dst);
1082 		if (ret)
1083 			return ret;
1084 		break;
1085 
1086 	/* speculation barrier */
1087 	case BPF_ST | BPF_NOSPEC:
1088 		/*
1089 		 * Nothing required here.
1090 		 *
1091 		 * In case of arm64, we rely on the firmware mitigation of
1092 		 * Speculative Store Bypass as controlled via the ssbd kernel
1093 		 * parameter. Whenever the mitigation is enabled, it works
1094 		 * for all of the kernel code with no need to provide any
1095 		 * additional instructions.
1096 		 */
1097 		break;
1098 
1099 	/* ST: *(size *)(dst + off) = imm */
1100 	case BPF_ST | BPF_MEM | BPF_W:
1101 	case BPF_ST | BPF_MEM | BPF_H:
1102 	case BPF_ST | BPF_MEM | BPF_B:
1103 	case BPF_ST | BPF_MEM | BPF_DW:
1104 		if (ctx->fpb_offset > 0 && dst == fp) {
1105 			dst_adj = fpb;
1106 			off_adj = off + ctx->fpb_offset;
1107 		} else {
1108 			dst_adj = dst;
1109 			off_adj = off;
1110 		}
1111 		/* Load imm to a register then store it */
1112 		emit_a64_mov_i(1, tmp, imm, ctx);
1113 		switch (BPF_SIZE(code)) {
1114 		case BPF_W:
1115 			if (is_lsi_offset(off_adj, 2)) {
1116 				emit(A64_STR32I(tmp, dst_adj, off_adj), ctx);
1117 			} else {
1118 				emit_a64_mov_i(1, tmp2, off, ctx);
1119 				emit(A64_STR32(tmp, dst, tmp2), ctx);
1120 			}
1121 			break;
1122 		case BPF_H:
1123 			if (is_lsi_offset(off_adj, 1)) {
1124 				emit(A64_STRHI(tmp, dst_adj, off_adj), ctx);
1125 			} else {
1126 				emit_a64_mov_i(1, tmp2, off, ctx);
1127 				emit(A64_STRH(tmp, dst, tmp2), ctx);
1128 			}
1129 			break;
1130 		case BPF_B:
1131 			if (is_lsi_offset(off_adj, 0)) {
1132 				emit(A64_STRBI(tmp, dst_adj, off_adj), ctx);
1133 			} else {
1134 				emit_a64_mov_i(1, tmp2, off, ctx);
1135 				emit(A64_STRB(tmp, dst, tmp2), ctx);
1136 			}
1137 			break;
1138 		case BPF_DW:
1139 			if (is_lsi_offset(off_adj, 3)) {
1140 				emit(A64_STR64I(tmp, dst_adj, off_adj), ctx);
1141 			} else {
1142 				emit_a64_mov_i(1, tmp2, off, ctx);
1143 				emit(A64_STR64(tmp, dst, tmp2), ctx);
1144 			}
1145 			break;
1146 		}
1147 		break;
1148 
1149 	/* STX: *(size *)(dst + off) = src */
1150 	case BPF_STX | BPF_MEM | BPF_W:
1151 	case BPF_STX | BPF_MEM | BPF_H:
1152 	case BPF_STX | BPF_MEM | BPF_B:
1153 	case BPF_STX | BPF_MEM | BPF_DW:
1154 		if (ctx->fpb_offset > 0 && dst == fp) {
1155 			dst_adj = fpb;
1156 			off_adj = off + ctx->fpb_offset;
1157 		} else {
1158 			dst_adj = dst;
1159 			off_adj = off;
1160 		}
1161 		switch (BPF_SIZE(code)) {
1162 		case BPF_W:
1163 			if (is_lsi_offset(off_adj, 2)) {
1164 				emit(A64_STR32I(src, dst_adj, off_adj), ctx);
1165 			} else {
1166 				emit_a64_mov_i(1, tmp, off, ctx);
1167 				emit(A64_STR32(src, dst, tmp), ctx);
1168 			}
1169 			break;
1170 		case BPF_H:
1171 			if (is_lsi_offset(off_adj, 1)) {
1172 				emit(A64_STRHI(src, dst_adj, off_adj), ctx);
1173 			} else {
1174 				emit_a64_mov_i(1, tmp, off, ctx);
1175 				emit(A64_STRH(src, dst, tmp), ctx);
1176 			}
1177 			break;
1178 		case BPF_B:
1179 			if (is_lsi_offset(off_adj, 0)) {
1180 				emit(A64_STRBI(src, dst_adj, off_adj), ctx);
1181 			} else {
1182 				emit_a64_mov_i(1, tmp, off, ctx);
1183 				emit(A64_STRB(src, dst, tmp), ctx);
1184 			}
1185 			break;
1186 		case BPF_DW:
1187 			if (is_lsi_offset(off_adj, 3)) {
1188 				emit(A64_STR64I(src, dst_adj, off_adj), ctx);
1189 			} else {
1190 				emit_a64_mov_i(1, tmp, off, ctx);
1191 				emit(A64_STR64(src, dst, tmp), ctx);
1192 			}
1193 			break;
1194 		}
1195 		break;
1196 
1197 	case BPF_STX | BPF_ATOMIC | BPF_W:
1198 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1199 		if (cpus_have_cap(ARM64_HAS_LSE_ATOMICS))
1200 			ret = emit_lse_atomic(insn, ctx);
1201 		else
1202 			ret = emit_ll_sc_atomic(insn, ctx);
1203 		if (ret)
1204 			return ret;
1205 		break;
1206 
1207 	default:
1208 		pr_err_once("unknown opcode %02x\n", code);
1209 		return -EINVAL;
1210 	}
1211 
1212 	return 0;
1213 }
1214 
1215 /*
1216  * Return 0 if FP may change at runtime, otherwise find the minimum negative
1217  * offset to FP, converts it to positive number, and align down to 8 bytes.
1218  */
1219 static int find_fpb_offset(struct bpf_prog *prog)
1220 {
1221 	int i;
1222 	int offset = 0;
1223 
1224 	for (i = 0; i < prog->len; i++) {
1225 		const struct bpf_insn *insn = &prog->insnsi[i];
1226 		const u8 class = BPF_CLASS(insn->code);
1227 		const u8 mode = BPF_MODE(insn->code);
1228 		const u8 src = insn->src_reg;
1229 		const u8 dst = insn->dst_reg;
1230 		const s32 imm = insn->imm;
1231 		const s16 off = insn->off;
1232 
1233 		switch (class) {
1234 		case BPF_STX:
1235 		case BPF_ST:
1236 			/* fp holds atomic operation result */
1237 			if (class == BPF_STX && mode == BPF_ATOMIC &&
1238 			    ((imm == BPF_XCHG ||
1239 			      imm == (BPF_FETCH | BPF_ADD) ||
1240 			      imm == (BPF_FETCH | BPF_AND) ||
1241 			      imm == (BPF_FETCH | BPF_XOR) ||
1242 			      imm == (BPF_FETCH | BPF_OR)) &&
1243 			     src == BPF_REG_FP))
1244 				return 0;
1245 
1246 			if (mode == BPF_MEM && dst == BPF_REG_FP &&
1247 			    off < offset)
1248 				offset = insn->off;
1249 			break;
1250 
1251 		case BPF_JMP32:
1252 		case BPF_JMP:
1253 			break;
1254 
1255 		case BPF_LDX:
1256 		case BPF_LD:
1257 			/* fp holds load result */
1258 			if (dst == BPF_REG_FP)
1259 				return 0;
1260 
1261 			if (class == BPF_LDX && mode == BPF_MEM &&
1262 			    src == BPF_REG_FP && off < offset)
1263 				offset = off;
1264 			break;
1265 
1266 		case BPF_ALU:
1267 		case BPF_ALU64:
1268 		default:
1269 			/* fp holds ALU result */
1270 			if (dst == BPF_REG_FP)
1271 				return 0;
1272 		}
1273 	}
1274 
1275 	if (offset < 0) {
1276 		/*
1277 		 * safely be converted to a positive 'int', since insn->off
1278 		 * is 's16'
1279 		 */
1280 		offset = -offset;
1281 		/* align down to 8 bytes */
1282 		offset = ALIGN_DOWN(offset, 8);
1283 	}
1284 
1285 	return offset;
1286 }
1287 
1288 static int build_body(struct jit_ctx *ctx, bool extra_pass)
1289 {
1290 	const struct bpf_prog *prog = ctx->prog;
1291 	int i;
1292 
1293 	/*
1294 	 * - offset[0] offset of the end of prologue,
1295 	 *   start of the 1st instruction.
1296 	 * - offset[1] - offset of the end of 1st instruction,
1297 	 *   start of the 2nd instruction
1298 	 * [....]
1299 	 * - offset[3] - offset of the end of 3rd instruction,
1300 	 *   start of 4th instruction
1301 	 */
1302 	for (i = 0; i < prog->len; i++) {
1303 		const struct bpf_insn *insn = &prog->insnsi[i];
1304 		int ret;
1305 
1306 		if (ctx->image == NULL)
1307 			ctx->offset[i] = ctx->idx;
1308 		ret = build_insn(insn, ctx, extra_pass);
1309 		if (ret > 0) {
1310 			i++;
1311 			if (ctx->image == NULL)
1312 				ctx->offset[i] = ctx->idx;
1313 			continue;
1314 		}
1315 		if (ret)
1316 			return ret;
1317 	}
1318 	/*
1319 	 * offset is allocated with prog->len + 1 so fill in
1320 	 * the last element with the offset after the last
1321 	 * instruction (end of program)
1322 	 */
1323 	if (ctx->image == NULL)
1324 		ctx->offset[i] = ctx->idx;
1325 
1326 	return 0;
1327 }
1328 
1329 static int validate_code(struct jit_ctx *ctx)
1330 {
1331 	int i;
1332 
1333 	for (i = 0; i < ctx->idx; i++) {
1334 		u32 a64_insn = le32_to_cpu(ctx->image[i]);
1335 
1336 		if (a64_insn == AARCH64_BREAK_FAULT)
1337 			return -1;
1338 	}
1339 
1340 	if (WARN_ON_ONCE(ctx->exentry_idx != ctx->prog->aux->num_exentries))
1341 		return -1;
1342 
1343 	return 0;
1344 }
1345 
1346 static inline void bpf_flush_icache(void *start, void *end)
1347 {
1348 	flush_icache_range((unsigned long)start, (unsigned long)end);
1349 }
1350 
1351 struct arm64_jit_data {
1352 	struct bpf_binary_header *header;
1353 	u8 *image;
1354 	struct jit_ctx ctx;
1355 };
1356 
1357 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1358 {
1359 	int image_size, prog_size, extable_size;
1360 	struct bpf_prog *tmp, *orig_prog = prog;
1361 	struct bpf_binary_header *header;
1362 	struct arm64_jit_data *jit_data;
1363 	bool was_classic = bpf_prog_was_classic(prog);
1364 	bool tmp_blinded = false;
1365 	bool extra_pass = false;
1366 	struct jit_ctx ctx;
1367 	u8 *image_ptr;
1368 
1369 	if (!prog->jit_requested)
1370 		return orig_prog;
1371 
1372 	tmp = bpf_jit_blind_constants(prog);
1373 	/* If blinding was requested and we failed during blinding,
1374 	 * we must fall back to the interpreter.
1375 	 */
1376 	if (IS_ERR(tmp))
1377 		return orig_prog;
1378 	if (tmp != prog) {
1379 		tmp_blinded = true;
1380 		prog = tmp;
1381 	}
1382 
1383 	jit_data = prog->aux->jit_data;
1384 	if (!jit_data) {
1385 		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1386 		if (!jit_data) {
1387 			prog = orig_prog;
1388 			goto out;
1389 		}
1390 		prog->aux->jit_data = jit_data;
1391 	}
1392 	if (jit_data->ctx.offset) {
1393 		ctx = jit_data->ctx;
1394 		image_ptr = jit_data->image;
1395 		header = jit_data->header;
1396 		extra_pass = true;
1397 		prog_size = sizeof(u32) * ctx.idx;
1398 		goto skip_init_ctx;
1399 	}
1400 	memset(&ctx, 0, sizeof(ctx));
1401 	ctx.prog = prog;
1402 
1403 	ctx.offset = kcalloc(prog->len + 1, sizeof(int), GFP_KERNEL);
1404 	if (ctx.offset == NULL) {
1405 		prog = orig_prog;
1406 		goto out_off;
1407 	}
1408 
1409 	ctx.fpb_offset = find_fpb_offset(prog);
1410 
1411 	/*
1412 	 * 1. Initial fake pass to compute ctx->idx and ctx->offset.
1413 	 *
1414 	 * BPF line info needs ctx->offset[i] to be the offset of
1415 	 * instruction[i] in jited image, so build prologue first.
1416 	 */
1417 	if (build_prologue(&ctx, was_classic)) {
1418 		prog = orig_prog;
1419 		goto out_off;
1420 	}
1421 
1422 	if (build_body(&ctx, extra_pass)) {
1423 		prog = orig_prog;
1424 		goto out_off;
1425 	}
1426 
1427 	ctx.epilogue_offset = ctx.idx;
1428 	build_epilogue(&ctx);
1429 
1430 	extable_size = prog->aux->num_exentries *
1431 		sizeof(struct exception_table_entry);
1432 
1433 	/* Now we know the actual image size. */
1434 	prog_size = sizeof(u32) * ctx.idx;
1435 	image_size = prog_size + extable_size;
1436 	header = bpf_jit_binary_alloc(image_size, &image_ptr,
1437 				      sizeof(u32), jit_fill_hole);
1438 	if (header == NULL) {
1439 		prog = orig_prog;
1440 		goto out_off;
1441 	}
1442 
1443 	/* 2. Now, the actual pass. */
1444 
1445 	ctx.image = (__le32 *)image_ptr;
1446 	if (extable_size)
1447 		prog->aux->extable = (void *)image_ptr + prog_size;
1448 skip_init_ctx:
1449 	ctx.idx = 0;
1450 	ctx.exentry_idx = 0;
1451 
1452 	build_prologue(&ctx, was_classic);
1453 
1454 	if (build_body(&ctx, extra_pass)) {
1455 		bpf_jit_binary_free(header);
1456 		prog = orig_prog;
1457 		goto out_off;
1458 	}
1459 
1460 	build_epilogue(&ctx);
1461 
1462 	/* 3. Extra pass to validate JITed code. */
1463 	if (validate_code(&ctx)) {
1464 		bpf_jit_binary_free(header);
1465 		prog = orig_prog;
1466 		goto out_off;
1467 	}
1468 
1469 	/* And we're done. */
1470 	if (bpf_jit_enable > 1)
1471 		bpf_jit_dump(prog->len, prog_size, 2, ctx.image);
1472 
1473 	bpf_flush_icache(header, ctx.image + ctx.idx);
1474 
1475 	if (!prog->is_func || extra_pass) {
1476 		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
1477 			pr_err_once("multi-func JIT bug %d != %d\n",
1478 				    ctx.idx, jit_data->ctx.idx);
1479 			bpf_jit_binary_free(header);
1480 			prog->bpf_func = NULL;
1481 			prog->jited = 0;
1482 			prog->jited_len = 0;
1483 			goto out_off;
1484 		}
1485 		bpf_jit_binary_lock_ro(header);
1486 	} else {
1487 		jit_data->ctx = ctx;
1488 		jit_data->image = image_ptr;
1489 		jit_data->header = header;
1490 	}
1491 	prog->bpf_func = (void *)ctx.image;
1492 	prog->jited = 1;
1493 	prog->jited_len = prog_size;
1494 
1495 	if (!prog->is_func || extra_pass) {
1496 		int i;
1497 
1498 		/* offset[prog->len] is the size of program */
1499 		for (i = 0; i <= prog->len; i++)
1500 			ctx.offset[i] *= AARCH64_INSN_SIZE;
1501 		bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
1502 out_off:
1503 		kfree(ctx.offset);
1504 		kfree(jit_data);
1505 		prog->aux->jit_data = NULL;
1506 	}
1507 out:
1508 	if (tmp_blinded)
1509 		bpf_jit_prog_release_other(prog, prog == orig_prog ?
1510 					   tmp : orig_prog);
1511 	return prog;
1512 }
1513 
1514 bool bpf_jit_supports_kfunc_call(void)
1515 {
1516 	return true;
1517 }
1518 
1519 u64 bpf_jit_alloc_exec_limit(void)
1520 {
1521 	return VMALLOC_END - VMALLOC_START;
1522 }
1523 
1524 void *bpf_jit_alloc_exec(unsigned long size)
1525 {
1526 	/* Memory is intended to be executable, reset the pointer tag. */
1527 	return kasan_reset_tag(vmalloc(size));
1528 }
1529 
1530 void bpf_jit_free_exec(void *addr)
1531 {
1532 	return vfree(addr);
1533 }
1534 
1535 /* Indicate the JIT backend supports mixing bpf2bpf and tailcalls. */
1536 bool bpf_jit_supports_subprog_tailcalls(void)
1537 {
1538 	return true;
1539 }
1540