xref: /openbmc/linux/arch/arm64/net/bpf_jit_comp.c (revision 7fc96d71)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler for ARM64
4  *
5  * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
6  */
7 
8 #define pr_fmt(fmt) "bpf_jit: " fmt
9 
10 #include <linux/bitfield.h>
11 #include <linux/bpf.h>
12 #include <linux/filter.h>
13 #include <linux/printk.h>
14 #include <linux/slab.h>
15 
16 #include <asm/asm-extable.h>
17 #include <asm/byteorder.h>
18 #include <asm/cacheflush.h>
19 #include <asm/debug-monitors.h>
20 #include <asm/insn.h>
21 #include <asm/set_memory.h>
22 
23 #include "bpf_jit.h"
24 
25 #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
26 #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
27 #define TCALL_CNT (MAX_BPF_JIT_REG + 2)
28 #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
29 #define FP_BOTTOM (MAX_BPF_JIT_REG + 4)
30 
31 #define check_imm(bits, imm) do {				\
32 	if ((((imm) > 0) && ((imm) >> (bits))) ||		\
33 	    (((imm) < 0) && (~(imm) >> (bits)))) {		\
34 		pr_info("[%2d] imm=%d(0x%x) out of range\n",	\
35 			i, imm, imm);				\
36 		return -EINVAL;					\
37 	}							\
38 } while (0)
39 #define check_imm19(imm) check_imm(19, imm)
40 #define check_imm26(imm) check_imm(26, imm)
41 
42 /* Map BPF registers to A64 registers */
43 static const int bpf2a64[] = {
44 	/* return value from in-kernel function, and exit value from eBPF */
45 	[BPF_REG_0] = A64_R(7),
46 	/* arguments from eBPF program to in-kernel function */
47 	[BPF_REG_1] = A64_R(0),
48 	[BPF_REG_2] = A64_R(1),
49 	[BPF_REG_3] = A64_R(2),
50 	[BPF_REG_4] = A64_R(3),
51 	[BPF_REG_5] = A64_R(4),
52 	/* callee saved registers that in-kernel function will preserve */
53 	[BPF_REG_6] = A64_R(19),
54 	[BPF_REG_7] = A64_R(20),
55 	[BPF_REG_8] = A64_R(21),
56 	[BPF_REG_9] = A64_R(22),
57 	/* read-only frame pointer to access stack */
58 	[BPF_REG_FP] = A64_R(25),
59 	/* temporary registers for BPF JIT */
60 	[TMP_REG_1] = A64_R(10),
61 	[TMP_REG_2] = A64_R(11),
62 	[TMP_REG_3] = A64_R(12),
63 	/* tail_call_cnt */
64 	[TCALL_CNT] = A64_R(26),
65 	/* temporary register for blinding constants */
66 	[BPF_REG_AX] = A64_R(9),
67 	[FP_BOTTOM] = A64_R(27),
68 };
69 
70 struct jit_ctx {
71 	const struct bpf_prog *prog;
72 	int idx;
73 	int epilogue_offset;
74 	int *offset;
75 	int exentry_idx;
76 	__le32 *image;
77 	u32 stack_size;
78 	int fpb_offset;
79 };
80 
81 static inline void emit(const u32 insn, struct jit_ctx *ctx)
82 {
83 	if (ctx->image != NULL)
84 		ctx->image[ctx->idx] = cpu_to_le32(insn);
85 
86 	ctx->idx++;
87 }
88 
89 static inline void emit_a64_mov_i(const int is64, const int reg,
90 				  const s32 val, struct jit_ctx *ctx)
91 {
92 	u16 hi = val >> 16;
93 	u16 lo = val & 0xffff;
94 
95 	if (hi & 0x8000) {
96 		if (hi == 0xffff) {
97 			emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
98 		} else {
99 			emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
100 			if (lo != 0xffff)
101 				emit(A64_MOVK(is64, reg, lo, 0), ctx);
102 		}
103 	} else {
104 		emit(A64_MOVZ(is64, reg, lo, 0), ctx);
105 		if (hi)
106 			emit(A64_MOVK(is64, reg, hi, 16), ctx);
107 	}
108 }
109 
110 static int i64_i16_blocks(const u64 val, bool inverse)
111 {
112 	return (((val >>  0) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
113 	       (((val >> 16) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
114 	       (((val >> 32) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
115 	       (((val >> 48) & 0xffff) != (inverse ? 0xffff : 0x0000));
116 }
117 
118 static inline void emit_a64_mov_i64(const int reg, const u64 val,
119 				    struct jit_ctx *ctx)
120 {
121 	u64 nrm_tmp = val, rev_tmp = ~val;
122 	bool inverse;
123 	int shift;
124 
125 	if (!(nrm_tmp >> 32))
126 		return emit_a64_mov_i(0, reg, (u32)val, ctx);
127 
128 	inverse = i64_i16_blocks(nrm_tmp, true) < i64_i16_blocks(nrm_tmp, false);
129 	shift = max(round_down((inverse ? (fls64(rev_tmp) - 1) :
130 					  (fls64(nrm_tmp) - 1)), 16), 0);
131 	if (inverse)
132 		emit(A64_MOVN(1, reg, (rev_tmp >> shift) & 0xffff, shift), ctx);
133 	else
134 		emit(A64_MOVZ(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
135 	shift -= 16;
136 	while (shift >= 0) {
137 		if (((nrm_tmp >> shift) & 0xffff) != (inverse ? 0xffff : 0x0000))
138 			emit(A64_MOVK(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
139 		shift -= 16;
140 	}
141 }
142 
143 /*
144  * Kernel addresses in the vmalloc space use at most 48 bits, and the
145  * remaining bits are guaranteed to be 0x1. So we can compose the address
146  * with a fixed length movn/movk/movk sequence.
147  */
148 static inline void emit_addr_mov_i64(const int reg, const u64 val,
149 				     struct jit_ctx *ctx)
150 {
151 	u64 tmp = val;
152 	int shift = 0;
153 
154 	emit(A64_MOVN(1, reg, ~tmp & 0xffff, shift), ctx);
155 	while (shift < 32) {
156 		tmp >>= 16;
157 		shift += 16;
158 		emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
159 	}
160 }
161 
162 static inline int bpf2a64_offset(int bpf_insn, int off,
163 				 const struct jit_ctx *ctx)
164 {
165 	/* BPF JMP offset is relative to the next instruction */
166 	bpf_insn++;
167 	/*
168 	 * Whereas arm64 branch instructions encode the offset
169 	 * from the branch itself, so we must subtract 1 from the
170 	 * instruction offset.
171 	 */
172 	return ctx->offset[bpf_insn + off] - (ctx->offset[bpf_insn] - 1);
173 }
174 
175 static void jit_fill_hole(void *area, unsigned int size)
176 {
177 	__le32 *ptr;
178 	/* We are guaranteed to have aligned memory. */
179 	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
180 		*ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
181 }
182 
183 static inline int epilogue_offset(const struct jit_ctx *ctx)
184 {
185 	int to = ctx->epilogue_offset;
186 	int from = ctx->idx;
187 
188 	return to - from;
189 }
190 
191 static bool is_addsub_imm(u32 imm)
192 {
193 	/* Either imm12 or shifted imm12. */
194 	return !(imm & ~0xfff) || !(imm & ~0xfff000);
195 }
196 
197 /*
198  * There are 3 types of AArch64 LDR/STR (immediate) instruction:
199  * Post-index, Pre-index, Unsigned offset.
200  *
201  * For BPF ldr/str, the "unsigned offset" type is sufficient.
202  *
203  * "Unsigned offset" type LDR(immediate) format:
204  *
205  *    3                   2                   1                   0
206  *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
207  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
208  * |x x|1 1 1 0 0 1 0 1|         imm12         |    Rn   |    Rt   |
209  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
210  * scale
211  *
212  * "Unsigned offset" type STR(immediate) format:
213  *    3                   2                   1                   0
214  *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
215  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
216  * |x x|1 1 1 0 0 1 0 0|         imm12         |    Rn   |    Rt   |
217  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
218  * scale
219  *
220  * The offset is calculated from imm12 and scale in the following way:
221  *
222  * offset = (u64)imm12 << scale
223  */
224 static bool is_lsi_offset(int offset, int scale)
225 {
226 	if (offset < 0)
227 		return false;
228 
229 	if (offset > (0xFFF << scale))
230 		return false;
231 
232 	if (offset & ((1 << scale) - 1))
233 		return false;
234 
235 	return true;
236 }
237 
238 /* Tail call offset to jump into */
239 #if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL) || \
240 	IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL)
241 #define PROLOGUE_OFFSET 9
242 #else
243 #define PROLOGUE_OFFSET 8
244 #endif
245 
246 static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
247 {
248 	const struct bpf_prog *prog = ctx->prog;
249 	const u8 r6 = bpf2a64[BPF_REG_6];
250 	const u8 r7 = bpf2a64[BPF_REG_7];
251 	const u8 r8 = bpf2a64[BPF_REG_8];
252 	const u8 r9 = bpf2a64[BPF_REG_9];
253 	const u8 fp = bpf2a64[BPF_REG_FP];
254 	const u8 tcc = bpf2a64[TCALL_CNT];
255 	const u8 fpb = bpf2a64[FP_BOTTOM];
256 	const int idx0 = ctx->idx;
257 	int cur_offset;
258 
259 	/*
260 	 * BPF prog stack layout
261 	 *
262 	 *                         high
263 	 * original A64_SP =>   0:+-----+ BPF prologue
264 	 *                        |FP/LR|
265 	 * current A64_FP =>  -16:+-----+
266 	 *                        | ... | callee saved registers
267 	 * BPF fp register => -64:+-----+ <= (BPF_FP)
268 	 *                        |     |
269 	 *                        | ... | BPF prog stack
270 	 *                        |     |
271 	 *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
272 	 *                        |RSVD | padding
273 	 * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
274 	 *                        |     |
275 	 *                        | ... | Function call stack
276 	 *                        |     |
277 	 *                        +-----+
278 	 *                          low
279 	 *
280 	 */
281 
282 	/* Sign lr */
283 	if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
284 		emit(A64_PACIASP, ctx);
285 	/* BTI landing pad */
286 	else if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
287 		emit(A64_BTI_C, ctx);
288 
289 	/* Save FP and LR registers to stay align with ARM64 AAPCS */
290 	emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
291 	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
292 
293 	/* Save callee-saved registers */
294 	emit(A64_PUSH(r6, r7, A64_SP), ctx);
295 	emit(A64_PUSH(r8, r9, A64_SP), ctx);
296 	emit(A64_PUSH(fp, tcc, A64_SP), ctx);
297 	emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx);
298 
299 	/* Set up BPF prog stack base register */
300 	emit(A64_MOV(1, fp, A64_SP), ctx);
301 
302 	if (!ebpf_from_cbpf) {
303 		/* Initialize tail_call_cnt */
304 		emit(A64_MOVZ(1, tcc, 0, 0), ctx);
305 
306 		cur_offset = ctx->idx - idx0;
307 		if (cur_offset != PROLOGUE_OFFSET) {
308 			pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
309 				    cur_offset, PROLOGUE_OFFSET);
310 			return -1;
311 		}
312 
313 		/* BTI landing pad for the tail call, done with a BR */
314 		if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
315 			emit(A64_BTI_J, ctx);
316 	}
317 
318 	emit(A64_SUB_I(1, fpb, fp, ctx->fpb_offset), ctx);
319 
320 	/* Stack must be multiples of 16B */
321 	ctx->stack_size = round_up(prog->aux->stack_depth, 16);
322 
323 	/* Set up function call stack */
324 	emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
325 	return 0;
326 }
327 
328 static int out_offset = -1; /* initialized on the first pass of build_body() */
329 static int emit_bpf_tail_call(struct jit_ctx *ctx)
330 {
331 	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
332 	const u8 r2 = bpf2a64[BPF_REG_2];
333 	const u8 r3 = bpf2a64[BPF_REG_3];
334 
335 	const u8 tmp = bpf2a64[TMP_REG_1];
336 	const u8 prg = bpf2a64[TMP_REG_2];
337 	const u8 tcc = bpf2a64[TCALL_CNT];
338 	const int idx0 = ctx->idx;
339 #define cur_offset (ctx->idx - idx0)
340 #define jmp_offset (out_offset - (cur_offset))
341 	size_t off;
342 
343 	/* if (index >= array->map.max_entries)
344 	 *     goto out;
345 	 */
346 	off = offsetof(struct bpf_array, map.max_entries);
347 	emit_a64_mov_i64(tmp, off, ctx);
348 	emit(A64_LDR32(tmp, r2, tmp), ctx);
349 	emit(A64_MOV(0, r3, r3), ctx);
350 	emit(A64_CMP(0, r3, tmp), ctx);
351 	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
352 
353 	/*
354 	 * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
355 	 *     goto out;
356 	 * tail_call_cnt++;
357 	 */
358 	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
359 	emit(A64_CMP(1, tcc, tmp), ctx);
360 	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
361 	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
362 
363 	/* prog = array->ptrs[index];
364 	 * if (prog == NULL)
365 	 *     goto out;
366 	 */
367 	off = offsetof(struct bpf_array, ptrs);
368 	emit_a64_mov_i64(tmp, off, ctx);
369 	emit(A64_ADD(1, tmp, r2, tmp), ctx);
370 	emit(A64_LSL(1, prg, r3, 3), ctx);
371 	emit(A64_LDR64(prg, tmp, prg), ctx);
372 	emit(A64_CBZ(1, prg, jmp_offset), ctx);
373 
374 	/* goto *(prog->bpf_func + prologue_offset); */
375 	off = offsetof(struct bpf_prog, bpf_func);
376 	emit_a64_mov_i64(tmp, off, ctx);
377 	emit(A64_LDR64(tmp, prg, tmp), ctx);
378 	emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
379 	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
380 	emit(A64_BR(tmp), ctx);
381 
382 	/* out: */
383 	if (out_offset == -1)
384 		out_offset = cur_offset;
385 	if (cur_offset != out_offset) {
386 		pr_err_once("tail_call out_offset = %d, expected %d!\n",
387 			    cur_offset, out_offset);
388 		return -1;
389 	}
390 	return 0;
391 #undef cur_offset
392 #undef jmp_offset
393 }
394 
395 #ifdef CONFIG_ARM64_LSE_ATOMICS
396 static int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
397 {
398 	const u8 code = insn->code;
399 	const u8 dst = bpf2a64[insn->dst_reg];
400 	const u8 src = bpf2a64[insn->src_reg];
401 	const u8 tmp = bpf2a64[TMP_REG_1];
402 	const u8 tmp2 = bpf2a64[TMP_REG_2];
403 	const bool isdw = BPF_SIZE(code) == BPF_DW;
404 	const s16 off = insn->off;
405 	u8 reg;
406 
407 	if (!off) {
408 		reg = dst;
409 	} else {
410 		emit_a64_mov_i(1, tmp, off, ctx);
411 		emit(A64_ADD(1, tmp, tmp, dst), ctx);
412 		reg = tmp;
413 	}
414 
415 	switch (insn->imm) {
416 	/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
417 	case BPF_ADD:
418 		emit(A64_STADD(isdw, reg, src), ctx);
419 		break;
420 	case BPF_AND:
421 		emit(A64_MVN(isdw, tmp2, src), ctx);
422 		emit(A64_STCLR(isdw, reg, tmp2), ctx);
423 		break;
424 	case BPF_OR:
425 		emit(A64_STSET(isdw, reg, src), ctx);
426 		break;
427 	case BPF_XOR:
428 		emit(A64_STEOR(isdw, reg, src), ctx);
429 		break;
430 	/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
431 	case BPF_ADD | BPF_FETCH:
432 		emit(A64_LDADDAL(isdw, src, reg, src), ctx);
433 		break;
434 	case BPF_AND | BPF_FETCH:
435 		emit(A64_MVN(isdw, tmp2, src), ctx);
436 		emit(A64_LDCLRAL(isdw, src, reg, tmp2), ctx);
437 		break;
438 	case BPF_OR | BPF_FETCH:
439 		emit(A64_LDSETAL(isdw, src, reg, src), ctx);
440 		break;
441 	case BPF_XOR | BPF_FETCH:
442 		emit(A64_LDEORAL(isdw, src, reg, src), ctx);
443 		break;
444 	/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
445 	case BPF_XCHG:
446 		emit(A64_SWPAL(isdw, src, reg, src), ctx);
447 		break;
448 	/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
449 	case BPF_CMPXCHG:
450 		emit(A64_CASAL(isdw, src, reg, bpf2a64[BPF_REG_0]), ctx);
451 		break;
452 	default:
453 		pr_err_once("unknown atomic op code %02x\n", insn->imm);
454 		return -EINVAL;
455 	}
456 
457 	return 0;
458 }
459 #else
460 static inline int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
461 {
462 	return -EINVAL;
463 }
464 #endif
465 
466 static int emit_ll_sc_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
467 {
468 	const u8 code = insn->code;
469 	const u8 dst = bpf2a64[insn->dst_reg];
470 	const u8 src = bpf2a64[insn->src_reg];
471 	const u8 tmp = bpf2a64[TMP_REG_1];
472 	const u8 tmp2 = bpf2a64[TMP_REG_2];
473 	const u8 tmp3 = bpf2a64[TMP_REG_3];
474 	const int i = insn - ctx->prog->insnsi;
475 	const s32 imm = insn->imm;
476 	const s16 off = insn->off;
477 	const bool isdw = BPF_SIZE(code) == BPF_DW;
478 	u8 reg;
479 	s32 jmp_offset;
480 
481 	if (!off) {
482 		reg = dst;
483 	} else {
484 		emit_a64_mov_i(1, tmp, off, ctx);
485 		emit(A64_ADD(1, tmp, tmp, dst), ctx);
486 		reg = tmp;
487 	}
488 
489 	if (imm == BPF_ADD || imm == BPF_AND ||
490 	    imm == BPF_OR || imm == BPF_XOR) {
491 		/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
492 		emit(A64_LDXR(isdw, tmp2, reg), ctx);
493 		if (imm == BPF_ADD)
494 			emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
495 		else if (imm == BPF_AND)
496 			emit(A64_AND(isdw, tmp2, tmp2, src), ctx);
497 		else if (imm == BPF_OR)
498 			emit(A64_ORR(isdw, tmp2, tmp2, src), ctx);
499 		else
500 			emit(A64_EOR(isdw, tmp2, tmp2, src), ctx);
501 		emit(A64_STXR(isdw, tmp2, reg, tmp3), ctx);
502 		jmp_offset = -3;
503 		check_imm19(jmp_offset);
504 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
505 	} else if (imm == (BPF_ADD | BPF_FETCH) ||
506 		   imm == (BPF_AND | BPF_FETCH) ||
507 		   imm == (BPF_OR | BPF_FETCH) ||
508 		   imm == (BPF_XOR | BPF_FETCH)) {
509 		/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
510 		const u8 ax = bpf2a64[BPF_REG_AX];
511 
512 		emit(A64_MOV(isdw, ax, src), ctx);
513 		emit(A64_LDXR(isdw, src, reg), ctx);
514 		if (imm == (BPF_ADD | BPF_FETCH))
515 			emit(A64_ADD(isdw, tmp2, src, ax), ctx);
516 		else if (imm == (BPF_AND | BPF_FETCH))
517 			emit(A64_AND(isdw, tmp2, src, ax), ctx);
518 		else if (imm == (BPF_OR | BPF_FETCH))
519 			emit(A64_ORR(isdw, tmp2, src, ax), ctx);
520 		else
521 			emit(A64_EOR(isdw, tmp2, src, ax), ctx);
522 		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
523 		jmp_offset = -3;
524 		check_imm19(jmp_offset);
525 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
526 		emit(A64_DMB_ISH, ctx);
527 	} else if (imm == BPF_XCHG) {
528 		/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
529 		emit(A64_MOV(isdw, tmp2, src), ctx);
530 		emit(A64_LDXR(isdw, src, reg), ctx);
531 		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
532 		jmp_offset = -2;
533 		check_imm19(jmp_offset);
534 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
535 		emit(A64_DMB_ISH, ctx);
536 	} else if (imm == BPF_CMPXCHG) {
537 		/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
538 		const u8 r0 = bpf2a64[BPF_REG_0];
539 
540 		emit(A64_MOV(isdw, tmp2, r0), ctx);
541 		emit(A64_LDXR(isdw, r0, reg), ctx);
542 		emit(A64_EOR(isdw, tmp3, r0, tmp2), ctx);
543 		jmp_offset = 4;
544 		check_imm19(jmp_offset);
545 		emit(A64_CBNZ(isdw, tmp3, jmp_offset), ctx);
546 		emit(A64_STLXR(isdw, src, reg, tmp3), ctx);
547 		jmp_offset = -4;
548 		check_imm19(jmp_offset);
549 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
550 		emit(A64_DMB_ISH, ctx);
551 	} else {
552 		pr_err_once("unknown atomic op code %02x\n", imm);
553 		return -EINVAL;
554 	}
555 
556 	return 0;
557 }
558 
559 static void build_epilogue(struct jit_ctx *ctx)
560 {
561 	const u8 r0 = bpf2a64[BPF_REG_0];
562 	const u8 r6 = bpf2a64[BPF_REG_6];
563 	const u8 r7 = bpf2a64[BPF_REG_7];
564 	const u8 r8 = bpf2a64[BPF_REG_8];
565 	const u8 r9 = bpf2a64[BPF_REG_9];
566 	const u8 fp = bpf2a64[BPF_REG_FP];
567 	const u8 fpb = bpf2a64[FP_BOTTOM];
568 
569 	/* We're done with BPF stack */
570 	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
571 
572 	/* Restore x27 and x28 */
573 	emit(A64_POP(fpb, A64_R(28), A64_SP), ctx);
574 	/* Restore fs (x25) and x26 */
575 	emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
576 
577 	/* Restore callee-saved register */
578 	emit(A64_POP(r8, r9, A64_SP), ctx);
579 	emit(A64_POP(r6, r7, A64_SP), ctx);
580 
581 	/* Restore FP/LR registers */
582 	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
583 
584 	/* Set return value */
585 	emit(A64_MOV(1, A64_R(0), r0), ctx);
586 
587 	/* Authenticate lr */
588 	if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
589 		emit(A64_AUTIASP, ctx);
590 
591 	emit(A64_RET(A64_LR), ctx);
592 }
593 
594 #define BPF_FIXUP_OFFSET_MASK	GENMASK(26, 0)
595 #define BPF_FIXUP_REG_MASK	GENMASK(31, 27)
596 
597 bool ex_handler_bpf(const struct exception_table_entry *ex,
598 		    struct pt_regs *regs)
599 {
600 	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
601 	int dst_reg = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
602 
603 	regs->regs[dst_reg] = 0;
604 	regs->pc = (unsigned long)&ex->fixup - offset;
605 	return true;
606 }
607 
608 /* For accesses to BTF pointers, add an entry to the exception table */
609 static int add_exception_handler(const struct bpf_insn *insn,
610 				 struct jit_ctx *ctx,
611 				 int dst_reg)
612 {
613 	off_t offset;
614 	unsigned long pc;
615 	struct exception_table_entry *ex;
616 
617 	if (!ctx->image)
618 		/* First pass */
619 		return 0;
620 
621 	if (BPF_MODE(insn->code) != BPF_PROBE_MEM)
622 		return 0;
623 
624 	if (!ctx->prog->aux->extable ||
625 	    WARN_ON_ONCE(ctx->exentry_idx >= ctx->prog->aux->num_exentries))
626 		return -EINVAL;
627 
628 	ex = &ctx->prog->aux->extable[ctx->exentry_idx];
629 	pc = (unsigned long)&ctx->image[ctx->idx - 1];
630 
631 	offset = pc - (long)&ex->insn;
632 	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
633 		return -ERANGE;
634 	ex->insn = offset;
635 
636 	/*
637 	 * Since the extable follows the program, the fixup offset is always
638 	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
639 	 * to keep things simple, and put the destination register in the upper
640 	 * bits. We don't need to worry about buildtime or runtime sort
641 	 * modifying the upper bits because the table is already sorted, and
642 	 * isn't part of the main exception table.
643 	 */
644 	offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE);
645 	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
646 		return -ERANGE;
647 
648 	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
649 		    FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
650 
651 	ex->type = EX_TYPE_BPF;
652 
653 	ctx->exentry_idx++;
654 	return 0;
655 }
656 
657 /* JITs an eBPF instruction.
658  * Returns:
659  * 0  - successfully JITed an 8-byte eBPF instruction.
660  * >0 - successfully JITed a 16-byte eBPF instruction.
661  * <0 - failed to JIT.
662  */
663 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
664 		      bool extra_pass)
665 {
666 	const u8 code = insn->code;
667 	const u8 dst = bpf2a64[insn->dst_reg];
668 	const u8 src = bpf2a64[insn->src_reg];
669 	const u8 tmp = bpf2a64[TMP_REG_1];
670 	const u8 tmp2 = bpf2a64[TMP_REG_2];
671 	const u8 fp = bpf2a64[BPF_REG_FP];
672 	const u8 fpb = bpf2a64[FP_BOTTOM];
673 	const s16 off = insn->off;
674 	const s32 imm = insn->imm;
675 	const int i = insn - ctx->prog->insnsi;
676 	const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
677 			  BPF_CLASS(code) == BPF_JMP;
678 	u8 jmp_cond;
679 	s32 jmp_offset;
680 	u32 a64_insn;
681 	u8 src_adj;
682 	u8 dst_adj;
683 	int off_adj;
684 	int ret;
685 
686 	switch (code) {
687 	/* dst = src */
688 	case BPF_ALU | BPF_MOV | BPF_X:
689 	case BPF_ALU64 | BPF_MOV | BPF_X:
690 		emit(A64_MOV(is64, dst, src), ctx);
691 		break;
692 	/* dst = dst OP src */
693 	case BPF_ALU | BPF_ADD | BPF_X:
694 	case BPF_ALU64 | BPF_ADD | BPF_X:
695 		emit(A64_ADD(is64, dst, dst, src), ctx);
696 		break;
697 	case BPF_ALU | BPF_SUB | BPF_X:
698 	case BPF_ALU64 | BPF_SUB | BPF_X:
699 		emit(A64_SUB(is64, dst, dst, src), ctx);
700 		break;
701 	case BPF_ALU | BPF_AND | BPF_X:
702 	case BPF_ALU64 | BPF_AND | BPF_X:
703 		emit(A64_AND(is64, dst, dst, src), ctx);
704 		break;
705 	case BPF_ALU | BPF_OR | BPF_X:
706 	case BPF_ALU64 | BPF_OR | BPF_X:
707 		emit(A64_ORR(is64, dst, dst, src), ctx);
708 		break;
709 	case BPF_ALU | BPF_XOR | BPF_X:
710 	case BPF_ALU64 | BPF_XOR | BPF_X:
711 		emit(A64_EOR(is64, dst, dst, src), ctx);
712 		break;
713 	case BPF_ALU | BPF_MUL | BPF_X:
714 	case BPF_ALU64 | BPF_MUL | BPF_X:
715 		emit(A64_MUL(is64, dst, dst, src), ctx);
716 		break;
717 	case BPF_ALU | BPF_DIV | BPF_X:
718 	case BPF_ALU64 | BPF_DIV | BPF_X:
719 		emit(A64_UDIV(is64, dst, dst, src), ctx);
720 		break;
721 	case BPF_ALU | BPF_MOD | BPF_X:
722 	case BPF_ALU64 | BPF_MOD | BPF_X:
723 		emit(A64_UDIV(is64, tmp, dst, src), ctx);
724 		emit(A64_MSUB(is64, dst, dst, tmp, src), ctx);
725 		break;
726 	case BPF_ALU | BPF_LSH | BPF_X:
727 	case BPF_ALU64 | BPF_LSH | BPF_X:
728 		emit(A64_LSLV(is64, dst, dst, src), ctx);
729 		break;
730 	case BPF_ALU | BPF_RSH | BPF_X:
731 	case BPF_ALU64 | BPF_RSH | BPF_X:
732 		emit(A64_LSRV(is64, dst, dst, src), ctx);
733 		break;
734 	case BPF_ALU | BPF_ARSH | BPF_X:
735 	case BPF_ALU64 | BPF_ARSH | BPF_X:
736 		emit(A64_ASRV(is64, dst, dst, src), ctx);
737 		break;
738 	/* dst = -dst */
739 	case BPF_ALU | BPF_NEG:
740 	case BPF_ALU64 | BPF_NEG:
741 		emit(A64_NEG(is64, dst, dst), ctx);
742 		break;
743 	/* dst = BSWAP##imm(dst) */
744 	case BPF_ALU | BPF_END | BPF_FROM_LE:
745 	case BPF_ALU | BPF_END | BPF_FROM_BE:
746 #ifdef CONFIG_CPU_BIG_ENDIAN
747 		if (BPF_SRC(code) == BPF_FROM_BE)
748 			goto emit_bswap_uxt;
749 #else /* !CONFIG_CPU_BIG_ENDIAN */
750 		if (BPF_SRC(code) == BPF_FROM_LE)
751 			goto emit_bswap_uxt;
752 #endif
753 		switch (imm) {
754 		case 16:
755 			emit(A64_REV16(is64, dst, dst), ctx);
756 			/* zero-extend 16 bits into 64 bits */
757 			emit(A64_UXTH(is64, dst, dst), ctx);
758 			break;
759 		case 32:
760 			emit(A64_REV32(is64, dst, dst), ctx);
761 			/* upper 32 bits already cleared */
762 			break;
763 		case 64:
764 			emit(A64_REV64(dst, dst), ctx);
765 			break;
766 		}
767 		break;
768 emit_bswap_uxt:
769 		switch (imm) {
770 		case 16:
771 			/* zero-extend 16 bits into 64 bits */
772 			emit(A64_UXTH(is64, dst, dst), ctx);
773 			break;
774 		case 32:
775 			/* zero-extend 32 bits into 64 bits */
776 			emit(A64_UXTW(is64, dst, dst), ctx);
777 			break;
778 		case 64:
779 			/* nop */
780 			break;
781 		}
782 		break;
783 	/* dst = imm */
784 	case BPF_ALU | BPF_MOV | BPF_K:
785 	case BPF_ALU64 | BPF_MOV | BPF_K:
786 		emit_a64_mov_i(is64, dst, imm, ctx);
787 		break;
788 	/* dst = dst OP imm */
789 	case BPF_ALU | BPF_ADD | BPF_K:
790 	case BPF_ALU64 | BPF_ADD | BPF_K:
791 		if (is_addsub_imm(imm)) {
792 			emit(A64_ADD_I(is64, dst, dst, imm), ctx);
793 		} else if (is_addsub_imm(-imm)) {
794 			emit(A64_SUB_I(is64, dst, dst, -imm), ctx);
795 		} else {
796 			emit_a64_mov_i(is64, tmp, imm, ctx);
797 			emit(A64_ADD(is64, dst, dst, tmp), ctx);
798 		}
799 		break;
800 	case BPF_ALU | BPF_SUB | BPF_K:
801 	case BPF_ALU64 | BPF_SUB | BPF_K:
802 		if (is_addsub_imm(imm)) {
803 			emit(A64_SUB_I(is64, dst, dst, imm), ctx);
804 		} else if (is_addsub_imm(-imm)) {
805 			emit(A64_ADD_I(is64, dst, dst, -imm), ctx);
806 		} else {
807 			emit_a64_mov_i(is64, tmp, imm, ctx);
808 			emit(A64_SUB(is64, dst, dst, tmp), ctx);
809 		}
810 		break;
811 	case BPF_ALU | BPF_AND | BPF_K:
812 	case BPF_ALU64 | BPF_AND | BPF_K:
813 		a64_insn = A64_AND_I(is64, dst, dst, imm);
814 		if (a64_insn != AARCH64_BREAK_FAULT) {
815 			emit(a64_insn, ctx);
816 		} else {
817 			emit_a64_mov_i(is64, tmp, imm, ctx);
818 			emit(A64_AND(is64, dst, dst, tmp), ctx);
819 		}
820 		break;
821 	case BPF_ALU | BPF_OR | BPF_K:
822 	case BPF_ALU64 | BPF_OR | BPF_K:
823 		a64_insn = A64_ORR_I(is64, dst, dst, imm);
824 		if (a64_insn != AARCH64_BREAK_FAULT) {
825 			emit(a64_insn, ctx);
826 		} else {
827 			emit_a64_mov_i(is64, tmp, imm, ctx);
828 			emit(A64_ORR(is64, dst, dst, tmp), ctx);
829 		}
830 		break;
831 	case BPF_ALU | BPF_XOR | BPF_K:
832 	case BPF_ALU64 | BPF_XOR | BPF_K:
833 		a64_insn = A64_EOR_I(is64, dst, dst, imm);
834 		if (a64_insn != AARCH64_BREAK_FAULT) {
835 			emit(a64_insn, ctx);
836 		} else {
837 			emit_a64_mov_i(is64, tmp, imm, ctx);
838 			emit(A64_EOR(is64, dst, dst, tmp), ctx);
839 		}
840 		break;
841 	case BPF_ALU | BPF_MUL | BPF_K:
842 	case BPF_ALU64 | BPF_MUL | BPF_K:
843 		emit_a64_mov_i(is64, tmp, imm, ctx);
844 		emit(A64_MUL(is64, dst, dst, tmp), ctx);
845 		break;
846 	case BPF_ALU | BPF_DIV | BPF_K:
847 	case BPF_ALU64 | BPF_DIV | BPF_K:
848 		emit_a64_mov_i(is64, tmp, imm, ctx);
849 		emit(A64_UDIV(is64, dst, dst, tmp), ctx);
850 		break;
851 	case BPF_ALU | BPF_MOD | BPF_K:
852 	case BPF_ALU64 | BPF_MOD | BPF_K:
853 		emit_a64_mov_i(is64, tmp2, imm, ctx);
854 		emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
855 		emit(A64_MSUB(is64, dst, dst, tmp, tmp2), ctx);
856 		break;
857 	case BPF_ALU | BPF_LSH | BPF_K:
858 	case BPF_ALU64 | BPF_LSH | BPF_K:
859 		emit(A64_LSL(is64, dst, dst, imm), ctx);
860 		break;
861 	case BPF_ALU | BPF_RSH | BPF_K:
862 	case BPF_ALU64 | BPF_RSH | BPF_K:
863 		emit(A64_LSR(is64, dst, dst, imm), ctx);
864 		break;
865 	case BPF_ALU | BPF_ARSH | BPF_K:
866 	case BPF_ALU64 | BPF_ARSH | BPF_K:
867 		emit(A64_ASR(is64, dst, dst, imm), ctx);
868 		break;
869 
870 	/* JUMP off */
871 	case BPF_JMP | BPF_JA:
872 		jmp_offset = bpf2a64_offset(i, off, ctx);
873 		check_imm26(jmp_offset);
874 		emit(A64_B(jmp_offset), ctx);
875 		break;
876 	/* IF (dst COND src) JUMP off */
877 	case BPF_JMP | BPF_JEQ | BPF_X:
878 	case BPF_JMP | BPF_JGT | BPF_X:
879 	case BPF_JMP | BPF_JLT | BPF_X:
880 	case BPF_JMP | BPF_JGE | BPF_X:
881 	case BPF_JMP | BPF_JLE | BPF_X:
882 	case BPF_JMP | BPF_JNE | BPF_X:
883 	case BPF_JMP | BPF_JSGT | BPF_X:
884 	case BPF_JMP | BPF_JSLT | BPF_X:
885 	case BPF_JMP | BPF_JSGE | BPF_X:
886 	case BPF_JMP | BPF_JSLE | BPF_X:
887 	case BPF_JMP32 | BPF_JEQ | BPF_X:
888 	case BPF_JMP32 | BPF_JGT | BPF_X:
889 	case BPF_JMP32 | BPF_JLT | BPF_X:
890 	case BPF_JMP32 | BPF_JGE | BPF_X:
891 	case BPF_JMP32 | BPF_JLE | BPF_X:
892 	case BPF_JMP32 | BPF_JNE | BPF_X:
893 	case BPF_JMP32 | BPF_JSGT | BPF_X:
894 	case BPF_JMP32 | BPF_JSLT | BPF_X:
895 	case BPF_JMP32 | BPF_JSGE | BPF_X:
896 	case BPF_JMP32 | BPF_JSLE | BPF_X:
897 		emit(A64_CMP(is64, dst, src), ctx);
898 emit_cond_jmp:
899 		jmp_offset = bpf2a64_offset(i, off, ctx);
900 		check_imm19(jmp_offset);
901 		switch (BPF_OP(code)) {
902 		case BPF_JEQ:
903 			jmp_cond = A64_COND_EQ;
904 			break;
905 		case BPF_JGT:
906 			jmp_cond = A64_COND_HI;
907 			break;
908 		case BPF_JLT:
909 			jmp_cond = A64_COND_CC;
910 			break;
911 		case BPF_JGE:
912 			jmp_cond = A64_COND_CS;
913 			break;
914 		case BPF_JLE:
915 			jmp_cond = A64_COND_LS;
916 			break;
917 		case BPF_JSET:
918 		case BPF_JNE:
919 			jmp_cond = A64_COND_NE;
920 			break;
921 		case BPF_JSGT:
922 			jmp_cond = A64_COND_GT;
923 			break;
924 		case BPF_JSLT:
925 			jmp_cond = A64_COND_LT;
926 			break;
927 		case BPF_JSGE:
928 			jmp_cond = A64_COND_GE;
929 			break;
930 		case BPF_JSLE:
931 			jmp_cond = A64_COND_LE;
932 			break;
933 		default:
934 			return -EFAULT;
935 		}
936 		emit(A64_B_(jmp_cond, jmp_offset), ctx);
937 		break;
938 	case BPF_JMP | BPF_JSET | BPF_X:
939 	case BPF_JMP32 | BPF_JSET | BPF_X:
940 		emit(A64_TST(is64, dst, src), ctx);
941 		goto emit_cond_jmp;
942 	/* IF (dst COND imm) JUMP off */
943 	case BPF_JMP | BPF_JEQ | BPF_K:
944 	case BPF_JMP | BPF_JGT | BPF_K:
945 	case BPF_JMP | BPF_JLT | BPF_K:
946 	case BPF_JMP | BPF_JGE | BPF_K:
947 	case BPF_JMP | BPF_JLE | BPF_K:
948 	case BPF_JMP | BPF_JNE | BPF_K:
949 	case BPF_JMP | BPF_JSGT | BPF_K:
950 	case BPF_JMP | BPF_JSLT | BPF_K:
951 	case BPF_JMP | BPF_JSGE | BPF_K:
952 	case BPF_JMP | BPF_JSLE | BPF_K:
953 	case BPF_JMP32 | BPF_JEQ | BPF_K:
954 	case BPF_JMP32 | BPF_JGT | BPF_K:
955 	case BPF_JMP32 | BPF_JLT | BPF_K:
956 	case BPF_JMP32 | BPF_JGE | BPF_K:
957 	case BPF_JMP32 | BPF_JLE | BPF_K:
958 	case BPF_JMP32 | BPF_JNE | BPF_K:
959 	case BPF_JMP32 | BPF_JSGT | BPF_K:
960 	case BPF_JMP32 | BPF_JSLT | BPF_K:
961 	case BPF_JMP32 | BPF_JSGE | BPF_K:
962 	case BPF_JMP32 | BPF_JSLE | BPF_K:
963 		if (is_addsub_imm(imm)) {
964 			emit(A64_CMP_I(is64, dst, imm), ctx);
965 		} else if (is_addsub_imm(-imm)) {
966 			emit(A64_CMN_I(is64, dst, -imm), ctx);
967 		} else {
968 			emit_a64_mov_i(is64, tmp, imm, ctx);
969 			emit(A64_CMP(is64, dst, tmp), ctx);
970 		}
971 		goto emit_cond_jmp;
972 	case BPF_JMP | BPF_JSET | BPF_K:
973 	case BPF_JMP32 | BPF_JSET | BPF_K:
974 		a64_insn = A64_TST_I(is64, dst, imm);
975 		if (a64_insn != AARCH64_BREAK_FAULT) {
976 			emit(a64_insn, ctx);
977 		} else {
978 			emit_a64_mov_i(is64, tmp, imm, ctx);
979 			emit(A64_TST(is64, dst, tmp), ctx);
980 		}
981 		goto emit_cond_jmp;
982 	/* function call */
983 	case BPF_JMP | BPF_CALL:
984 	{
985 		const u8 r0 = bpf2a64[BPF_REG_0];
986 		bool func_addr_fixed;
987 		u64 func_addr;
988 
989 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
990 					    &func_addr, &func_addr_fixed);
991 		if (ret < 0)
992 			return ret;
993 		emit_addr_mov_i64(tmp, func_addr, ctx);
994 		emit(A64_BLR(tmp), ctx);
995 		emit(A64_MOV(1, r0, A64_R(0)), ctx);
996 		break;
997 	}
998 	/* tail call */
999 	case BPF_JMP | BPF_TAIL_CALL:
1000 		if (emit_bpf_tail_call(ctx))
1001 			return -EFAULT;
1002 		break;
1003 	/* function return */
1004 	case BPF_JMP | BPF_EXIT:
1005 		/* Optimization: when last instruction is EXIT,
1006 		   simply fallthrough to epilogue. */
1007 		if (i == ctx->prog->len - 1)
1008 			break;
1009 		jmp_offset = epilogue_offset(ctx);
1010 		check_imm26(jmp_offset);
1011 		emit(A64_B(jmp_offset), ctx);
1012 		break;
1013 
1014 	/* dst = imm64 */
1015 	case BPF_LD | BPF_IMM | BPF_DW:
1016 	{
1017 		const struct bpf_insn insn1 = insn[1];
1018 		u64 imm64;
1019 
1020 		imm64 = (u64)insn1.imm << 32 | (u32)imm;
1021 		if (bpf_pseudo_func(insn))
1022 			emit_addr_mov_i64(dst, imm64, ctx);
1023 		else
1024 			emit_a64_mov_i64(dst, imm64, ctx);
1025 
1026 		return 1;
1027 	}
1028 
1029 	/* LDX: dst = *(size *)(src + off) */
1030 	case BPF_LDX | BPF_MEM | BPF_W:
1031 	case BPF_LDX | BPF_MEM | BPF_H:
1032 	case BPF_LDX | BPF_MEM | BPF_B:
1033 	case BPF_LDX | BPF_MEM | BPF_DW:
1034 	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1035 	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1036 	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1037 	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1038 		if (ctx->fpb_offset > 0 && src == fp) {
1039 			src_adj = fpb;
1040 			off_adj = off + ctx->fpb_offset;
1041 		} else {
1042 			src_adj = src;
1043 			off_adj = off;
1044 		}
1045 		switch (BPF_SIZE(code)) {
1046 		case BPF_W:
1047 			if (is_lsi_offset(off_adj, 2)) {
1048 				emit(A64_LDR32I(dst, src_adj, off_adj), ctx);
1049 			} else {
1050 				emit_a64_mov_i(1, tmp, off, ctx);
1051 				emit(A64_LDR32(dst, src, tmp), ctx);
1052 			}
1053 			break;
1054 		case BPF_H:
1055 			if (is_lsi_offset(off_adj, 1)) {
1056 				emit(A64_LDRHI(dst, src_adj, off_adj), ctx);
1057 			} else {
1058 				emit_a64_mov_i(1, tmp, off, ctx);
1059 				emit(A64_LDRH(dst, src, tmp), ctx);
1060 			}
1061 			break;
1062 		case BPF_B:
1063 			if (is_lsi_offset(off_adj, 0)) {
1064 				emit(A64_LDRBI(dst, src_adj, off_adj), ctx);
1065 			} else {
1066 				emit_a64_mov_i(1, tmp, off, ctx);
1067 				emit(A64_LDRB(dst, src, tmp), ctx);
1068 			}
1069 			break;
1070 		case BPF_DW:
1071 			if (is_lsi_offset(off_adj, 3)) {
1072 				emit(A64_LDR64I(dst, src_adj, off_adj), ctx);
1073 			} else {
1074 				emit_a64_mov_i(1, tmp, off, ctx);
1075 				emit(A64_LDR64(dst, src, tmp), ctx);
1076 			}
1077 			break;
1078 		}
1079 
1080 		ret = add_exception_handler(insn, ctx, dst);
1081 		if (ret)
1082 			return ret;
1083 		break;
1084 
1085 	/* speculation barrier */
1086 	case BPF_ST | BPF_NOSPEC:
1087 		/*
1088 		 * Nothing required here.
1089 		 *
1090 		 * In case of arm64, we rely on the firmware mitigation of
1091 		 * Speculative Store Bypass as controlled via the ssbd kernel
1092 		 * parameter. Whenever the mitigation is enabled, it works
1093 		 * for all of the kernel code with no need to provide any
1094 		 * additional instructions.
1095 		 */
1096 		break;
1097 
1098 	/* ST: *(size *)(dst + off) = imm */
1099 	case BPF_ST | BPF_MEM | BPF_W:
1100 	case BPF_ST | BPF_MEM | BPF_H:
1101 	case BPF_ST | BPF_MEM | BPF_B:
1102 	case BPF_ST | BPF_MEM | BPF_DW:
1103 		if (ctx->fpb_offset > 0 && dst == fp) {
1104 			dst_adj = fpb;
1105 			off_adj = off + ctx->fpb_offset;
1106 		} else {
1107 			dst_adj = dst;
1108 			off_adj = off;
1109 		}
1110 		/* Load imm to a register then store it */
1111 		emit_a64_mov_i(1, tmp, imm, ctx);
1112 		switch (BPF_SIZE(code)) {
1113 		case BPF_W:
1114 			if (is_lsi_offset(off_adj, 2)) {
1115 				emit(A64_STR32I(tmp, dst_adj, off_adj), ctx);
1116 			} else {
1117 				emit_a64_mov_i(1, tmp2, off, ctx);
1118 				emit(A64_STR32(tmp, dst, tmp2), ctx);
1119 			}
1120 			break;
1121 		case BPF_H:
1122 			if (is_lsi_offset(off_adj, 1)) {
1123 				emit(A64_STRHI(tmp, dst_adj, off_adj), ctx);
1124 			} else {
1125 				emit_a64_mov_i(1, tmp2, off, ctx);
1126 				emit(A64_STRH(tmp, dst, tmp2), ctx);
1127 			}
1128 			break;
1129 		case BPF_B:
1130 			if (is_lsi_offset(off_adj, 0)) {
1131 				emit(A64_STRBI(tmp, dst_adj, off_adj), ctx);
1132 			} else {
1133 				emit_a64_mov_i(1, tmp2, off, ctx);
1134 				emit(A64_STRB(tmp, dst, tmp2), ctx);
1135 			}
1136 			break;
1137 		case BPF_DW:
1138 			if (is_lsi_offset(off_adj, 3)) {
1139 				emit(A64_STR64I(tmp, dst_adj, off_adj), ctx);
1140 			} else {
1141 				emit_a64_mov_i(1, tmp2, off, ctx);
1142 				emit(A64_STR64(tmp, dst, tmp2), ctx);
1143 			}
1144 			break;
1145 		}
1146 		break;
1147 
1148 	/* STX: *(size *)(dst + off) = src */
1149 	case BPF_STX | BPF_MEM | BPF_W:
1150 	case BPF_STX | BPF_MEM | BPF_H:
1151 	case BPF_STX | BPF_MEM | BPF_B:
1152 	case BPF_STX | BPF_MEM | BPF_DW:
1153 		if (ctx->fpb_offset > 0 && dst == fp) {
1154 			dst_adj = fpb;
1155 			off_adj = off + ctx->fpb_offset;
1156 		} else {
1157 			dst_adj = dst;
1158 			off_adj = off;
1159 		}
1160 		switch (BPF_SIZE(code)) {
1161 		case BPF_W:
1162 			if (is_lsi_offset(off_adj, 2)) {
1163 				emit(A64_STR32I(src, dst_adj, off_adj), ctx);
1164 			} else {
1165 				emit_a64_mov_i(1, tmp, off, ctx);
1166 				emit(A64_STR32(src, dst, tmp), ctx);
1167 			}
1168 			break;
1169 		case BPF_H:
1170 			if (is_lsi_offset(off_adj, 1)) {
1171 				emit(A64_STRHI(src, dst_adj, off_adj), ctx);
1172 			} else {
1173 				emit_a64_mov_i(1, tmp, off, ctx);
1174 				emit(A64_STRH(src, dst, tmp), ctx);
1175 			}
1176 			break;
1177 		case BPF_B:
1178 			if (is_lsi_offset(off_adj, 0)) {
1179 				emit(A64_STRBI(src, dst_adj, off_adj), ctx);
1180 			} else {
1181 				emit_a64_mov_i(1, tmp, off, ctx);
1182 				emit(A64_STRB(src, dst, tmp), ctx);
1183 			}
1184 			break;
1185 		case BPF_DW:
1186 			if (is_lsi_offset(off_adj, 3)) {
1187 				emit(A64_STR64I(src, dst_adj, off_adj), ctx);
1188 			} else {
1189 				emit_a64_mov_i(1, tmp, off, ctx);
1190 				emit(A64_STR64(src, dst, tmp), ctx);
1191 			}
1192 			break;
1193 		}
1194 		break;
1195 
1196 	case BPF_STX | BPF_ATOMIC | BPF_W:
1197 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1198 		if (cpus_have_cap(ARM64_HAS_LSE_ATOMICS))
1199 			ret = emit_lse_atomic(insn, ctx);
1200 		else
1201 			ret = emit_ll_sc_atomic(insn, ctx);
1202 		if (ret)
1203 			return ret;
1204 		break;
1205 
1206 	default:
1207 		pr_err_once("unknown opcode %02x\n", code);
1208 		return -EINVAL;
1209 	}
1210 
1211 	return 0;
1212 }
1213 
1214 /*
1215  * Return 0 if FP may change at runtime, otherwise find the minimum negative
1216  * offset to FP, converts it to positive number, and align down to 8 bytes.
1217  */
1218 static int find_fpb_offset(struct bpf_prog *prog)
1219 {
1220 	int i;
1221 	int offset = 0;
1222 
1223 	for (i = 0; i < prog->len; i++) {
1224 		const struct bpf_insn *insn = &prog->insnsi[i];
1225 		const u8 class = BPF_CLASS(insn->code);
1226 		const u8 mode = BPF_MODE(insn->code);
1227 		const u8 src = insn->src_reg;
1228 		const u8 dst = insn->dst_reg;
1229 		const s32 imm = insn->imm;
1230 		const s16 off = insn->off;
1231 
1232 		switch (class) {
1233 		case BPF_STX:
1234 		case BPF_ST:
1235 			/* fp holds atomic operation result */
1236 			if (class == BPF_STX && mode == BPF_ATOMIC &&
1237 			    ((imm == BPF_XCHG ||
1238 			      imm == (BPF_FETCH | BPF_ADD) ||
1239 			      imm == (BPF_FETCH | BPF_AND) ||
1240 			      imm == (BPF_FETCH | BPF_XOR) ||
1241 			      imm == (BPF_FETCH | BPF_OR)) &&
1242 			     src == BPF_REG_FP))
1243 				return 0;
1244 
1245 			if (mode == BPF_MEM && dst == BPF_REG_FP &&
1246 			    off < offset)
1247 				offset = insn->off;
1248 			break;
1249 
1250 		case BPF_JMP32:
1251 		case BPF_JMP:
1252 			break;
1253 
1254 		case BPF_LDX:
1255 		case BPF_LD:
1256 			/* fp holds load result */
1257 			if (dst == BPF_REG_FP)
1258 				return 0;
1259 
1260 			if (class == BPF_LDX && mode == BPF_MEM &&
1261 			    src == BPF_REG_FP && off < offset)
1262 				offset = off;
1263 			break;
1264 
1265 		case BPF_ALU:
1266 		case BPF_ALU64:
1267 		default:
1268 			/* fp holds ALU result */
1269 			if (dst == BPF_REG_FP)
1270 				return 0;
1271 		}
1272 	}
1273 
1274 	if (offset < 0) {
1275 		/*
1276 		 * safely be converted to a positive 'int', since insn->off
1277 		 * is 's16'
1278 		 */
1279 		offset = -offset;
1280 		/* align down to 8 bytes */
1281 		offset = ALIGN_DOWN(offset, 8);
1282 	}
1283 
1284 	return offset;
1285 }
1286 
1287 static int build_body(struct jit_ctx *ctx, bool extra_pass)
1288 {
1289 	const struct bpf_prog *prog = ctx->prog;
1290 	int i;
1291 
1292 	/*
1293 	 * - offset[0] offset of the end of prologue,
1294 	 *   start of the 1st instruction.
1295 	 * - offset[1] - offset of the end of 1st instruction,
1296 	 *   start of the 2nd instruction
1297 	 * [....]
1298 	 * - offset[3] - offset of the end of 3rd instruction,
1299 	 *   start of 4th instruction
1300 	 */
1301 	for (i = 0; i < prog->len; i++) {
1302 		const struct bpf_insn *insn = &prog->insnsi[i];
1303 		int ret;
1304 
1305 		if (ctx->image == NULL)
1306 			ctx->offset[i] = ctx->idx;
1307 		ret = build_insn(insn, ctx, extra_pass);
1308 		if (ret > 0) {
1309 			i++;
1310 			if (ctx->image == NULL)
1311 				ctx->offset[i] = ctx->idx;
1312 			continue;
1313 		}
1314 		if (ret)
1315 			return ret;
1316 	}
1317 	/*
1318 	 * offset is allocated with prog->len + 1 so fill in
1319 	 * the last element with the offset after the last
1320 	 * instruction (end of program)
1321 	 */
1322 	if (ctx->image == NULL)
1323 		ctx->offset[i] = ctx->idx;
1324 
1325 	return 0;
1326 }
1327 
1328 static int validate_code(struct jit_ctx *ctx)
1329 {
1330 	int i;
1331 
1332 	for (i = 0; i < ctx->idx; i++) {
1333 		u32 a64_insn = le32_to_cpu(ctx->image[i]);
1334 
1335 		if (a64_insn == AARCH64_BREAK_FAULT)
1336 			return -1;
1337 	}
1338 
1339 	if (WARN_ON_ONCE(ctx->exentry_idx != ctx->prog->aux->num_exentries))
1340 		return -1;
1341 
1342 	return 0;
1343 }
1344 
1345 static inline void bpf_flush_icache(void *start, void *end)
1346 {
1347 	flush_icache_range((unsigned long)start, (unsigned long)end);
1348 }
1349 
1350 struct arm64_jit_data {
1351 	struct bpf_binary_header *header;
1352 	u8 *image;
1353 	struct jit_ctx ctx;
1354 };
1355 
1356 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1357 {
1358 	int image_size, prog_size, extable_size;
1359 	struct bpf_prog *tmp, *orig_prog = prog;
1360 	struct bpf_binary_header *header;
1361 	struct arm64_jit_data *jit_data;
1362 	bool was_classic = bpf_prog_was_classic(prog);
1363 	bool tmp_blinded = false;
1364 	bool extra_pass = false;
1365 	struct jit_ctx ctx;
1366 	u8 *image_ptr;
1367 
1368 	if (!prog->jit_requested)
1369 		return orig_prog;
1370 
1371 	tmp = bpf_jit_blind_constants(prog);
1372 	/* If blinding was requested and we failed during blinding,
1373 	 * we must fall back to the interpreter.
1374 	 */
1375 	if (IS_ERR(tmp))
1376 		return orig_prog;
1377 	if (tmp != prog) {
1378 		tmp_blinded = true;
1379 		prog = tmp;
1380 	}
1381 
1382 	jit_data = prog->aux->jit_data;
1383 	if (!jit_data) {
1384 		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1385 		if (!jit_data) {
1386 			prog = orig_prog;
1387 			goto out;
1388 		}
1389 		prog->aux->jit_data = jit_data;
1390 	}
1391 	if (jit_data->ctx.offset) {
1392 		ctx = jit_data->ctx;
1393 		image_ptr = jit_data->image;
1394 		header = jit_data->header;
1395 		extra_pass = true;
1396 		prog_size = sizeof(u32) * ctx.idx;
1397 		goto skip_init_ctx;
1398 	}
1399 	memset(&ctx, 0, sizeof(ctx));
1400 	ctx.prog = prog;
1401 
1402 	ctx.offset = kcalloc(prog->len + 1, sizeof(int), GFP_KERNEL);
1403 	if (ctx.offset == NULL) {
1404 		prog = orig_prog;
1405 		goto out_off;
1406 	}
1407 
1408 	ctx.fpb_offset = find_fpb_offset(prog);
1409 
1410 	/*
1411 	 * 1. Initial fake pass to compute ctx->idx and ctx->offset.
1412 	 *
1413 	 * BPF line info needs ctx->offset[i] to be the offset of
1414 	 * instruction[i] in jited image, so build prologue first.
1415 	 */
1416 	if (build_prologue(&ctx, was_classic)) {
1417 		prog = orig_prog;
1418 		goto out_off;
1419 	}
1420 
1421 	if (build_body(&ctx, extra_pass)) {
1422 		prog = orig_prog;
1423 		goto out_off;
1424 	}
1425 
1426 	ctx.epilogue_offset = ctx.idx;
1427 	build_epilogue(&ctx);
1428 
1429 	extable_size = prog->aux->num_exentries *
1430 		sizeof(struct exception_table_entry);
1431 
1432 	/* Now we know the actual image size. */
1433 	prog_size = sizeof(u32) * ctx.idx;
1434 	image_size = prog_size + extable_size;
1435 	header = bpf_jit_binary_alloc(image_size, &image_ptr,
1436 				      sizeof(u32), jit_fill_hole);
1437 	if (header == NULL) {
1438 		prog = orig_prog;
1439 		goto out_off;
1440 	}
1441 
1442 	/* 2. Now, the actual pass. */
1443 
1444 	ctx.image = (__le32 *)image_ptr;
1445 	if (extable_size)
1446 		prog->aux->extable = (void *)image_ptr + prog_size;
1447 skip_init_ctx:
1448 	ctx.idx = 0;
1449 	ctx.exentry_idx = 0;
1450 
1451 	build_prologue(&ctx, was_classic);
1452 
1453 	if (build_body(&ctx, extra_pass)) {
1454 		bpf_jit_binary_free(header);
1455 		prog = orig_prog;
1456 		goto out_off;
1457 	}
1458 
1459 	build_epilogue(&ctx);
1460 
1461 	/* 3. Extra pass to validate JITed code. */
1462 	if (validate_code(&ctx)) {
1463 		bpf_jit_binary_free(header);
1464 		prog = orig_prog;
1465 		goto out_off;
1466 	}
1467 
1468 	/* And we're done. */
1469 	if (bpf_jit_enable > 1)
1470 		bpf_jit_dump(prog->len, prog_size, 2, ctx.image);
1471 
1472 	bpf_flush_icache(header, ctx.image + ctx.idx);
1473 
1474 	if (!prog->is_func || extra_pass) {
1475 		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
1476 			pr_err_once("multi-func JIT bug %d != %d\n",
1477 				    ctx.idx, jit_data->ctx.idx);
1478 			bpf_jit_binary_free(header);
1479 			prog->bpf_func = NULL;
1480 			prog->jited = 0;
1481 			goto out_off;
1482 		}
1483 		bpf_jit_binary_lock_ro(header);
1484 	} else {
1485 		jit_data->ctx = ctx;
1486 		jit_data->image = image_ptr;
1487 		jit_data->header = header;
1488 	}
1489 	prog->bpf_func = (void *)ctx.image;
1490 	prog->jited = 1;
1491 	prog->jited_len = prog_size;
1492 
1493 	if (!prog->is_func || extra_pass) {
1494 		int i;
1495 
1496 		/* offset[prog->len] is the size of program */
1497 		for (i = 0; i <= prog->len; i++)
1498 			ctx.offset[i] *= AARCH64_INSN_SIZE;
1499 		bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
1500 out_off:
1501 		kfree(ctx.offset);
1502 		kfree(jit_data);
1503 		prog->aux->jit_data = NULL;
1504 	}
1505 out:
1506 	if (tmp_blinded)
1507 		bpf_jit_prog_release_other(prog, prog == orig_prog ?
1508 					   tmp : orig_prog);
1509 	return prog;
1510 }
1511 
1512 bool bpf_jit_supports_kfunc_call(void)
1513 {
1514 	return true;
1515 }
1516 
1517 u64 bpf_jit_alloc_exec_limit(void)
1518 {
1519 	return VMALLOC_END - VMALLOC_START;
1520 }
1521 
1522 void *bpf_jit_alloc_exec(unsigned long size)
1523 {
1524 	/* Memory is intended to be executable, reset the pointer tag. */
1525 	return kasan_reset_tag(vmalloc(size));
1526 }
1527 
1528 void bpf_jit_free_exec(void *addr)
1529 {
1530 	return vfree(addr);
1531 }
1532