xref: /openbmc/linux/arch/loongarch/net/bpf_jit.c (revision ed4543328f7108e1047b83b96ca7f7208747d930)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler for LoongArch
4  *
5  * Copyright (C) 2022 Loongson Technology Corporation Limited
6  */
7 #include "bpf_jit.h"
8 
9 #define REG_TCC		LOONGARCH_GPR_A6
10 #define TCC_SAVED	LOONGARCH_GPR_S5
11 
12 #define SAVE_RA		BIT(0)
13 #define SAVE_TCC	BIT(1)
14 
15 static const int regmap[] = {
16 	/* return value from in-kernel function, and exit value for eBPF program */
17 	[BPF_REG_0] = LOONGARCH_GPR_A5,
18 	/* arguments from eBPF program to in-kernel function */
19 	[BPF_REG_1] = LOONGARCH_GPR_A0,
20 	[BPF_REG_2] = LOONGARCH_GPR_A1,
21 	[BPF_REG_3] = LOONGARCH_GPR_A2,
22 	[BPF_REG_4] = LOONGARCH_GPR_A3,
23 	[BPF_REG_5] = LOONGARCH_GPR_A4,
24 	/* callee saved registers that in-kernel function will preserve */
25 	[BPF_REG_6] = LOONGARCH_GPR_S0,
26 	[BPF_REG_7] = LOONGARCH_GPR_S1,
27 	[BPF_REG_8] = LOONGARCH_GPR_S2,
28 	[BPF_REG_9] = LOONGARCH_GPR_S3,
29 	/* read-only frame pointer to access stack */
30 	[BPF_REG_FP] = LOONGARCH_GPR_S4,
31 	/* temporary register for blinding constants */
32 	[BPF_REG_AX] = LOONGARCH_GPR_T0,
33 };
34 
mark_call(struct jit_ctx * ctx)35 static void mark_call(struct jit_ctx *ctx)
36 {
37 	ctx->flags |= SAVE_RA;
38 }
39 
mark_tail_call(struct jit_ctx * ctx)40 static void mark_tail_call(struct jit_ctx *ctx)
41 {
42 	ctx->flags |= SAVE_TCC;
43 }
44 
seen_call(struct jit_ctx * ctx)45 static bool seen_call(struct jit_ctx *ctx)
46 {
47 	return (ctx->flags & SAVE_RA);
48 }
49 
seen_tail_call(struct jit_ctx * ctx)50 static bool seen_tail_call(struct jit_ctx *ctx)
51 {
52 	return (ctx->flags & SAVE_TCC);
53 }
54 
tail_call_reg(struct jit_ctx * ctx)55 static u8 tail_call_reg(struct jit_ctx *ctx)
56 {
57 	if (seen_call(ctx))
58 		return TCC_SAVED;
59 
60 	return REG_TCC;
61 }
62 
63 /*
64  * eBPF prog stack layout:
65  *
66  *                                        high
67  * original $sp ------------> +-------------------------+ <--LOONGARCH_GPR_FP
68  *                            |           $ra           |
69  *                            +-------------------------+
70  *                            |           $fp           |
71  *                            +-------------------------+
72  *                            |           $s0           |
73  *                            +-------------------------+
74  *                            |           $s1           |
75  *                            +-------------------------+
76  *                            |           $s2           |
77  *                            +-------------------------+
78  *                            |           $s3           |
79  *                            +-------------------------+
80  *                            |           $s4           |
81  *                            +-------------------------+
82  *                            |           $s5           |
83  *                            +-------------------------+ <--BPF_REG_FP
84  *                            |  prog->aux->stack_depth |
85  *                            |        (optional)       |
86  * current $sp -------------> +-------------------------+
87  *                                        low
88  */
build_prologue(struct jit_ctx * ctx)89 static void build_prologue(struct jit_ctx *ctx)
90 {
91 	int stack_adjust = 0, store_offset, bpf_stack_adjust;
92 
93 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
94 
95 	/* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
96 	stack_adjust += sizeof(long) * 8;
97 
98 	stack_adjust = round_up(stack_adjust, 16);
99 	stack_adjust += bpf_stack_adjust;
100 
101 	/*
102 	 * First instruction initializes the tail call count (TCC).
103 	 * On tail call we skip this instruction, and the TCC is
104 	 * passed in REG_TCC from the caller.
105 	 */
106 	emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
107 
108 	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
109 
110 	store_offset = stack_adjust - sizeof(long);
111 	emit_insn(ctx, std, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, store_offset);
112 
113 	store_offset -= sizeof(long);
114 	emit_insn(ctx, std, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, store_offset);
115 
116 	store_offset -= sizeof(long);
117 	emit_insn(ctx, std, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, store_offset);
118 
119 	store_offset -= sizeof(long);
120 	emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, store_offset);
121 
122 	store_offset -= sizeof(long);
123 	emit_insn(ctx, std, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, store_offset);
124 
125 	store_offset -= sizeof(long);
126 	emit_insn(ctx, std, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, store_offset);
127 
128 	store_offset -= sizeof(long);
129 	emit_insn(ctx, std, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, store_offset);
130 
131 	store_offset -= sizeof(long);
132 	emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
133 
134 	emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
135 
136 	if (bpf_stack_adjust)
137 		emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
138 
139 	/*
140 	 * Program contains calls and tail calls, so REG_TCC need
141 	 * to be saved across calls.
142 	 */
143 	if (seen_tail_call(ctx) && seen_call(ctx))
144 		move_reg(ctx, TCC_SAVED, REG_TCC);
145 
146 	ctx->stack_size = stack_adjust;
147 }
148 
__build_epilogue(struct jit_ctx * ctx,bool is_tail_call)149 static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
150 {
151 	int stack_adjust = ctx->stack_size;
152 	int load_offset;
153 
154 	load_offset = stack_adjust - sizeof(long);
155 	emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, load_offset);
156 
157 	load_offset -= sizeof(long);
158 	emit_insn(ctx, ldd, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, load_offset);
159 
160 	load_offset -= sizeof(long);
161 	emit_insn(ctx, ldd, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, load_offset);
162 
163 	load_offset -= sizeof(long);
164 	emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, load_offset);
165 
166 	load_offset -= sizeof(long);
167 	emit_insn(ctx, ldd, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, load_offset);
168 
169 	load_offset -= sizeof(long);
170 	emit_insn(ctx, ldd, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, load_offset);
171 
172 	load_offset -= sizeof(long);
173 	emit_insn(ctx, ldd, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, load_offset);
174 
175 	load_offset -= sizeof(long);
176 	emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
177 
178 	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
179 
180 	if (!is_tail_call) {
181 		/* Set return value */
182 		emit_insn(ctx, addiw, LOONGARCH_GPR_A0, regmap[BPF_REG_0], 0);
183 		/* Return to the caller */
184 		emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_RA, 0);
185 	} else {
186 		/*
187 		 * Call the next bpf prog and skip the first instruction
188 		 * of TCC initialization.
189 		 */
190 		emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 1);
191 	}
192 }
193 
build_epilogue(struct jit_ctx * ctx)194 static void build_epilogue(struct jit_ctx *ctx)
195 {
196 	__build_epilogue(ctx, false);
197 }
198 
bpf_jit_supports_kfunc_call(void)199 bool bpf_jit_supports_kfunc_call(void)
200 {
201 	return true;
202 }
203 
204 /* initialized on the first pass of build_body() */
205 static int out_offset = -1;
emit_bpf_tail_call(struct jit_ctx * ctx)206 static int emit_bpf_tail_call(struct jit_ctx *ctx)
207 {
208 	int off;
209 	u8 tcc = tail_call_reg(ctx);
210 	u8 a1 = LOONGARCH_GPR_A1;
211 	u8 a2 = LOONGARCH_GPR_A2;
212 	u8 t1 = LOONGARCH_GPR_T1;
213 	u8 t2 = LOONGARCH_GPR_T2;
214 	u8 t3 = LOONGARCH_GPR_T3;
215 	const int idx0 = ctx->idx;
216 
217 #define cur_offset (ctx->idx - idx0)
218 #define jmp_offset (out_offset - (cur_offset))
219 
220 	/*
221 	 * a0: &ctx
222 	 * a1: &array
223 	 * a2: index
224 	 *
225 	 * if (index >= array->map.max_entries)
226 	 *	 goto out;
227 	 */
228 	off = offsetof(struct bpf_array, map.max_entries);
229 	emit_insn(ctx, ldwu, t1, a1, off);
230 	/* bgeu $a2, $t1, jmp_offset */
231 	if (emit_tailcall_jmp(ctx, BPF_JGE, a2, t1, jmp_offset) < 0)
232 		goto toofar;
233 
234 	/*
235 	 * if (--TCC < 0)
236 	 *	 goto out;
237 	 */
238 	emit_insn(ctx, addid, REG_TCC, tcc, -1);
239 	if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
240 		goto toofar;
241 
242 	/*
243 	 * prog = array->ptrs[index];
244 	 * if (!prog)
245 	 *	 goto out;
246 	 */
247 	emit_insn(ctx, alsld, t2, a2, a1, 2);
248 	off = offsetof(struct bpf_array, ptrs);
249 	emit_insn(ctx, ldd, t2, t2, off);
250 	/* beq $t2, $zero, jmp_offset */
251 	if (emit_tailcall_jmp(ctx, BPF_JEQ, t2, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
252 		goto toofar;
253 
254 	/* goto *(prog->bpf_func + 4); */
255 	off = offsetof(struct bpf_prog, bpf_func);
256 	emit_insn(ctx, ldd, t3, t2, off);
257 	__build_epilogue(ctx, true);
258 
259 	/* out: */
260 	if (out_offset == -1)
261 		out_offset = cur_offset;
262 	if (cur_offset != out_offset) {
263 		pr_err_once("tail_call out_offset = %d, expected %d!\n",
264 			    cur_offset, out_offset);
265 		return -1;
266 	}
267 
268 	return 0;
269 
270 toofar:
271 	pr_info_once("tail_call: jump too far\n");
272 	return -1;
273 #undef cur_offset
274 #undef jmp_offset
275 }
276 
emit_atomic(const struct bpf_insn * insn,struct jit_ctx * ctx)277 static void emit_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
278 {
279 	const u8 t1 = LOONGARCH_GPR_T1;
280 	const u8 t2 = LOONGARCH_GPR_T2;
281 	const u8 t3 = LOONGARCH_GPR_T3;
282 	const u8 r0 = regmap[BPF_REG_0];
283 	const u8 src = regmap[insn->src_reg];
284 	const u8 dst = regmap[insn->dst_reg];
285 	const s16 off = insn->off;
286 	const s32 imm = insn->imm;
287 	const bool isdw = BPF_SIZE(insn->code) == BPF_DW;
288 
289 	move_imm(ctx, t1, off, false);
290 	emit_insn(ctx, addd, t1, dst, t1);
291 	move_reg(ctx, t3, src);
292 
293 	switch (imm) {
294 	/* lock *(size *)(dst + off) <op>= src */
295 	case BPF_ADD:
296 		if (isdw)
297 			emit_insn(ctx, amaddd, t2, t1, src);
298 		else
299 			emit_insn(ctx, amaddw, t2, t1, src);
300 		break;
301 	case BPF_AND:
302 		if (isdw)
303 			emit_insn(ctx, amandd, t2, t1, src);
304 		else
305 			emit_insn(ctx, amandw, t2, t1, src);
306 		break;
307 	case BPF_OR:
308 		if (isdw)
309 			emit_insn(ctx, amord, t2, t1, src);
310 		else
311 			emit_insn(ctx, amorw, t2, t1, src);
312 		break;
313 	case BPF_XOR:
314 		if (isdw)
315 			emit_insn(ctx, amxord, t2, t1, src);
316 		else
317 			emit_insn(ctx, amxorw, t2, t1, src);
318 		break;
319 	/* src = atomic_fetch_<op>(dst + off, src) */
320 	case BPF_ADD | BPF_FETCH:
321 		if (isdw) {
322 			emit_insn(ctx, amaddd, src, t1, t3);
323 		} else {
324 			emit_insn(ctx, amaddw, src, t1, t3);
325 			emit_zext_32(ctx, src, true);
326 		}
327 		break;
328 	case BPF_AND | BPF_FETCH:
329 		if (isdw) {
330 			emit_insn(ctx, amandd, src, t1, t3);
331 		} else {
332 			emit_insn(ctx, amandw, src, t1, t3);
333 			emit_zext_32(ctx, src, true);
334 		}
335 		break;
336 	case BPF_OR | BPF_FETCH:
337 		if (isdw) {
338 			emit_insn(ctx, amord, src, t1, t3);
339 		} else {
340 			emit_insn(ctx, amorw, src, t1, t3);
341 			emit_zext_32(ctx, src, true);
342 		}
343 		break;
344 	case BPF_XOR | BPF_FETCH:
345 		if (isdw) {
346 			emit_insn(ctx, amxord, src, t1, t3);
347 		} else {
348 			emit_insn(ctx, amxorw, src, t1, t3);
349 			emit_zext_32(ctx, src, true);
350 		}
351 		break;
352 	/* src = atomic_xchg(dst + off, src); */
353 	case BPF_XCHG:
354 		if (isdw) {
355 			emit_insn(ctx, amswapd, src, t1, t3);
356 		} else {
357 			emit_insn(ctx, amswapw, src, t1, t3);
358 			emit_zext_32(ctx, src, true);
359 		}
360 		break;
361 	/* r0 = atomic_cmpxchg(dst + off, r0, src); */
362 	case BPF_CMPXCHG:
363 		move_reg(ctx, t2, r0);
364 		if (isdw) {
365 			emit_insn(ctx, lld, r0, t1, 0);
366 			emit_insn(ctx, bne, t2, r0, 4);
367 			move_reg(ctx, t3, src);
368 			emit_insn(ctx, scd, t3, t1, 0);
369 			emit_insn(ctx, beq, t3, LOONGARCH_GPR_ZERO, -4);
370 		} else {
371 			emit_insn(ctx, llw, r0, t1, 0);
372 			emit_zext_32(ctx, t2, true);
373 			emit_zext_32(ctx, r0, true);
374 			emit_insn(ctx, bne, t2, r0, 4);
375 			move_reg(ctx, t3, src);
376 			emit_insn(ctx, scw, t3, t1, 0);
377 			emit_insn(ctx, beq, t3, LOONGARCH_GPR_ZERO, -6);
378 			emit_zext_32(ctx, r0, true);
379 		}
380 		break;
381 	}
382 }
383 
is_signed_bpf_cond(u8 cond)384 static bool is_signed_bpf_cond(u8 cond)
385 {
386 	return cond == BPF_JSGT || cond == BPF_JSLT ||
387 	       cond == BPF_JSGE || cond == BPF_JSLE;
388 }
389 
390 #define BPF_FIXUP_REG_MASK	GENMASK(31, 27)
391 #define BPF_FIXUP_OFFSET_MASK	GENMASK(26, 0)
392 
ex_handler_bpf(const struct exception_table_entry * ex,struct pt_regs * regs)393 bool ex_handler_bpf(const struct exception_table_entry *ex,
394 		    struct pt_regs *regs)
395 {
396 	int dst_reg = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
397 	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
398 
399 	regs->regs[dst_reg] = 0;
400 	regs->csr_era = (unsigned long)&ex->fixup - offset;
401 
402 	return true;
403 }
404 
405 /* For accesses to BTF pointers, add an entry to the exception table */
add_exception_handler(const struct bpf_insn * insn,struct jit_ctx * ctx,int dst_reg)406 static int add_exception_handler(const struct bpf_insn *insn,
407 				 struct jit_ctx *ctx,
408 				 int dst_reg)
409 {
410 	unsigned long pc;
411 	off_t offset;
412 	struct exception_table_entry *ex;
413 
414 	if (!ctx->image || !ctx->prog->aux->extable || BPF_MODE(insn->code) != BPF_PROBE_MEM)
415 		return 0;
416 
417 	if (WARN_ON_ONCE(ctx->num_exentries >= ctx->prog->aux->num_exentries))
418 		return -EINVAL;
419 
420 	ex = &ctx->prog->aux->extable[ctx->num_exentries];
421 	pc = (unsigned long)&ctx->image[ctx->idx - 1];
422 
423 	offset = pc - (long)&ex->insn;
424 	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
425 		return -ERANGE;
426 
427 	ex->insn = offset;
428 
429 	/*
430 	 * Since the extable follows the program, the fixup offset is always
431 	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
432 	 * to keep things simple, and put the destination register in the upper
433 	 * bits. We don't need to worry about buildtime or runtime sort
434 	 * modifying the upper bits because the table is already sorted, and
435 	 * isn't part of the main exception table.
436 	 */
437 	offset = (long)&ex->fixup - (pc + LOONGARCH_INSN_SIZE);
438 	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
439 		return -ERANGE;
440 
441 	ex->type = EX_TYPE_BPF;
442 	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) | FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
443 
444 	ctx->num_exentries++;
445 
446 	return 0;
447 }
448 
build_insn(const struct bpf_insn * insn,struct jit_ctx * ctx,bool extra_pass)449 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool extra_pass)
450 {
451 	u8 tm = -1;
452 	u64 func_addr;
453 	bool func_addr_fixed;
454 	int i = insn - ctx->prog->insnsi;
455 	int ret, jmp_offset;
456 	const u8 code = insn->code;
457 	const u8 cond = BPF_OP(code);
458 	const u8 t1 = LOONGARCH_GPR_T1;
459 	const u8 t2 = LOONGARCH_GPR_T2;
460 	const u8 src = regmap[insn->src_reg];
461 	const u8 dst = regmap[insn->dst_reg];
462 	const s16 off = insn->off;
463 	const s32 imm = insn->imm;
464 	const bool is32 = BPF_CLASS(insn->code) == BPF_ALU || BPF_CLASS(insn->code) == BPF_JMP32;
465 
466 	switch (code) {
467 	/* dst = src */
468 	case BPF_ALU | BPF_MOV | BPF_X:
469 	case BPF_ALU64 | BPF_MOV | BPF_X:
470 		move_reg(ctx, dst, src);
471 		emit_zext_32(ctx, dst, is32);
472 		break;
473 
474 	/* dst = imm */
475 	case BPF_ALU | BPF_MOV | BPF_K:
476 	case BPF_ALU64 | BPF_MOV | BPF_K:
477 		move_imm(ctx, dst, imm, is32);
478 		break;
479 
480 	/* dst = dst + src */
481 	case BPF_ALU | BPF_ADD | BPF_X:
482 	case BPF_ALU64 | BPF_ADD | BPF_X:
483 		emit_insn(ctx, addd, dst, dst, src);
484 		emit_zext_32(ctx, dst, is32);
485 		break;
486 
487 	/* dst = dst + imm */
488 	case BPF_ALU | BPF_ADD | BPF_K:
489 	case BPF_ALU64 | BPF_ADD | BPF_K:
490 		if (is_signed_imm12(imm)) {
491 			emit_insn(ctx, addid, dst, dst, imm);
492 		} else {
493 			move_imm(ctx, t1, imm, is32);
494 			emit_insn(ctx, addd, dst, dst, t1);
495 		}
496 		emit_zext_32(ctx, dst, is32);
497 		break;
498 
499 	/* dst = dst - src */
500 	case BPF_ALU | BPF_SUB | BPF_X:
501 	case BPF_ALU64 | BPF_SUB | BPF_X:
502 		emit_insn(ctx, subd, dst, dst, src);
503 		emit_zext_32(ctx, dst, is32);
504 		break;
505 
506 	/* dst = dst - imm */
507 	case BPF_ALU | BPF_SUB | BPF_K:
508 	case BPF_ALU64 | BPF_SUB | BPF_K:
509 		if (is_signed_imm12(-imm)) {
510 			emit_insn(ctx, addid, dst, dst, -imm);
511 		} else {
512 			move_imm(ctx, t1, imm, is32);
513 			emit_insn(ctx, subd, dst, dst, t1);
514 		}
515 		emit_zext_32(ctx, dst, is32);
516 		break;
517 
518 	/* dst = dst * src */
519 	case BPF_ALU | BPF_MUL | BPF_X:
520 	case BPF_ALU64 | BPF_MUL | BPF_X:
521 		emit_insn(ctx, muld, dst, dst, src);
522 		emit_zext_32(ctx, dst, is32);
523 		break;
524 
525 	/* dst = dst * imm */
526 	case BPF_ALU | BPF_MUL | BPF_K:
527 	case BPF_ALU64 | BPF_MUL | BPF_K:
528 		move_imm(ctx, t1, imm, is32);
529 		emit_insn(ctx, muld, dst, dst, t1);
530 		emit_zext_32(ctx, dst, is32);
531 		break;
532 
533 	/* dst = dst / src */
534 	case BPF_ALU | BPF_DIV | BPF_X:
535 	case BPF_ALU64 | BPF_DIV | BPF_X:
536 		emit_zext_32(ctx, dst, is32);
537 		move_reg(ctx, t1, src);
538 		emit_zext_32(ctx, t1, is32);
539 		emit_insn(ctx, divdu, dst, dst, t1);
540 		emit_zext_32(ctx, dst, is32);
541 		break;
542 
543 	/* dst = dst / imm */
544 	case BPF_ALU | BPF_DIV | BPF_K:
545 	case BPF_ALU64 | BPF_DIV | BPF_K:
546 		move_imm(ctx, t1, imm, is32);
547 		emit_zext_32(ctx, dst, is32);
548 		emit_insn(ctx, divdu, dst, dst, t1);
549 		emit_zext_32(ctx, dst, is32);
550 		break;
551 
552 	/* dst = dst % src */
553 	case BPF_ALU | BPF_MOD | BPF_X:
554 	case BPF_ALU64 | BPF_MOD | BPF_X:
555 		emit_zext_32(ctx, dst, is32);
556 		move_reg(ctx, t1, src);
557 		emit_zext_32(ctx, t1, is32);
558 		emit_insn(ctx, moddu, dst, dst, t1);
559 		emit_zext_32(ctx, dst, is32);
560 		break;
561 
562 	/* dst = dst % imm */
563 	case BPF_ALU | BPF_MOD | BPF_K:
564 	case BPF_ALU64 | BPF_MOD | BPF_K:
565 		move_imm(ctx, t1, imm, is32);
566 		emit_zext_32(ctx, dst, is32);
567 		emit_insn(ctx, moddu, dst, dst, t1);
568 		emit_zext_32(ctx, dst, is32);
569 		break;
570 
571 	/* dst = -dst */
572 	case BPF_ALU | BPF_NEG:
573 	case BPF_ALU64 | BPF_NEG:
574 		move_imm(ctx, t1, imm, is32);
575 		emit_insn(ctx, subd, dst, LOONGARCH_GPR_ZERO, dst);
576 		emit_zext_32(ctx, dst, is32);
577 		break;
578 
579 	/* dst = dst & src */
580 	case BPF_ALU | BPF_AND | BPF_X:
581 	case BPF_ALU64 | BPF_AND | BPF_X:
582 		emit_insn(ctx, and, dst, dst, src);
583 		emit_zext_32(ctx, dst, is32);
584 		break;
585 
586 	/* dst = dst & imm */
587 	case BPF_ALU | BPF_AND | BPF_K:
588 	case BPF_ALU64 | BPF_AND | BPF_K:
589 		if (is_unsigned_imm12(imm)) {
590 			emit_insn(ctx, andi, dst, dst, imm);
591 		} else {
592 			move_imm(ctx, t1, imm, is32);
593 			emit_insn(ctx, and, dst, dst, t1);
594 		}
595 		emit_zext_32(ctx, dst, is32);
596 		break;
597 
598 	/* dst = dst | src */
599 	case BPF_ALU | BPF_OR | BPF_X:
600 	case BPF_ALU64 | BPF_OR | BPF_X:
601 		emit_insn(ctx, or, dst, dst, src);
602 		emit_zext_32(ctx, dst, is32);
603 		break;
604 
605 	/* dst = dst | imm */
606 	case BPF_ALU | BPF_OR | BPF_K:
607 	case BPF_ALU64 | BPF_OR | BPF_K:
608 		if (is_unsigned_imm12(imm)) {
609 			emit_insn(ctx, ori, dst, dst, imm);
610 		} else {
611 			move_imm(ctx, t1, imm, is32);
612 			emit_insn(ctx, or, dst, dst, t1);
613 		}
614 		emit_zext_32(ctx, dst, is32);
615 		break;
616 
617 	/* dst = dst ^ src */
618 	case BPF_ALU | BPF_XOR | BPF_X:
619 	case BPF_ALU64 | BPF_XOR | BPF_X:
620 		emit_insn(ctx, xor, dst, dst, src);
621 		emit_zext_32(ctx, dst, is32);
622 		break;
623 
624 	/* dst = dst ^ imm */
625 	case BPF_ALU | BPF_XOR | BPF_K:
626 	case BPF_ALU64 | BPF_XOR | BPF_K:
627 		if (is_unsigned_imm12(imm)) {
628 			emit_insn(ctx, xori, dst, dst, imm);
629 		} else {
630 			move_imm(ctx, t1, imm, is32);
631 			emit_insn(ctx, xor, dst, dst, t1);
632 		}
633 		emit_zext_32(ctx, dst, is32);
634 		break;
635 
636 	/* dst = dst << src (logical) */
637 	case BPF_ALU | BPF_LSH | BPF_X:
638 		emit_insn(ctx, sllw, dst, dst, src);
639 		emit_zext_32(ctx, dst, is32);
640 		break;
641 
642 	case BPF_ALU64 | BPF_LSH | BPF_X:
643 		emit_insn(ctx, slld, dst, dst, src);
644 		break;
645 
646 	/* dst = dst << imm (logical) */
647 	case BPF_ALU | BPF_LSH | BPF_K:
648 		emit_insn(ctx, slliw, dst, dst, imm);
649 		emit_zext_32(ctx, dst, is32);
650 		break;
651 
652 	case BPF_ALU64 | BPF_LSH | BPF_K:
653 		emit_insn(ctx, sllid, dst, dst, imm);
654 		break;
655 
656 	/* dst = dst >> src (logical) */
657 	case BPF_ALU | BPF_RSH | BPF_X:
658 		emit_insn(ctx, srlw, dst, dst, src);
659 		emit_zext_32(ctx, dst, is32);
660 		break;
661 
662 	case BPF_ALU64 | BPF_RSH | BPF_X:
663 		emit_insn(ctx, srld, dst, dst, src);
664 		break;
665 
666 	/* dst = dst >> imm (logical) */
667 	case BPF_ALU | BPF_RSH | BPF_K:
668 		emit_insn(ctx, srliw, dst, dst, imm);
669 		emit_zext_32(ctx, dst, is32);
670 		break;
671 
672 	case BPF_ALU64 | BPF_RSH | BPF_K:
673 		emit_insn(ctx, srlid, dst, dst, imm);
674 		break;
675 
676 	/* dst = dst >> src (arithmetic) */
677 	case BPF_ALU | BPF_ARSH | BPF_X:
678 		emit_insn(ctx, sraw, dst, dst, src);
679 		emit_zext_32(ctx, dst, is32);
680 		break;
681 
682 	case BPF_ALU64 | BPF_ARSH | BPF_X:
683 		emit_insn(ctx, srad, dst, dst, src);
684 		break;
685 
686 	/* dst = dst >> imm (arithmetic) */
687 	case BPF_ALU | BPF_ARSH | BPF_K:
688 		emit_insn(ctx, sraiw, dst, dst, imm);
689 		emit_zext_32(ctx, dst, is32);
690 		break;
691 
692 	case BPF_ALU64 | BPF_ARSH | BPF_K:
693 		emit_insn(ctx, sraid, dst, dst, imm);
694 		break;
695 
696 	/* dst = BSWAP##imm(dst) */
697 	case BPF_ALU | BPF_END | BPF_FROM_LE:
698 		switch (imm) {
699 		case 16:
700 			/* zero-extend 16 bits into 64 bits */
701 			emit_insn(ctx, bstrpickd, dst, dst, 15, 0);
702 			break;
703 		case 32:
704 			/* zero-extend 32 bits into 64 bits */
705 			emit_zext_32(ctx, dst, is32);
706 			break;
707 		case 64:
708 			/* do nothing */
709 			break;
710 		}
711 		break;
712 
713 	case BPF_ALU | BPF_END | BPF_FROM_BE:
714 		switch (imm) {
715 		case 16:
716 			emit_insn(ctx, revb2h, dst, dst);
717 			/* zero-extend 16 bits into 64 bits */
718 			emit_insn(ctx, bstrpickd, dst, dst, 15, 0);
719 			break;
720 		case 32:
721 			emit_insn(ctx, revb2w, dst, dst);
722 			/* zero-extend 32 bits into 64 bits */
723 			emit_zext_32(ctx, dst, is32);
724 			break;
725 		case 64:
726 			emit_insn(ctx, revbd, dst, dst);
727 			break;
728 		}
729 		break;
730 
731 	/* PC += off if dst cond src */
732 	case BPF_JMP | BPF_JEQ | BPF_X:
733 	case BPF_JMP | BPF_JNE | BPF_X:
734 	case BPF_JMP | BPF_JGT | BPF_X:
735 	case BPF_JMP | BPF_JGE | BPF_X:
736 	case BPF_JMP | BPF_JLT | BPF_X:
737 	case BPF_JMP | BPF_JLE | BPF_X:
738 	case BPF_JMP | BPF_JSGT | BPF_X:
739 	case BPF_JMP | BPF_JSGE | BPF_X:
740 	case BPF_JMP | BPF_JSLT | BPF_X:
741 	case BPF_JMP | BPF_JSLE | BPF_X:
742 	case BPF_JMP32 | BPF_JEQ | BPF_X:
743 	case BPF_JMP32 | BPF_JNE | BPF_X:
744 	case BPF_JMP32 | BPF_JGT | BPF_X:
745 	case BPF_JMP32 | BPF_JGE | BPF_X:
746 	case BPF_JMP32 | BPF_JLT | BPF_X:
747 	case BPF_JMP32 | BPF_JLE | BPF_X:
748 	case BPF_JMP32 | BPF_JSGT | BPF_X:
749 	case BPF_JMP32 | BPF_JSGE | BPF_X:
750 	case BPF_JMP32 | BPF_JSLT | BPF_X:
751 	case BPF_JMP32 | BPF_JSLE | BPF_X:
752 		jmp_offset = bpf2la_offset(i, off, ctx);
753 		move_reg(ctx, t1, dst);
754 		move_reg(ctx, t2, src);
755 		if (is_signed_bpf_cond(BPF_OP(code))) {
756 			emit_sext_32(ctx, t1, is32);
757 			emit_sext_32(ctx, t2, is32);
758 		} else {
759 			emit_zext_32(ctx, t1, is32);
760 			emit_zext_32(ctx, t2, is32);
761 		}
762 		if (emit_cond_jmp(ctx, cond, t1, t2, jmp_offset) < 0)
763 			goto toofar;
764 		break;
765 
766 	/* PC += off if dst cond imm */
767 	case BPF_JMP | BPF_JEQ | BPF_K:
768 	case BPF_JMP | BPF_JNE | BPF_K:
769 	case BPF_JMP | BPF_JGT | BPF_K:
770 	case BPF_JMP | BPF_JGE | BPF_K:
771 	case BPF_JMP | BPF_JLT | BPF_K:
772 	case BPF_JMP | BPF_JLE | BPF_K:
773 	case BPF_JMP | BPF_JSGT | BPF_K:
774 	case BPF_JMP | BPF_JSGE | BPF_K:
775 	case BPF_JMP | BPF_JSLT | BPF_K:
776 	case BPF_JMP | BPF_JSLE | BPF_K:
777 	case BPF_JMP32 | BPF_JEQ | BPF_K:
778 	case BPF_JMP32 | BPF_JNE | BPF_K:
779 	case BPF_JMP32 | BPF_JGT | BPF_K:
780 	case BPF_JMP32 | BPF_JGE | BPF_K:
781 	case BPF_JMP32 | BPF_JLT | BPF_K:
782 	case BPF_JMP32 | BPF_JLE | BPF_K:
783 	case BPF_JMP32 | BPF_JSGT | BPF_K:
784 	case BPF_JMP32 | BPF_JSGE | BPF_K:
785 	case BPF_JMP32 | BPF_JSLT | BPF_K:
786 	case BPF_JMP32 | BPF_JSLE | BPF_K:
787 		jmp_offset = bpf2la_offset(i, off, ctx);
788 		if (imm) {
789 			move_imm(ctx, t1, imm, false);
790 			tm = t1;
791 		} else {
792 			/* If imm is 0, simply use zero register. */
793 			tm = LOONGARCH_GPR_ZERO;
794 		}
795 		move_reg(ctx, t2, dst);
796 		if (is_signed_bpf_cond(BPF_OP(code))) {
797 			emit_sext_32(ctx, tm, is32);
798 			emit_sext_32(ctx, t2, is32);
799 		} else {
800 			emit_zext_32(ctx, tm, is32);
801 			emit_zext_32(ctx, t2, is32);
802 		}
803 		if (emit_cond_jmp(ctx, cond, t2, tm, jmp_offset) < 0)
804 			goto toofar;
805 		break;
806 
807 	/* PC += off if dst & src */
808 	case BPF_JMP | BPF_JSET | BPF_X:
809 	case BPF_JMP32 | BPF_JSET | BPF_X:
810 		jmp_offset = bpf2la_offset(i, off, ctx);
811 		emit_insn(ctx, and, t1, dst, src);
812 		emit_zext_32(ctx, t1, is32);
813 		if (emit_cond_jmp(ctx, cond, t1, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
814 			goto toofar;
815 		break;
816 
817 	/* PC += off if dst & imm */
818 	case BPF_JMP | BPF_JSET | BPF_K:
819 	case BPF_JMP32 | BPF_JSET | BPF_K:
820 		jmp_offset = bpf2la_offset(i, off, ctx);
821 		move_imm(ctx, t1, imm, is32);
822 		emit_insn(ctx, and, t1, dst, t1);
823 		emit_zext_32(ctx, t1, is32);
824 		if (emit_cond_jmp(ctx, cond, t1, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
825 			goto toofar;
826 		break;
827 
828 	/* PC += off */
829 	case BPF_JMP | BPF_JA:
830 		jmp_offset = bpf2la_offset(i, off, ctx);
831 		if (emit_uncond_jmp(ctx, jmp_offset) < 0)
832 			goto toofar;
833 		break;
834 
835 	/* function call */
836 	case BPF_JMP | BPF_CALL:
837 		mark_call(ctx);
838 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
839 					    &func_addr, &func_addr_fixed);
840 		if (ret < 0)
841 			return ret;
842 
843 		move_addr(ctx, t1, func_addr);
844 		emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0);
845 		move_reg(ctx, regmap[BPF_REG_0], LOONGARCH_GPR_A0);
846 		break;
847 
848 	/* tail call */
849 	case BPF_JMP | BPF_TAIL_CALL:
850 		mark_tail_call(ctx);
851 		if (emit_bpf_tail_call(ctx) < 0)
852 			return -EINVAL;
853 		break;
854 
855 	/* function return */
856 	case BPF_JMP | BPF_EXIT:
857 		if (i == ctx->prog->len - 1)
858 			break;
859 
860 		jmp_offset = epilogue_offset(ctx);
861 		if (emit_uncond_jmp(ctx, jmp_offset) < 0)
862 			goto toofar;
863 		break;
864 
865 	/* dst = imm64 */
866 	case BPF_LD | BPF_IMM | BPF_DW:
867 	{
868 		const u64 imm64 = (u64)(insn + 1)->imm << 32 | (u32)insn->imm;
869 
870 		move_imm(ctx, dst, imm64, is32);
871 		return 1;
872 	}
873 
874 	/* dst = *(size *)(src + off) */
875 	case BPF_LDX | BPF_MEM | BPF_B:
876 	case BPF_LDX | BPF_MEM | BPF_H:
877 	case BPF_LDX | BPF_MEM | BPF_W:
878 	case BPF_LDX | BPF_MEM | BPF_DW:
879 	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
880 	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
881 	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
882 	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
883 		switch (BPF_SIZE(code)) {
884 		case BPF_B:
885 			if (is_signed_imm12(off)) {
886 				emit_insn(ctx, ldbu, dst, src, off);
887 			} else {
888 				move_imm(ctx, t1, off, is32);
889 				emit_insn(ctx, ldxbu, dst, src, t1);
890 			}
891 			break;
892 		case BPF_H:
893 			if (is_signed_imm12(off)) {
894 				emit_insn(ctx, ldhu, dst, src, off);
895 			} else {
896 				move_imm(ctx, t1, off, is32);
897 				emit_insn(ctx, ldxhu, dst, src, t1);
898 			}
899 			break;
900 		case BPF_W:
901 			if (is_signed_imm12(off)) {
902 				emit_insn(ctx, ldwu, dst, src, off);
903 			} else if (is_signed_imm14(off)) {
904 				emit_insn(ctx, ldptrw, dst, src, off);
905 			} else {
906 				move_imm(ctx, t1, off, is32);
907 				emit_insn(ctx, ldxwu, dst, src, t1);
908 			}
909 			break;
910 		case BPF_DW:
911 			move_imm(ctx, t1, off, is32);
912 			emit_insn(ctx, ldxd, dst, src, t1);
913 			break;
914 		}
915 
916 		ret = add_exception_handler(insn, ctx, dst);
917 		if (ret)
918 			return ret;
919 		break;
920 
921 	/* *(size *)(dst + off) = imm */
922 	case BPF_ST | BPF_MEM | BPF_B:
923 	case BPF_ST | BPF_MEM | BPF_H:
924 	case BPF_ST | BPF_MEM | BPF_W:
925 	case BPF_ST | BPF_MEM | BPF_DW:
926 		switch (BPF_SIZE(code)) {
927 		case BPF_B:
928 			move_imm(ctx, t1, imm, is32);
929 			if (is_signed_imm12(off)) {
930 				emit_insn(ctx, stb, t1, dst, off);
931 			} else {
932 				move_imm(ctx, t2, off, is32);
933 				emit_insn(ctx, stxb, t1, dst, t2);
934 			}
935 			break;
936 		case BPF_H:
937 			move_imm(ctx, t1, imm, is32);
938 			if (is_signed_imm12(off)) {
939 				emit_insn(ctx, sth, t1, dst, off);
940 			} else {
941 				move_imm(ctx, t2, off, is32);
942 				emit_insn(ctx, stxh, t1, dst, t2);
943 			}
944 			break;
945 		case BPF_W:
946 			move_imm(ctx, t1, imm, is32);
947 			if (is_signed_imm12(off)) {
948 				emit_insn(ctx, stw, t1, dst, off);
949 			} else if (is_signed_imm14(off)) {
950 				emit_insn(ctx, stptrw, t1, dst, off);
951 			} else {
952 				move_imm(ctx, t2, off, is32);
953 				emit_insn(ctx, stxw, t1, dst, t2);
954 			}
955 			break;
956 		case BPF_DW:
957 			move_imm(ctx, t1, imm, is32);
958 			if (is_signed_imm12(off)) {
959 				emit_insn(ctx, std, t1, dst, off);
960 			} else if (is_signed_imm14(off)) {
961 				emit_insn(ctx, stptrd, t1, dst, off);
962 			} else {
963 				move_imm(ctx, t2, off, is32);
964 				emit_insn(ctx, stxd, t1, dst, t2);
965 			}
966 			break;
967 		}
968 		break;
969 
970 	/* *(size *)(dst + off) = src */
971 	case BPF_STX | BPF_MEM | BPF_B:
972 	case BPF_STX | BPF_MEM | BPF_H:
973 	case BPF_STX | BPF_MEM | BPF_W:
974 	case BPF_STX | BPF_MEM | BPF_DW:
975 		switch (BPF_SIZE(code)) {
976 		case BPF_B:
977 			if (is_signed_imm12(off)) {
978 				emit_insn(ctx, stb, src, dst, off);
979 			} else {
980 				move_imm(ctx, t1, off, is32);
981 				emit_insn(ctx, stxb, src, dst, t1);
982 			}
983 			break;
984 		case BPF_H:
985 			if (is_signed_imm12(off)) {
986 				emit_insn(ctx, sth, src, dst, off);
987 			} else {
988 				move_imm(ctx, t1, off, is32);
989 				emit_insn(ctx, stxh, src, dst, t1);
990 			}
991 			break;
992 		case BPF_W:
993 			if (is_signed_imm12(off)) {
994 				emit_insn(ctx, stw, src, dst, off);
995 			} else if (is_signed_imm14(off)) {
996 				emit_insn(ctx, stptrw, src, dst, off);
997 			} else {
998 				move_imm(ctx, t1, off, is32);
999 				emit_insn(ctx, stxw, src, dst, t1);
1000 			}
1001 			break;
1002 		case BPF_DW:
1003 			if (is_signed_imm12(off)) {
1004 				emit_insn(ctx, std, src, dst, off);
1005 			} else if (is_signed_imm14(off)) {
1006 				emit_insn(ctx, stptrd, src, dst, off);
1007 			} else {
1008 				move_imm(ctx, t1, off, is32);
1009 				emit_insn(ctx, stxd, src, dst, t1);
1010 			}
1011 			break;
1012 		}
1013 		break;
1014 
1015 	case BPF_STX | BPF_ATOMIC | BPF_W:
1016 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1017 		emit_atomic(insn, ctx);
1018 		break;
1019 
1020 	/* Speculation barrier */
1021 	case BPF_ST | BPF_NOSPEC:
1022 		break;
1023 
1024 	default:
1025 		pr_err("bpf_jit: unknown opcode %02x\n", code);
1026 		return -EINVAL;
1027 	}
1028 
1029 	return 0;
1030 
1031 toofar:
1032 	pr_info_once("bpf_jit: opcode %02x, jump too far\n", code);
1033 	return -E2BIG;
1034 }
1035 
build_body(struct jit_ctx * ctx,bool extra_pass)1036 static int build_body(struct jit_ctx *ctx, bool extra_pass)
1037 {
1038 	int i;
1039 	const struct bpf_prog *prog = ctx->prog;
1040 
1041 	for (i = 0; i < prog->len; i++) {
1042 		const struct bpf_insn *insn = &prog->insnsi[i];
1043 		int ret;
1044 
1045 		if (ctx->image == NULL)
1046 			ctx->offset[i] = ctx->idx;
1047 
1048 		ret = build_insn(insn, ctx, extra_pass);
1049 		if (ret > 0) {
1050 			i++;
1051 			if (ctx->image == NULL)
1052 				ctx->offset[i] = ctx->idx;
1053 			continue;
1054 		}
1055 		if (ret)
1056 			return ret;
1057 	}
1058 
1059 	if (ctx->image == NULL)
1060 		ctx->offset[i] = ctx->idx;
1061 
1062 	return 0;
1063 }
1064 
1065 /* Fill space with break instructions */
jit_fill_hole(void * area,unsigned int size)1066 static void jit_fill_hole(void *area, unsigned int size)
1067 {
1068 	u32 *ptr;
1069 
1070 	/* We are guaranteed to have aligned memory */
1071 	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
1072 		*ptr++ = INSN_BREAK;
1073 }
1074 
validate_code(struct jit_ctx * ctx)1075 static int validate_code(struct jit_ctx *ctx)
1076 {
1077 	int i;
1078 	union loongarch_instruction insn;
1079 
1080 	for (i = 0; i < ctx->idx; i++) {
1081 		insn = ctx->image[i];
1082 		/* Check INSN_BREAK */
1083 		if (insn.word == INSN_BREAK)
1084 			return -1;
1085 	}
1086 
1087 	if (WARN_ON_ONCE(ctx->num_exentries != ctx->prog->aux->num_exentries))
1088 		return -1;
1089 
1090 	return 0;
1091 }
1092 
bpf_int_jit_compile(struct bpf_prog * prog)1093 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1094 {
1095 	bool tmp_blinded = false, extra_pass = false;
1096 	u8 *image_ptr;
1097 	int image_size, prog_size, extable_size;
1098 	struct jit_ctx ctx;
1099 	struct jit_data *jit_data;
1100 	struct bpf_binary_header *header;
1101 	struct bpf_prog *tmp, *orig_prog = prog;
1102 
1103 	/*
1104 	 * If BPF JIT was not enabled then we must fall back to
1105 	 * the interpreter.
1106 	 */
1107 	if (!prog->jit_requested)
1108 		return orig_prog;
1109 
1110 	tmp = bpf_jit_blind_constants(prog);
1111 	/*
1112 	 * If blinding was requested and we failed during blinding,
1113 	 * we must fall back to the interpreter. Otherwise, we save
1114 	 * the new JITed code.
1115 	 */
1116 	if (IS_ERR(tmp))
1117 		return orig_prog;
1118 
1119 	if (tmp != prog) {
1120 		tmp_blinded = true;
1121 		prog = tmp;
1122 	}
1123 
1124 	jit_data = prog->aux->jit_data;
1125 	if (!jit_data) {
1126 		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1127 		if (!jit_data) {
1128 			prog = orig_prog;
1129 			goto out;
1130 		}
1131 		prog->aux->jit_data = jit_data;
1132 	}
1133 	if (jit_data->ctx.offset) {
1134 		ctx = jit_data->ctx;
1135 		image_ptr = jit_data->image;
1136 		header = jit_data->header;
1137 		extra_pass = true;
1138 		prog_size = sizeof(u32) * ctx.idx;
1139 		goto skip_init_ctx;
1140 	}
1141 
1142 	memset(&ctx, 0, sizeof(ctx));
1143 	ctx.prog = prog;
1144 
1145 	ctx.offset = kvcalloc(prog->len + 1, sizeof(u32), GFP_KERNEL);
1146 	if (ctx.offset == NULL) {
1147 		prog = orig_prog;
1148 		goto out_offset;
1149 	}
1150 
1151 	/* 1. Initial fake pass to compute ctx->idx and set ctx->flags */
1152 	build_prologue(&ctx);
1153 	if (build_body(&ctx, extra_pass)) {
1154 		prog = orig_prog;
1155 		goto out_offset;
1156 	}
1157 	ctx.epilogue_offset = ctx.idx;
1158 	build_epilogue(&ctx);
1159 
1160 	extable_size = prog->aux->num_exentries * sizeof(struct exception_table_entry);
1161 
1162 	/* Now we know the actual image size.
1163 	 * As each LoongArch instruction is of length 32bit,
1164 	 * we are translating number of JITed intructions into
1165 	 * the size required to store these JITed code.
1166 	 */
1167 	prog_size = sizeof(u32) * ctx.idx;
1168 	image_size = prog_size + extable_size;
1169 	/* Now we know the size of the structure to make */
1170 	header = bpf_jit_binary_alloc(image_size, &image_ptr,
1171 				      sizeof(u32), jit_fill_hole);
1172 	if (header == NULL) {
1173 		prog = orig_prog;
1174 		goto out_offset;
1175 	}
1176 
1177 	/* 2. Now, the actual pass to generate final JIT code */
1178 	ctx.image = (union loongarch_instruction *)image_ptr;
1179 	if (extable_size)
1180 		prog->aux->extable = (void *)image_ptr + prog_size;
1181 
1182 skip_init_ctx:
1183 	ctx.idx = 0;
1184 	ctx.num_exentries = 0;
1185 
1186 	build_prologue(&ctx);
1187 	if (build_body(&ctx, extra_pass)) {
1188 		bpf_jit_binary_free(header);
1189 		prog = orig_prog;
1190 		goto out_offset;
1191 	}
1192 	build_epilogue(&ctx);
1193 
1194 	/* 3. Extra pass to validate JITed code */
1195 	if (validate_code(&ctx)) {
1196 		bpf_jit_binary_free(header);
1197 		prog = orig_prog;
1198 		goto out_offset;
1199 	}
1200 
1201 	/* And we're done */
1202 	if (bpf_jit_enable > 1)
1203 		bpf_jit_dump(prog->len, prog_size, 2, ctx.image);
1204 
1205 	/* Update the icache */
1206 	flush_icache_range((unsigned long)header, (unsigned long)(ctx.image + ctx.idx));
1207 
1208 	if (!prog->is_func || extra_pass) {
1209 		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
1210 			pr_err_once("multi-func JIT bug %d != %d\n",
1211 				    ctx.idx, jit_data->ctx.idx);
1212 			bpf_jit_binary_free(header);
1213 			prog->bpf_func = NULL;
1214 			prog->jited = 0;
1215 			prog->jited_len = 0;
1216 			goto out_offset;
1217 		}
1218 		bpf_jit_binary_lock_ro(header);
1219 	} else {
1220 		jit_data->ctx = ctx;
1221 		jit_data->image = image_ptr;
1222 		jit_data->header = header;
1223 	}
1224 	prog->jited = 1;
1225 	prog->jited_len = prog_size;
1226 	prog->bpf_func = (void *)ctx.image;
1227 
1228 	if (!prog->is_func || extra_pass) {
1229 		int i;
1230 
1231 		/* offset[prog->len] is the size of program */
1232 		for (i = 0; i <= prog->len; i++)
1233 			ctx.offset[i] *= LOONGARCH_INSN_SIZE;
1234 		bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
1235 
1236 out_offset:
1237 		kvfree(ctx.offset);
1238 		kfree(jit_data);
1239 		prog->aux->jit_data = NULL;
1240 	}
1241 
1242 out:
1243 	if (tmp_blinded)
1244 		bpf_jit_prog_release_other(prog, prog == orig_prog ? tmp : orig_prog);
1245 
1246 	out_offset = -1;
1247 
1248 	return prog;
1249 }
1250 
1251 /* Indicate the JIT backend supports mixing bpf2bpf and tailcalls. */
bpf_jit_supports_subprog_tailcalls(void)1252 bool bpf_jit_supports_subprog_tailcalls(void)
1253 {
1254 	return true;
1255 }
1256