xref: /openbmc/linux/arch/loongarch/net/bpf_jit.c (revision 6246ed09111fbb17168619006b4380103c6673c3)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler for LoongArch
4  *
5  * Copyright (C) 2022 Loongson Technology Corporation Limited
6  */
7 #include "bpf_jit.h"
8 
9 #define REG_TCC		LOONGARCH_GPR_A6
10 #define TCC_SAVED	LOONGARCH_GPR_S5
11 
12 #define SAVE_RA		BIT(0)
13 #define SAVE_TCC	BIT(1)
14 
15 static const int regmap[] = {
16 	/* return value from in-kernel function, and exit value for eBPF program */
17 	[BPF_REG_0] = LOONGARCH_GPR_A5,
18 	/* arguments from eBPF program to in-kernel function */
19 	[BPF_REG_1] = LOONGARCH_GPR_A0,
20 	[BPF_REG_2] = LOONGARCH_GPR_A1,
21 	[BPF_REG_3] = LOONGARCH_GPR_A2,
22 	[BPF_REG_4] = LOONGARCH_GPR_A3,
23 	[BPF_REG_5] = LOONGARCH_GPR_A4,
24 	/* callee saved registers that in-kernel function will preserve */
25 	[BPF_REG_6] = LOONGARCH_GPR_S0,
26 	[BPF_REG_7] = LOONGARCH_GPR_S1,
27 	[BPF_REG_8] = LOONGARCH_GPR_S2,
28 	[BPF_REG_9] = LOONGARCH_GPR_S3,
29 	/* read-only frame pointer to access stack */
30 	[BPF_REG_FP] = LOONGARCH_GPR_S4,
31 	/* temporary register for blinding constants */
32 	[BPF_REG_AX] = LOONGARCH_GPR_T0,
33 };
34 
35 static void mark_call(struct jit_ctx *ctx)
36 {
37 	ctx->flags |= SAVE_RA;
38 }
39 
40 static void mark_tail_call(struct jit_ctx *ctx)
41 {
42 	ctx->flags |= SAVE_TCC;
43 }
44 
45 static bool seen_call(struct jit_ctx *ctx)
46 {
47 	return (ctx->flags & SAVE_RA);
48 }
49 
50 static bool seen_tail_call(struct jit_ctx *ctx)
51 {
52 	return (ctx->flags & SAVE_TCC);
53 }
54 
55 static u8 tail_call_reg(struct jit_ctx *ctx)
56 {
57 	if (seen_call(ctx))
58 		return TCC_SAVED;
59 
60 	return REG_TCC;
61 }
62 
63 /*
64  * eBPF prog stack layout:
65  *
66  *                                        high
67  * original $sp ------------> +-------------------------+ <--LOONGARCH_GPR_FP
68  *                            |           $ra           |
69  *                            +-------------------------+
70  *                            |           $fp           |
71  *                            +-------------------------+
72  *                            |           $s0           |
73  *                            +-------------------------+
74  *                            |           $s1           |
75  *                            +-------------------------+
76  *                            |           $s2           |
77  *                            +-------------------------+
78  *                            |           $s3           |
79  *                            +-------------------------+
80  *                            |           $s4           |
81  *                            +-------------------------+
82  *                            |           $s5           |
83  *                            +-------------------------+ <--BPF_REG_FP
84  *                            |  prog->aux->stack_depth |
85  *                            |        (optional)       |
86  * current $sp -------------> +-------------------------+
87  *                                        low
88  */
89 static void build_prologue(struct jit_ctx *ctx)
90 {
91 	int stack_adjust = 0, store_offset, bpf_stack_adjust;
92 
93 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
94 
95 	/* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
96 	stack_adjust += sizeof(long) * 8;
97 
98 	stack_adjust = round_up(stack_adjust, 16);
99 	stack_adjust += bpf_stack_adjust;
100 
101 	/*
102 	 * First instruction initializes the tail call count (TCC).
103 	 * On tail call we skip this instruction, and the TCC is
104 	 * passed in REG_TCC from the caller.
105 	 */
106 	emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
107 
108 	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
109 
110 	store_offset = stack_adjust - sizeof(long);
111 	emit_insn(ctx, std, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, store_offset);
112 
113 	store_offset -= sizeof(long);
114 	emit_insn(ctx, std, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, store_offset);
115 
116 	store_offset -= sizeof(long);
117 	emit_insn(ctx, std, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, store_offset);
118 
119 	store_offset -= sizeof(long);
120 	emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, store_offset);
121 
122 	store_offset -= sizeof(long);
123 	emit_insn(ctx, std, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, store_offset);
124 
125 	store_offset -= sizeof(long);
126 	emit_insn(ctx, std, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, store_offset);
127 
128 	store_offset -= sizeof(long);
129 	emit_insn(ctx, std, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, store_offset);
130 
131 	store_offset -= sizeof(long);
132 	emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
133 
134 	emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
135 
136 	if (bpf_stack_adjust)
137 		emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
138 
139 	/*
140 	 * Program contains calls and tail calls, so REG_TCC need
141 	 * to be saved across calls.
142 	 */
143 	if (seen_tail_call(ctx) && seen_call(ctx))
144 		move_reg(ctx, TCC_SAVED, REG_TCC);
145 
146 	ctx->stack_size = stack_adjust;
147 }
148 
149 static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
150 {
151 	int stack_adjust = ctx->stack_size;
152 	int load_offset;
153 
154 	load_offset = stack_adjust - sizeof(long);
155 	emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, load_offset);
156 
157 	load_offset -= sizeof(long);
158 	emit_insn(ctx, ldd, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, load_offset);
159 
160 	load_offset -= sizeof(long);
161 	emit_insn(ctx, ldd, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, load_offset);
162 
163 	load_offset -= sizeof(long);
164 	emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, load_offset);
165 
166 	load_offset -= sizeof(long);
167 	emit_insn(ctx, ldd, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, load_offset);
168 
169 	load_offset -= sizeof(long);
170 	emit_insn(ctx, ldd, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, load_offset);
171 
172 	load_offset -= sizeof(long);
173 	emit_insn(ctx, ldd, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, load_offset);
174 
175 	load_offset -= sizeof(long);
176 	emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
177 
178 	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
179 
180 	if (!is_tail_call) {
181 		/* Set return value */
182 		move_reg(ctx, LOONGARCH_GPR_A0, regmap[BPF_REG_0]);
183 		/* Return to the caller */
184 		emit_insn(ctx, jirl, LOONGARCH_GPR_RA, LOONGARCH_GPR_ZERO, 0);
185 	} else {
186 		/*
187 		 * Call the next bpf prog and skip the first instruction
188 		 * of TCC initialization.
189 		 */
190 		emit_insn(ctx, jirl, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, 1);
191 	}
192 }
193 
194 static void build_epilogue(struct jit_ctx *ctx)
195 {
196 	__build_epilogue(ctx, false);
197 }
198 
199 bool bpf_jit_supports_kfunc_call(void)
200 {
201 	return true;
202 }
203 
204 /* initialized on the first pass of build_body() */
205 static int out_offset = -1;
206 static int emit_bpf_tail_call(struct jit_ctx *ctx)
207 {
208 	int off;
209 	u8 tcc = tail_call_reg(ctx);
210 	u8 a1 = LOONGARCH_GPR_A1;
211 	u8 a2 = LOONGARCH_GPR_A2;
212 	u8 t1 = LOONGARCH_GPR_T1;
213 	u8 t2 = LOONGARCH_GPR_T2;
214 	u8 t3 = LOONGARCH_GPR_T3;
215 	const int idx0 = ctx->idx;
216 
217 #define cur_offset (ctx->idx - idx0)
218 #define jmp_offset (out_offset - (cur_offset))
219 
220 	/*
221 	 * a0: &ctx
222 	 * a1: &array
223 	 * a2: index
224 	 *
225 	 * if (index >= array->map.max_entries)
226 	 *	 goto out;
227 	 */
228 	off = offsetof(struct bpf_array, map.max_entries);
229 	emit_insn(ctx, ldwu, t1, a1, off);
230 	/* bgeu $a2, $t1, jmp_offset */
231 	if (emit_tailcall_jmp(ctx, BPF_JGE, a2, t1, jmp_offset) < 0)
232 		goto toofar;
233 
234 	/*
235 	 * if (--TCC < 0)
236 	 *	 goto out;
237 	 */
238 	emit_insn(ctx, addid, REG_TCC, tcc, -1);
239 	if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
240 		goto toofar;
241 
242 	/*
243 	 * prog = array->ptrs[index];
244 	 * if (!prog)
245 	 *	 goto out;
246 	 */
247 	emit_insn(ctx, alsld, t2, a2, a1, 2);
248 	off = offsetof(struct bpf_array, ptrs);
249 	emit_insn(ctx, ldd, t2, t2, off);
250 	/* beq $t2, $zero, jmp_offset */
251 	if (emit_tailcall_jmp(ctx, BPF_JEQ, t2, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
252 		goto toofar;
253 
254 	/* goto *(prog->bpf_func + 4); */
255 	off = offsetof(struct bpf_prog, bpf_func);
256 	emit_insn(ctx, ldd, t3, t2, off);
257 	__build_epilogue(ctx, true);
258 
259 	/* out: */
260 	if (out_offset == -1)
261 		out_offset = cur_offset;
262 	if (cur_offset != out_offset) {
263 		pr_err_once("tail_call out_offset = %d, expected %d!\n",
264 			    cur_offset, out_offset);
265 		return -1;
266 	}
267 
268 	return 0;
269 
270 toofar:
271 	pr_info_once("tail_call: jump too far\n");
272 	return -1;
273 #undef cur_offset
274 #undef jmp_offset
275 }
276 
277 static void emit_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
278 {
279 	const u8 t1 = LOONGARCH_GPR_T1;
280 	const u8 t2 = LOONGARCH_GPR_T2;
281 	const u8 t3 = LOONGARCH_GPR_T3;
282 	const u8 src = regmap[insn->src_reg];
283 	const u8 dst = regmap[insn->dst_reg];
284 	const s16 off = insn->off;
285 	const s32 imm = insn->imm;
286 	const bool isdw = BPF_SIZE(insn->code) == BPF_DW;
287 
288 	move_imm(ctx, t1, off, false);
289 	emit_insn(ctx, addd, t1, dst, t1);
290 	move_reg(ctx, t3, src);
291 
292 	switch (imm) {
293 	/* lock *(size *)(dst + off) <op>= src */
294 	case BPF_ADD:
295 		if (isdw)
296 			emit_insn(ctx, amaddd, t2, t1, src);
297 		else
298 			emit_insn(ctx, amaddw, t2, t1, src);
299 		break;
300 	case BPF_AND:
301 		if (isdw)
302 			emit_insn(ctx, amandd, t2, t1, src);
303 		else
304 			emit_insn(ctx, amandw, t2, t1, src);
305 		break;
306 	case BPF_OR:
307 		if (isdw)
308 			emit_insn(ctx, amord, t2, t1, src);
309 		else
310 			emit_insn(ctx, amorw, t2, t1, src);
311 		break;
312 	case BPF_XOR:
313 		if (isdw)
314 			emit_insn(ctx, amxord, t2, t1, src);
315 		else
316 			emit_insn(ctx, amxorw, t2, t1, src);
317 		break;
318 	/* src = atomic_fetch_<op>(dst + off, src) */
319 	case BPF_ADD | BPF_FETCH:
320 		if (isdw) {
321 			emit_insn(ctx, amaddd, src, t1, t3);
322 		} else {
323 			emit_insn(ctx, amaddw, src, t1, t3);
324 			emit_zext_32(ctx, src, true);
325 		}
326 		break;
327 	case BPF_AND | BPF_FETCH:
328 		if (isdw) {
329 			emit_insn(ctx, amandd, src, t1, t3);
330 		} else {
331 			emit_insn(ctx, amandw, src, t1, t3);
332 			emit_zext_32(ctx, src, true);
333 		}
334 		break;
335 	case BPF_OR | BPF_FETCH:
336 		if (isdw) {
337 			emit_insn(ctx, amord, src, t1, t3);
338 		} else {
339 			emit_insn(ctx, amorw, src, t1, t3);
340 			emit_zext_32(ctx, src, true);
341 		}
342 		break;
343 	case BPF_XOR | BPF_FETCH:
344 		if (isdw) {
345 			emit_insn(ctx, amxord, src, t1, t3);
346 		} else {
347 			emit_insn(ctx, amxorw, src, t1, t3);
348 			emit_zext_32(ctx, src, true);
349 		}
350 		break;
351 	/* src = atomic_xchg(dst + off, src); */
352 	case BPF_XCHG:
353 		if (isdw) {
354 			emit_insn(ctx, amswapd, src, t1, t3);
355 		} else {
356 			emit_insn(ctx, amswapw, src, t1, t3);
357 			emit_zext_32(ctx, src, true);
358 		}
359 		break;
360 	/* r0 = atomic_cmpxchg(dst + off, r0, src); */
361 	case BPF_CMPXCHG:
362 		u8 r0 = regmap[BPF_REG_0];
363 
364 		move_reg(ctx, t2, r0);
365 		if (isdw) {
366 			emit_insn(ctx, lld, r0, t1, 0);
367 			emit_insn(ctx, bne, t2, r0, 4);
368 			move_reg(ctx, t3, src);
369 			emit_insn(ctx, scd, t3, t1, 0);
370 			emit_insn(ctx, beq, t3, LOONGARCH_GPR_ZERO, -4);
371 		} else {
372 			emit_insn(ctx, llw, r0, t1, 0);
373 			emit_zext_32(ctx, t2, true);
374 			emit_zext_32(ctx, r0, true);
375 			emit_insn(ctx, bne, t2, r0, 4);
376 			move_reg(ctx, t3, src);
377 			emit_insn(ctx, scw, t3, t1, 0);
378 			emit_insn(ctx, beq, t3, LOONGARCH_GPR_ZERO, -6);
379 			emit_zext_32(ctx, r0, true);
380 		}
381 		break;
382 	}
383 }
384 
385 static bool is_signed_bpf_cond(u8 cond)
386 {
387 	return cond == BPF_JSGT || cond == BPF_JSLT ||
388 	       cond == BPF_JSGE || cond == BPF_JSLE;
389 }
390 
391 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool extra_pass)
392 {
393 	const bool is32 = BPF_CLASS(insn->code) == BPF_ALU ||
394 			  BPF_CLASS(insn->code) == BPF_JMP32;
395 	const u8 code = insn->code;
396 	const u8 cond = BPF_OP(code);
397 	const u8 t1 = LOONGARCH_GPR_T1;
398 	const u8 t2 = LOONGARCH_GPR_T2;
399 	const u8 src = regmap[insn->src_reg];
400 	const u8 dst = regmap[insn->dst_reg];
401 	const s16 off = insn->off;
402 	const s32 imm = insn->imm;
403 	int jmp_offset;
404 	int i = insn - ctx->prog->insnsi;
405 
406 	switch (code) {
407 	/* dst = src */
408 	case BPF_ALU | BPF_MOV | BPF_X:
409 	case BPF_ALU64 | BPF_MOV | BPF_X:
410 		move_reg(ctx, dst, src);
411 		emit_zext_32(ctx, dst, is32);
412 		break;
413 
414 	/* dst = imm */
415 	case BPF_ALU | BPF_MOV | BPF_K:
416 	case BPF_ALU64 | BPF_MOV | BPF_K:
417 		move_imm(ctx, dst, imm, is32);
418 		break;
419 
420 	/* dst = dst + src */
421 	case BPF_ALU | BPF_ADD | BPF_X:
422 	case BPF_ALU64 | BPF_ADD | BPF_X:
423 		emit_insn(ctx, addd, dst, dst, src);
424 		emit_zext_32(ctx, dst, is32);
425 		break;
426 
427 	/* dst = dst + imm */
428 	case BPF_ALU | BPF_ADD | BPF_K:
429 	case BPF_ALU64 | BPF_ADD | BPF_K:
430 		if (is_signed_imm12(imm)) {
431 			emit_insn(ctx, addid, dst, dst, imm);
432 		} else {
433 			move_imm(ctx, t1, imm, is32);
434 			emit_insn(ctx, addd, dst, dst, t1);
435 		}
436 		emit_zext_32(ctx, dst, is32);
437 		break;
438 
439 	/* dst = dst - src */
440 	case BPF_ALU | BPF_SUB | BPF_X:
441 	case BPF_ALU64 | BPF_SUB | BPF_X:
442 		emit_insn(ctx, subd, dst, dst, src);
443 		emit_zext_32(ctx, dst, is32);
444 		break;
445 
446 	/* dst = dst - imm */
447 	case BPF_ALU | BPF_SUB | BPF_K:
448 	case BPF_ALU64 | BPF_SUB | BPF_K:
449 		if (is_signed_imm12(-imm)) {
450 			emit_insn(ctx, addid, dst, dst, -imm);
451 		} else {
452 			move_imm(ctx, t1, imm, is32);
453 			emit_insn(ctx, subd, dst, dst, t1);
454 		}
455 		emit_zext_32(ctx, dst, is32);
456 		break;
457 
458 	/* dst = dst * src */
459 	case BPF_ALU | BPF_MUL | BPF_X:
460 	case BPF_ALU64 | BPF_MUL | BPF_X:
461 		emit_insn(ctx, muld, dst, dst, src);
462 		emit_zext_32(ctx, dst, is32);
463 		break;
464 
465 	/* dst = dst * imm */
466 	case BPF_ALU | BPF_MUL | BPF_K:
467 	case BPF_ALU64 | BPF_MUL | BPF_K:
468 		move_imm(ctx, t1, imm, is32);
469 		emit_insn(ctx, muld, dst, dst, t1);
470 		emit_zext_32(ctx, dst, is32);
471 		break;
472 
473 	/* dst = dst / src */
474 	case BPF_ALU | BPF_DIV | BPF_X:
475 	case BPF_ALU64 | BPF_DIV | BPF_X:
476 		emit_zext_32(ctx, dst, is32);
477 		move_reg(ctx, t1, src);
478 		emit_zext_32(ctx, t1, is32);
479 		emit_insn(ctx, divdu, dst, dst, t1);
480 		emit_zext_32(ctx, dst, is32);
481 		break;
482 
483 	/* dst = dst / imm */
484 	case BPF_ALU | BPF_DIV | BPF_K:
485 	case BPF_ALU64 | BPF_DIV | BPF_K:
486 		move_imm(ctx, t1, imm, is32);
487 		emit_zext_32(ctx, dst, is32);
488 		emit_insn(ctx, divdu, dst, dst, t1);
489 		emit_zext_32(ctx, dst, is32);
490 		break;
491 
492 	/* dst = dst % src */
493 	case BPF_ALU | BPF_MOD | BPF_X:
494 	case BPF_ALU64 | BPF_MOD | BPF_X:
495 		emit_zext_32(ctx, dst, is32);
496 		move_reg(ctx, t1, src);
497 		emit_zext_32(ctx, t1, is32);
498 		emit_insn(ctx, moddu, dst, dst, t1);
499 		emit_zext_32(ctx, dst, is32);
500 		break;
501 
502 	/* dst = dst % imm */
503 	case BPF_ALU | BPF_MOD | BPF_K:
504 	case BPF_ALU64 | BPF_MOD | BPF_K:
505 		move_imm(ctx, t1, imm, is32);
506 		emit_zext_32(ctx, dst, is32);
507 		emit_insn(ctx, moddu, dst, dst, t1);
508 		emit_zext_32(ctx, dst, is32);
509 		break;
510 
511 	/* dst = -dst */
512 	case BPF_ALU | BPF_NEG:
513 	case BPF_ALU64 | BPF_NEG:
514 		move_imm(ctx, t1, imm, is32);
515 		emit_insn(ctx, subd, dst, LOONGARCH_GPR_ZERO, dst);
516 		emit_zext_32(ctx, dst, is32);
517 		break;
518 
519 	/* dst = dst & src */
520 	case BPF_ALU | BPF_AND | BPF_X:
521 	case BPF_ALU64 | BPF_AND | BPF_X:
522 		emit_insn(ctx, and, dst, dst, src);
523 		emit_zext_32(ctx, dst, is32);
524 		break;
525 
526 	/* dst = dst & imm */
527 	case BPF_ALU | BPF_AND | BPF_K:
528 	case BPF_ALU64 | BPF_AND | BPF_K:
529 		if (is_unsigned_imm12(imm)) {
530 			emit_insn(ctx, andi, dst, dst, imm);
531 		} else {
532 			move_imm(ctx, t1, imm, is32);
533 			emit_insn(ctx, and, dst, dst, t1);
534 		}
535 		emit_zext_32(ctx, dst, is32);
536 		break;
537 
538 	/* dst = dst | src */
539 	case BPF_ALU | BPF_OR | BPF_X:
540 	case BPF_ALU64 | BPF_OR | BPF_X:
541 		emit_insn(ctx, or, dst, dst, src);
542 		emit_zext_32(ctx, dst, is32);
543 		break;
544 
545 	/* dst = dst | imm */
546 	case BPF_ALU | BPF_OR | BPF_K:
547 	case BPF_ALU64 | BPF_OR | BPF_K:
548 		if (is_unsigned_imm12(imm)) {
549 			emit_insn(ctx, ori, dst, dst, imm);
550 		} else {
551 			move_imm(ctx, t1, imm, is32);
552 			emit_insn(ctx, or, dst, dst, t1);
553 		}
554 		emit_zext_32(ctx, dst, is32);
555 		break;
556 
557 	/* dst = dst ^ src */
558 	case BPF_ALU | BPF_XOR | BPF_X:
559 	case BPF_ALU64 | BPF_XOR | BPF_X:
560 		emit_insn(ctx, xor, dst, dst, src);
561 		emit_zext_32(ctx, dst, is32);
562 		break;
563 
564 	/* dst = dst ^ imm */
565 	case BPF_ALU | BPF_XOR | BPF_K:
566 	case BPF_ALU64 | BPF_XOR | BPF_K:
567 		if (is_unsigned_imm12(imm)) {
568 			emit_insn(ctx, xori, dst, dst, imm);
569 		} else {
570 			move_imm(ctx, t1, imm, is32);
571 			emit_insn(ctx, xor, dst, dst, t1);
572 		}
573 		emit_zext_32(ctx, dst, is32);
574 		break;
575 
576 	/* dst = dst << src (logical) */
577 	case BPF_ALU | BPF_LSH | BPF_X:
578 		emit_insn(ctx, sllw, dst, dst, src);
579 		emit_zext_32(ctx, dst, is32);
580 		break;
581 
582 	case BPF_ALU64 | BPF_LSH | BPF_X:
583 		emit_insn(ctx, slld, dst, dst, src);
584 		break;
585 
586 	/* dst = dst << imm (logical) */
587 	case BPF_ALU | BPF_LSH | BPF_K:
588 		emit_insn(ctx, slliw, dst, dst, imm);
589 		emit_zext_32(ctx, dst, is32);
590 		break;
591 
592 	case BPF_ALU64 | BPF_LSH | BPF_K:
593 		emit_insn(ctx, sllid, dst, dst, imm);
594 		break;
595 
596 	/* dst = dst >> src (logical) */
597 	case BPF_ALU | BPF_RSH | BPF_X:
598 		emit_insn(ctx, srlw, dst, dst, src);
599 		emit_zext_32(ctx, dst, is32);
600 		break;
601 
602 	case BPF_ALU64 | BPF_RSH | BPF_X:
603 		emit_insn(ctx, srld, dst, dst, src);
604 		break;
605 
606 	/* dst = dst >> imm (logical) */
607 	case BPF_ALU | BPF_RSH | BPF_K:
608 		emit_insn(ctx, srliw, dst, dst, imm);
609 		emit_zext_32(ctx, dst, is32);
610 		break;
611 
612 	case BPF_ALU64 | BPF_RSH | BPF_K:
613 		emit_insn(ctx, srlid, dst, dst, imm);
614 		break;
615 
616 	/* dst = dst >> src (arithmetic) */
617 	case BPF_ALU | BPF_ARSH | BPF_X:
618 		emit_insn(ctx, sraw, dst, dst, src);
619 		emit_zext_32(ctx, dst, is32);
620 		break;
621 
622 	case BPF_ALU64 | BPF_ARSH | BPF_X:
623 		emit_insn(ctx, srad, dst, dst, src);
624 		break;
625 
626 	/* dst = dst >> imm (arithmetic) */
627 	case BPF_ALU | BPF_ARSH | BPF_K:
628 		emit_insn(ctx, sraiw, dst, dst, imm);
629 		emit_zext_32(ctx, dst, is32);
630 		break;
631 
632 	case BPF_ALU64 | BPF_ARSH | BPF_K:
633 		emit_insn(ctx, sraid, dst, dst, imm);
634 		break;
635 
636 	/* dst = BSWAP##imm(dst) */
637 	case BPF_ALU | BPF_END | BPF_FROM_LE:
638 		switch (imm) {
639 		case 16:
640 			/* zero-extend 16 bits into 64 bits */
641 			emit_insn(ctx, bstrpickd, dst, dst, 15, 0);
642 			break;
643 		case 32:
644 			/* zero-extend 32 bits into 64 bits */
645 			emit_zext_32(ctx, dst, is32);
646 			break;
647 		case 64:
648 			/* do nothing */
649 			break;
650 		}
651 		break;
652 
653 	case BPF_ALU | BPF_END | BPF_FROM_BE:
654 		switch (imm) {
655 		case 16:
656 			emit_insn(ctx, revb2h, dst, dst);
657 			/* zero-extend 16 bits into 64 bits */
658 			emit_insn(ctx, bstrpickd, dst, dst, 15, 0);
659 			break;
660 		case 32:
661 			emit_insn(ctx, revb2w, dst, dst);
662 			/* zero-extend 32 bits into 64 bits */
663 			emit_zext_32(ctx, dst, is32);
664 			break;
665 		case 64:
666 			emit_insn(ctx, revbd, dst, dst);
667 			break;
668 		}
669 		break;
670 
671 	/* PC += off if dst cond src */
672 	case BPF_JMP | BPF_JEQ | BPF_X:
673 	case BPF_JMP | BPF_JNE | BPF_X:
674 	case BPF_JMP | BPF_JGT | BPF_X:
675 	case BPF_JMP | BPF_JGE | BPF_X:
676 	case BPF_JMP | BPF_JLT | BPF_X:
677 	case BPF_JMP | BPF_JLE | BPF_X:
678 	case BPF_JMP | BPF_JSGT | BPF_X:
679 	case BPF_JMP | BPF_JSGE | BPF_X:
680 	case BPF_JMP | BPF_JSLT | BPF_X:
681 	case BPF_JMP | BPF_JSLE | BPF_X:
682 	case BPF_JMP32 | BPF_JEQ | BPF_X:
683 	case BPF_JMP32 | BPF_JNE | BPF_X:
684 	case BPF_JMP32 | BPF_JGT | BPF_X:
685 	case BPF_JMP32 | BPF_JGE | BPF_X:
686 	case BPF_JMP32 | BPF_JLT | BPF_X:
687 	case BPF_JMP32 | BPF_JLE | BPF_X:
688 	case BPF_JMP32 | BPF_JSGT | BPF_X:
689 	case BPF_JMP32 | BPF_JSGE | BPF_X:
690 	case BPF_JMP32 | BPF_JSLT | BPF_X:
691 	case BPF_JMP32 | BPF_JSLE | BPF_X:
692 		jmp_offset = bpf2la_offset(i, off, ctx);
693 		move_reg(ctx, t1, dst);
694 		move_reg(ctx, t2, src);
695 		if (is_signed_bpf_cond(BPF_OP(code))) {
696 			emit_sext_32(ctx, t1, is32);
697 			emit_sext_32(ctx, t2, is32);
698 		} else {
699 			emit_zext_32(ctx, t1, is32);
700 			emit_zext_32(ctx, t2, is32);
701 		}
702 		if (emit_cond_jmp(ctx, cond, t1, t2, jmp_offset) < 0)
703 			goto toofar;
704 		break;
705 
706 	/* PC += off if dst cond imm */
707 	case BPF_JMP | BPF_JEQ | BPF_K:
708 	case BPF_JMP | BPF_JNE | BPF_K:
709 	case BPF_JMP | BPF_JGT | BPF_K:
710 	case BPF_JMP | BPF_JGE | BPF_K:
711 	case BPF_JMP | BPF_JLT | BPF_K:
712 	case BPF_JMP | BPF_JLE | BPF_K:
713 	case BPF_JMP | BPF_JSGT | BPF_K:
714 	case BPF_JMP | BPF_JSGE | BPF_K:
715 	case BPF_JMP | BPF_JSLT | BPF_K:
716 	case BPF_JMP | BPF_JSLE | BPF_K:
717 	case BPF_JMP32 | BPF_JEQ | BPF_K:
718 	case BPF_JMP32 | BPF_JNE | BPF_K:
719 	case BPF_JMP32 | BPF_JGT | BPF_K:
720 	case BPF_JMP32 | BPF_JGE | BPF_K:
721 	case BPF_JMP32 | BPF_JLT | BPF_K:
722 	case BPF_JMP32 | BPF_JLE | BPF_K:
723 	case BPF_JMP32 | BPF_JSGT | BPF_K:
724 	case BPF_JMP32 | BPF_JSGE | BPF_K:
725 	case BPF_JMP32 | BPF_JSLT | BPF_K:
726 	case BPF_JMP32 | BPF_JSLE | BPF_K:
727 		u8 t7 = -1;
728 		jmp_offset = bpf2la_offset(i, off, ctx);
729 		if (imm) {
730 			move_imm(ctx, t1, imm, false);
731 			t7 = t1;
732 		} else {
733 			/* If imm is 0, simply use zero register. */
734 			t7 = LOONGARCH_GPR_ZERO;
735 		}
736 		move_reg(ctx, t2, dst);
737 		if (is_signed_bpf_cond(BPF_OP(code))) {
738 			emit_sext_32(ctx, t7, is32);
739 			emit_sext_32(ctx, t2, is32);
740 		} else {
741 			emit_zext_32(ctx, t7, is32);
742 			emit_zext_32(ctx, t2, is32);
743 		}
744 		if (emit_cond_jmp(ctx, cond, t2, t7, jmp_offset) < 0)
745 			goto toofar;
746 		break;
747 
748 	/* PC += off if dst & src */
749 	case BPF_JMP | BPF_JSET | BPF_X:
750 	case BPF_JMP32 | BPF_JSET | BPF_X:
751 		jmp_offset = bpf2la_offset(i, off, ctx);
752 		emit_insn(ctx, and, t1, dst, src);
753 		emit_zext_32(ctx, t1, is32);
754 		if (emit_cond_jmp(ctx, cond, t1, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
755 			goto toofar;
756 		break;
757 
758 	/* PC += off if dst & imm */
759 	case BPF_JMP | BPF_JSET | BPF_K:
760 	case BPF_JMP32 | BPF_JSET | BPF_K:
761 		jmp_offset = bpf2la_offset(i, off, ctx);
762 		move_imm(ctx, t1, imm, is32);
763 		emit_insn(ctx, and, t1, dst, t1);
764 		emit_zext_32(ctx, t1, is32);
765 		if (emit_cond_jmp(ctx, cond, t1, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
766 			goto toofar;
767 		break;
768 
769 	/* PC += off */
770 	case BPF_JMP | BPF_JA:
771 		jmp_offset = bpf2la_offset(i, off, ctx);
772 		if (emit_uncond_jmp(ctx, jmp_offset) < 0)
773 			goto toofar;
774 		break;
775 
776 	/* function call */
777 	case BPF_JMP | BPF_CALL:
778 		int ret;
779 		u64 func_addr;
780 		bool func_addr_fixed;
781 
782 		mark_call(ctx);
783 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
784 					    &func_addr, &func_addr_fixed);
785 		if (ret < 0)
786 			return ret;
787 
788 		move_imm(ctx, t1, func_addr, is32);
789 		emit_insn(ctx, jirl, t1, LOONGARCH_GPR_RA, 0);
790 		move_reg(ctx, regmap[BPF_REG_0], LOONGARCH_GPR_A0);
791 		break;
792 
793 	/* tail call */
794 	case BPF_JMP | BPF_TAIL_CALL:
795 		mark_tail_call(ctx);
796 		if (emit_bpf_tail_call(ctx) < 0)
797 			return -EINVAL;
798 		break;
799 
800 	/* function return */
801 	case BPF_JMP | BPF_EXIT:
802 		emit_sext_32(ctx, regmap[BPF_REG_0], true);
803 
804 		if (i == ctx->prog->len - 1)
805 			break;
806 
807 		jmp_offset = epilogue_offset(ctx);
808 		if (emit_uncond_jmp(ctx, jmp_offset) < 0)
809 			goto toofar;
810 		break;
811 
812 	/* dst = imm64 */
813 	case BPF_LD | BPF_IMM | BPF_DW:
814 		u64 imm64 = (u64)(insn + 1)->imm << 32 | (u32)insn->imm;
815 
816 		move_imm(ctx, dst, imm64, is32);
817 		return 1;
818 
819 	/* dst = *(size *)(src + off) */
820 	case BPF_LDX | BPF_MEM | BPF_B:
821 	case BPF_LDX | BPF_MEM | BPF_H:
822 	case BPF_LDX | BPF_MEM | BPF_W:
823 	case BPF_LDX | BPF_MEM | BPF_DW:
824 		switch (BPF_SIZE(code)) {
825 		case BPF_B:
826 			if (is_signed_imm12(off)) {
827 				emit_insn(ctx, ldbu, dst, src, off);
828 			} else {
829 				move_imm(ctx, t1, off, is32);
830 				emit_insn(ctx, ldxbu, dst, src, t1);
831 			}
832 			break;
833 		case BPF_H:
834 			if (is_signed_imm12(off)) {
835 				emit_insn(ctx, ldhu, dst, src, off);
836 			} else {
837 				move_imm(ctx, t1, off, is32);
838 				emit_insn(ctx, ldxhu, dst, src, t1);
839 			}
840 			break;
841 		case BPF_W:
842 			if (is_signed_imm12(off)) {
843 				emit_insn(ctx, ldwu, dst, src, off);
844 			} else if (is_signed_imm14(off)) {
845 				emit_insn(ctx, ldptrw, dst, src, off);
846 			} else {
847 				move_imm(ctx, t1, off, is32);
848 				emit_insn(ctx, ldxwu, dst, src, t1);
849 			}
850 			break;
851 		case BPF_DW:
852 			if (is_signed_imm12(off)) {
853 				emit_insn(ctx, ldd, dst, src, off);
854 			} else if (is_signed_imm14(off)) {
855 				emit_insn(ctx, ldptrd, dst, src, off);
856 			} else {
857 				move_imm(ctx, t1, off, is32);
858 				emit_insn(ctx, ldxd, dst, src, t1);
859 			}
860 			break;
861 		}
862 		break;
863 
864 	/* *(size *)(dst + off) = imm */
865 	case BPF_ST | BPF_MEM | BPF_B:
866 	case BPF_ST | BPF_MEM | BPF_H:
867 	case BPF_ST | BPF_MEM | BPF_W:
868 	case BPF_ST | BPF_MEM | BPF_DW:
869 		switch (BPF_SIZE(code)) {
870 		case BPF_B:
871 			move_imm(ctx, t1, imm, is32);
872 			if (is_signed_imm12(off)) {
873 				emit_insn(ctx, stb, t1, dst, off);
874 			} else {
875 				move_imm(ctx, t2, off, is32);
876 				emit_insn(ctx, stxb, t1, dst, t2);
877 			}
878 			break;
879 		case BPF_H:
880 			move_imm(ctx, t1, imm, is32);
881 			if (is_signed_imm12(off)) {
882 				emit_insn(ctx, sth, t1, dst, off);
883 			} else {
884 				move_imm(ctx, t2, off, is32);
885 				emit_insn(ctx, stxh, t1, dst, t2);
886 			}
887 			break;
888 		case BPF_W:
889 			move_imm(ctx, t1, imm, is32);
890 			if (is_signed_imm12(off)) {
891 				emit_insn(ctx, stw, t1, dst, off);
892 			} else if (is_signed_imm14(off)) {
893 				emit_insn(ctx, stptrw, t1, dst, off);
894 			} else {
895 				move_imm(ctx, t2, off, is32);
896 				emit_insn(ctx, stxw, t1, dst, t2);
897 			}
898 			break;
899 		case BPF_DW:
900 			move_imm(ctx, t1, imm, is32);
901 			if (is_signed_imm12(off)) {
902 				emit_insn(ctx, std, t1, dst, off);
903 			} else if (is_signed_imm14(off)) {
904 				emit_insn(ctx, stptrd, t1, dst, off);
905 			} else {
906 				move_imm(ctx, t2, off, is32);
907 				emit_insn(ctx, stxd, t1, dst, t2);
908 			}
909 			break;
910 		}
911 		break;
912 
913 	/* *(size *)(dst + off) = src */
914 	case BPF_STX | BPF_MEM | BPF_B:
915 	case BPF_STX | BPF_MEM | BPF_H:
916 	case BPF_STX | BPF_MEM | BPF_W:
917 	case BPF_STX | BPF_MEM | BPF_DW:
918 		switch (BPF_SIZE(code)) {
919 		case BPF_B:
920 			if (is_signed_imm12(off)) {
921 				emit_insn(ctx, stb, src, dst, off);
922 			} else {
923 				move_imm(ctx, t1, off, is32);
924 				emit_insn(ctx, stxb, src, dst, t1);
925 			}
926 			break;
927 		case BPF_H:
928 			if (is_signed_imm12(off)) {
929 				emit_insn(ctx, sth, src, dst, off);
930 			} else {
931 				move_imm(ctx, t1, off, is32);
932 				emit_insn(ctx, stxh, src, dst, t1);
933 			}
934 			break;
935 		case BPF_W:
936 			if (is_signed_imm12(off)) {
937 				emit_insn(ctx, stw, src, dst, off);
938 			} else if (is_signed_imm14(off)) {
939 				emit_insn(ctx, stptrw, src, dst, off);
940 			} else {
941 				move_imm(ctx, t1, off, is32);
942 				emit_insn(ctx, stxw, src, dst, t1);
943 			}
944 			break;
945 		case BPF_DW:
946 			if (is_signed_imm12(off)) {
947 				emit_insn(ctx, std, src, dst, off);
948 			} else if (is_signed_imm14(off)) {
949 				emit_insn(ctx, stptrd, src, dst, off);
950 			} else {
951 				move_imm(ctx, t1, off, is32);
952 				emit_insn(ctx, stxd, src, dst, t1);
953 			}
954 			break;
955 		}
956 		break;
957 
958 	case BPF_STX | BPF_ATOMIC | BPF_W:
959 	case BPF_STX | BPF_ATOMIC | BPF_DW:
960 		emit_atomic(insn, ctx);
961 		break;
962 
963 	default:
964 		pr_err("bpf_jit: unknown opcode %02x\n", code);
965 		return -EINVAL;
966 	}
967 
968 	return 0;
969 
970 toofar:
971 	pr_info_once("bpf_jit: opcode %02x, jump too far\n", code);
972 	return -E2BIG;
973 }
974 
975 static int build_body(struct jit_ctx *ctx, bool extra_pass)
976 {
977 	int i;
978 	const struct bpf_prog *prog = ctx->prog;
979 
980 	for (i = 0; i < prog->len; i++) {
981 		const struct bpf_insn *insn = &prog->insnsi[i];
982 		int ret;
983 
984 		if (ctx->image == NULL)
985 			ctx->offset[i] = ctx->idx;
986 
987 		ret = build_insn(insn, ctx, extra_pass);
988 		if (ret > 0) {
989 			i++;
990 			if (ctx->image == NULL)
991 				ctx->offset[i] = ctx->idx;
992 			continue;
993 		}
994 		if (ret)
995 			return ret;
996 	}
997 
998 	if (ctx->image == NULL)
999 		ctx->offset[i] = ctx->idx;
1000 
1001 	return 0;
1002 }
1003 
1004 /* Fill space with break instructions */
1005 static void jit_fill_hole(void *area, unsigned int size)
1006 {
1007 	u32 *ptr;
1008 
1009 	/* We are guaranteed to have aligned memory */
1010 	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
1011 		*ptr++ = INSN_BREAK;
1012 }
1013 
1014 static int validate_code(struct jit_ctx *ctx)
1015 {
1016 	int i;
1017 	union loongarch_instruction insn;
1018 
1019 	for (i = 0; i < ctx->idx; i++) {
1020 		insn = ctx->image[i];
1021 		/* Check INSN_BREAK */
1022 		if (insn.word == INSN_BREAK)
1023 			return -1;
1024 	}
1025 
1026 	return 0;
1027 }
1028 
1029 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1030 {
1031 	bool tmp_blinded = false, extra_pass = false;
1032 	u8 *image_ptr;
1033 	int image_size;
1034 	struct jit_ctx ctx;
1035 	struct jit_data *jit_data;
1036 	struct bpf_binary_header *header;
1037 	struct bpf_prog *tmp, *orig_prog = prog;
1038 
1039 	/*
1040 	 * If BPF JIT was not enabled then we must fall back to
1041 	 * the interpreter.
1042 	 */
1043 	if (!prog->jit_requested)
1044 		return orig_prog;
1045 
1046 	tmp = bpf_jit_blind_constants(prog);
1047 	/*
1048 	 * If blinding was requested and we failed during blinding,
1049 	 * we must fall back to the interpreter. Otherwise, we save
1050 	 * the new JITed code.
1051 	 */
1052 	if (IS_ERR(tmp))
1053 		return orig_prog;
1054 
1055 	if (tmp != prog) {
1056 		tmp_blinded = true;
1057 		prog = tmp;
1058 	}
1059 
1060 	jit_data = prog->aux->jit_data;
1061 	if (!jit_data) {
1062 		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1063 		if (!jit_data) {
1064 			prog = orig_prog;
1065 			goto out;
1066 		}
1067 		prog->aux->jit_data = jit_data;
1068 	}
1069 	if (jit_data->ctx.offset) {
1070 		ctx = jit_data->ctx;
1071 		image_ptr = jit_data->image;
1072 		header = jit_data->header;
1073 		extra_pass = true;
1074 		image_size = sizeof(u32) * ctx.idx;
1075 		goto skip_init_ctx;
1076 	}
1077 
1078 	memset(&ctx, 0, sizeof(ctx));
1079 	ctx.prog = prog;
1080 
1081 	ctx.offset = kvcalloc(prog->len + 1, sizeof(u32), GFP_KERNEL);
1082 	if (ctx.offset == NULL) {
1083 		prog = orig_prog;
1084 		goto out_offset;
1085 	}
1086 
1087 	/* 1. Initial fake pass to compute ctx->idx and set ctx->flags */
1088 	build_prologue(&ctx);
1089 	if (build_body(&ctx, extra_pass)) {
1090 		prog = orig_prog;
1091 		goto out_offset;
1092 	}
1093 	ctx.epilogue_offset = ctx.idx;
1094 	build_epilogue(&ctx);
1095 
1096 	/* Now we know the actual image size.
1097 	 * As each LoongArch instruction is of length 32bit,
1098 	 * we are translating number of JITed intructions into
1099 	 * the size required to store these JITed code.
1100 	 */
1101 	image_size = sizeof(u32) * ctx.idx;
1102 	/* Now we know the size of the structure to make */
1103 	header = bpf_jit_binary_alloc(image_size, &image_ptr,
1104 				      sizeof(u32), jit_fill_hole);
1105 	if (header == NULL) {
1106 		prog = orig_prog;
1107 		goto out_offset;
1108 	}
1109 
1110 	/* 2. Now, the actual pass to generate final JIT code */
1111 	ctx.image = (union loongarch_instruction *)image_ptr;
1112 
1113 skip_init_ctx:
1114 	ctx.idx = 0;
1115 
1116 	build_prologue(&ctx);
1117 	if (build_body(&ctx, extra_pass)) {
1118 		bpf_jit_binary_free(header);
1119 		prog = orig_prog;
1120 		goto out_offset;
1121 	}
1122 	build_epilogue(&ctx);
1123 
1124 	/* 3. Extra pass to validate JITed code */
1125 	if (validate_code(&ctx)) {
1126 		bpf_jit_binary_free(header);
1127 		prog = orig_prog;
1128 		goto out_offset;
1129 	}
1130 
1131 	/* And we're done */
1132 	if (bpf_jit_enable > 1)
1133 		bpf_jit_dump(prog->len, image_size, 2, ctx.image);
1134 
1135 	/* Update the icache */
1136 	flush_icache_range((unsigned long)header, (unsigned long)(ctx.image + ctx.idx));
1137 
1138 	if (!prog->is_func || extra_pass) {
1139 		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
1140 			pr_err_once("multi-func JIT bug %d != %d\n",
1141 				    ctx.idx, jit_data->ctx.idx);
1142 			bpf_jit_binary_free(header);
1143 			prog->bpf_func = NULL;
1144 			prog->jited = 0;
1145 			prog->jited_len = 0;
1146 			goto out_offset;
1147 		}
1148 		bpf_jit_binary_lock_ro(header);
1149 	} else {
1150 		jit_data->ctx = ctx;
1151 		jit_data->image = image_ptr;
1152 		jit_data->header = header;
1153 	}
1154 	prog->jited = 1;
1155 	prog->jited_len = image_size;
1156 	prog->bpf_func = (void *)ctx.image;
1157 
1158 	if (!prog->is_func || extra_pass) {
1159 		int i;
1160 
1161 		/* offset[prog->len] is the size of program */
1162 		for (i = 0; i <= prog->len; i++)
1163 			ctx.offset[i] *= LOONGARCH_INSN_SIZE;
1164 		bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
1165 
1166 out_offset:
1167 		kvfree(ctx.offset);
1168 		kfree(jit_data);
1169 		prog->aux->jit_data = NULL;
1170 	}
1171 
1172 out:
1173 	if (tmp_blinded)
1174 		bpf_jit_prog_release_other(prog, prog == orig_prog ? tmp : orig_prog);
1175 
1176 	out_offset = -1;
1177 
1178 	return prog;
1179 }
1180