xref: /openbmc/linux/arch/riscv/net/bpf_jit_comp64.c (revision e6b9d8eddb1772d99a676a906d42865293934edd)
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7 
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include <linux/memory.h>
12 #include <linux/stop_machine.h>
13 #include <asm/patch.h>
14 #include "bpf_jit.h"
15 
16 #define RV_REG_TCC RV_REG_A6
17 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
18 
19 static const int regmap[] = {
20 	[BPF_REG_0] =	RV_REG_A5,
21 	[BPF_REG_1] =	RV_REG_A0,
22 	[BPF_REG_2] =	RV_REG_A1,
23 	[BPF_REG_3] =	RV_REG_A2,
24 	[BPF_REG_4] =	RV_REG_A3,
25 	[BPF_REG_5] =	RV_REG_A4,
26 	[BPF_REG_6] =	RV_REG_S1,
27 	[BPF_REG_7] =	RV_REG_S2,
28 	[BPF_REG_8] =	RV_REG_S3,
29 	[BPF_REG_9] =	RV_REG_S4,
30 	[BPF_REG_FP] =	RV_REG_S5,
31 	[BPF_REG_AX] =	RV_REG_T0,
32 };
33 
34 static const int pt_regmap[] = {
35 	[RV_REG_A0] = offsetof(struct pt_regs, a0),
36 	[RV_REG_A1] = offsetof(struct pt_regs, a1),
37 	[RV_REG_A2] = offsetof(struct pt_regs, a2),
38 	[RV_REG_A3] = offsetof(struct pt_regs, a3),
39 	[RV_REG_A4] = offsetof(struct pt_regs, a4),
40 	[RV_REG_A5] = offsetof(struct pt_regs, a5),
41 	[RV_REG_S1] = offsetof(struct pt_regs, s1),
42 	[RV_REG_S2] = offsetof(struct pt_regs, s2),
43 	[RV_REG_S3] = offsetof(struct pt_regs, s3),
44 	[RV_REG_S4] = offsetof(struct pt_regs, s4),
45 	[RV_REG_S5] = offsetof(struct pt_regs, s5),
46 	[RV_REG_T0] = offsetof(struct pt_regs, t0),
47 };
48 
49 enum {
50 	RV_CTX_F_SEEN_TAIL_CALL =	0,
51 	RV_CTX_F_SEEN_CALL =		RV_REG_RA,
52 	RV_CTX_F_SEEN_S1 =		RV_REG_S1,
53 	RV_CTX_F_SEEN_S2 =		RV_REG_S2,
54 	RV_CTX_F_SEEN_S3 =		RV_REG_S3,
55 	RV_CTX_F_SEEN_S4 =		RV_REG_S4,
56 	RV_CTX_F_SEEN_S5 =		RV_REG_S5,
57 	RV_CTX_F_SEEN_S6 =		RV_REG_S6,
58 };
59 
60 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
61 {
62 	u8 reg = regmap[bpf_reg];
63 
64 	switch (reg) {
65 	case RV_CTX_F_SEEN_S1:
66 	case RV_CTX_F_SEEN_S2:
67 	case RV_CTX_F_SEEN_S3:
68 	case RV_CTX_F_SEEN_S4:
69 	case RV_CTX_F_SEEN_S5:
70 	case RV_CTX_F_SEEN_S6:
71 		__set_bit(reg, &ctx->flags);
72 	}
73 	return reg;
74 };
75 
76 static bool seen_reg(int reg, struct rv_jit_context *ctx)
77 {
78 	switch (reg) {
79 	case RV_CTX_F_SEEN_CALL:
80 	case RV_CTX_F_SEEN_S1:
81 	case RV_CTX_F_SEEN_S2:
82 	case RV_CTX_F_SEEN_S3:
83 	case RV_CTX_F_SEEN_S4:
84 	case RV_CTX_F_SEEN_S5:
85 	case RV_CTX_F_SEEN_S6:
86 		return test_bit(reg, &ctx->flags);
87 	}
88 	return false;
89 }
90 
91 static void mark_fp(struct rv_jit_context *ctx)
92 {
93 	__set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
94 }
95 
96 static void mark_call(struct rv_jit_context *ctx)
97 {
98 	__set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
99 }
100 
101 static bool seen_call(struct rv_jit_context *ctx)
102 {
103 	return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
104 }
105 
106 static void mark_tail_call(struct rv_jit_context *ctx)
107 {
108 	__set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
109 }
110 
111 static bool seen_tail_call(struct rv_jit_context *ctx)
112 {
113 	return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
114 }
115 
116 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
117 {
118 	mark_tail_call(ctx);
119 
120 	if (seen_call(ctx)) {
121 		__set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
122 		return RV_REG_S6;
123 	}
124 	return RV_REG_A6;
125 }
126 
127 static bool is_32b_int(s64 val)
128 {
129 	return -(1L << 31) <= val && val < (1L << 31);
130 }
131 
132 static bool in_auipc_jalr_range(s64 val)
133 {
134 	/*
135 	 * auipc+jalr can reach any signed PC-relative offset in the range
136 	 * [-2^31 - 2^11, 2^31 - 2^11).
137 	 */
138 	return (-(1L << 31) - (1L << 11)) <= val &&
139 		val < ((1L << 31) - (1L << 11));
140 }
141 
142 /* Emit fixed-length instructions for address */
143 static int emit_addr(u8 rd, u64 addr, bool extra_pass, struct rv_jit_context *ctx)
144 {
145 	u64 ip = (u64)(ctx->insns + ctx->ninsns);
146 	s64 off = addr - ip;
147 	s64 upper = (off + (1 << 11)) >> 12;
148 	s64 lower = off & 0xfff;
149 
150 	if (extra_pass && !in_auipc_jalr_range(off)) {
151 		pr_err("bpf-jit: target offset 0x%llx is out of range\n", off);
152 		return -ERANGE;
153 	}
154 
155 	emit(rv_auipc(rd, upper), ctx);
156 	emit(rv_addi(rd, rd, lower), ctx);
157 	return 0;
158 }
159 
160 /* Emit variable-length instructions for 32-bit and 64-bit imm */
161 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
162 {
163 	/* Note that the immediate from the add is sign-extended,
164 	 * which means that we need to compensate this by adding 2^12,
165 	 * when the 12th bit is set. A simpler way of doing this, and
166 	 * getting rid of the check, is to just add 2**11 before the
167 	 * shift. The "Loading a 32-Bit constant" example from the
168 	 * "Computer Organization and Design, RISC-V edition" book by
169 	 * Patterson/Hennessy highlights this fact.
170 	 *
171 	 * This also means that we need to process LSB to MSB.
172 	 */
173 	s64 upper = (val + (1 << 11)) >> 12;
174 	/* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
175 	 * and addi are signed and RVC checks will perform signed comparisons.
176 	 */
177 	s64 lower = ((val & 0xfff) << 52) >> 52;
178 	int shift;
179 
180 	if (is_32b_int(val)) {
181 		if (upper)
182 			emit_lui(rd, upper, ctx);
183 
184 		if (!upper) {
185 			emit_li(rd, lower, ctx);
186 			return;
187 		}
188 
189 		emit_addiw(rd, rd, lower, ctx);
190 		return;
191 	}
192 
193 	shift = __ffs(upper);
194 	upper >>= shift;
195 	shift += 12;
196 
197 	emit_imm(rd, upper, ctx);
198 
199 	emit_slli(rd, rd, shift, ctx);
200 	if (lower)
201 		emit_addi(rd, rd, lower, ctx);
202 }
203 
204 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
205 {
206 	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
207 
208 	if (seen_reg(RV_REG_RA, ctx)) {
209 		emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
210 		store_offset -= 8;
211 	}
212 	emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
213 	store_offset -= 8;
214 	if (seen_reg(RV_REG_S1, ctx)) {
215 		emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
216 		store_offset -= 8;
217 	}
218 	if (seen_reg(RV_REG_S2, ctx)) {
219 		emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
220 		store_offset -= 8;
221 	}
222 	if (seen_reg(RV_REG_S3, ctx)) {
223 		emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
224 		store_offset -= 8;
225 	}
226 	if (seen_reg(RV_REG_S4, ctx)) {
227 		emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
228 		store_offset -= 8;
229 	}
230 	if (seen_reg(RV_REG_S5, ctx)) {
231 		emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
232 		store_offset -= 8;
233 	}
234 	if (seen_reg(RV_REG_S6, ctx)) {
235 		emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
236 		store_offset -= 8;
237 	}
238 
239 	emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
240 	/* Set return value. */
241 	if (!is_tail_call)
242 		emit_mv(RV_REG_A0, RV_REG_A5, ctx);
243 	emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
244 		  is_tail_call ? 20 : 0, /* skip reserved nops and TCC init */
245 		  ctx);
246 }
247 
248 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
249 		     struct rv_jit_context *ctx)
250 {
251 	switch (cond) {
252 	case BPF_JEQ:
253 		emit(rv_beq(rd, rs, rvoff >> 1), ctx);
254 		return;
255 	case BPF_JGT:
256 		emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
257 		return;
258 	case BPF_JLT:
259 		emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
260 		return;
261 	case BPF_JGE:
262 		emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
263 		return;
264 	case BPF_JLE:
265 		emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
266 		return;
267 	case BPF_JNE:
268 		emit(rv_bne(rd, rs, rvoff >> 1), ctx);
269 		return;
270 	case BPF_JSGT:
271 		emit(rv_blt(rs, rd, rvoff >> 1), ctx);
272 		return;
273 	case BPF_JSLT:
274 		emit(rv_blt(rd, rs, rvoff >> 1), ctx);
275 		return;
276 	case BPF_JSGE:
277 		emit(rv_bge(rd, rs, rvoff >> 1), ctx);
278 		return;
279 	case BPF_JSLE:
280 		emit(rv_bge(rs, rd, rvoff >> 1), ctx);
281 	}
282 }
283 
284 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
285 			struct rv_jit_context *ctx)
286 {
287 	s64 upper, lower;
288 
289 	if (is_13b_int(rvoff)) {
290 		emit_bcc(cond, rd, rs, rvoff, ctx);
291 		return;
292 	}
293 
294 	/* Adjust for jal */
295 	rvoff -= 4;
296 
297 	/* Transform, e.g.:
298 	 *   bne rd,rs,foo
299 	 * to
300 	 *   beq rd,rs,<.L1>
301 	 *   (auipc foo)
302 	 *   jal(r) foo
303 	 * .L1
304 	 */
305 	cond = invert_bpf_cond(cond);
306 	if (is_21b_int(rvoff)) {
307 		emit_bcc(cond, rd, rs, 8, ctx);
308 		emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
309 		return;
310 	}
311 
312 	/* 32b No need for an additional rvoff adjustment, since we
313 	 * get that from the auipc at PC', where PC = PC' + 4.
314 	 */
315 	upper = (rvoff + (1 << 11)) >> 12;
316 	lower = rvoff & 0xfff;
317 
318 	emit_bcc(cond, rd, rs, 12, ctx);
319 	emit(rv_auipc(RV_REG_T1, upper), ctx);
320 	emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
321 }
322 
323 static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
324 {
325 	emit_slli(reg, reg, 32, ctx);
326 	emit_srli(reg, reg, 32, ctx);
327 }
328 
329 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
330 {
331 	int tc_ninsn, off, start_insn = ctx->ninsns;
332 	u8 tcc = rv_tail_call_reg(ctx);
333 
334 	/* a0: &ctx
335 	 * a1: &array
336 	 * a2: index
337 	 *
338 	 * if (index >= array->map.max_entries)
339 	 *	goto out;
340 	 */
341 	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
342 		   ctx->offset[0];
343 	emit_zext_32(RV_REG_A2, ctx);
344 
345 	off = offsetof(struct bpf_array, map.max_entries);
346 	if (is_12b_check(off, insn))
347 		return -1;
348 	emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
349 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
350 	emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
351 
352 	/* if (--TCC < 0)
353 	 *     goto out;
354 	 */
355 	emit_addi(RV_REG_TCC, tcc, -1, ctx);
356 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
357 	emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
358 
359 	/* prog = array->ptrs[index];
360 	 * if (!prog)
361 	 *     goto out;
362 	 */
363 	emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
364 	emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
365 	off = offsetof(struct bpf_array, ptrs);
366 	if (is_12b_check(off, insn))
367 		return -1;
368 	emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
369 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
370 	emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
371 
372 	/* goto *(prog->bpf_func + 4); */
373 	off = offsetof(struct bpf_prog, bpf_func);
374 	if (is_12b_check(off, insn))
375 		return -1;
376 	emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
377 	__build_epilogue(true, ctx);
378 	return 0;
379 }
380 
381 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
382 		      struct rv_jit_context *ctx)
383 {
384 	u8 code = insn->code;
385 
386 	switch (code) {
387 	case BPF_JMP | BPF_JA:
388 	case BPF_JMP | BPF_CALL:
389 	case BPF_JMP | BPF_EXIT:
390 	case BPF_JMP | BPF_TAIL_CALL:
391 		break;
392 	default:
393 		*rd = bpf_to_rv_reg(insn->dst_reg, ctx);
394 	}
395 
396 	if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
397 	    code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
398 	    code & BPF_LDX || code & BPF_STX)
399 		*rs = bpf_to_rv_reg(insn->src_reg, ctx);
400 }
401 
402 static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
403 {
404 	emit_mv(RV_REG_T2, *rd, ctx);
405 	emit_zext_32(RV_REG_T2, ctx);
406 	emit_mv(RV_REG_T1, *rs, ctx);
407 	emit_zext_32(RV_REG_T1, ctx);
408 	*rd = RV_REG_T2;
409 	*rs = RV_REG_T1;
410 }
411 
412 static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
413 {
414 	emit_addiw(RV_REG_T2, *rd, 0, ctx);
415 	emit_addiw(RV_REG_T1, *rs, 0, ctx);
416 	*rd = RV_REG_T2;
417 	*rs = RV_REG_T1;
418 }
419 
420 static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
421 {
422 	emit_mv(RV_REG_T2, *rd, ctx);
423 	emit_zext_32(RV_REG_T2, ctx);
424 	emit_zext_32(RV_REG_T1, ctx);
425 	*rd = RV_REG_T2;
426 }
427 
428 static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
429 {
430 	emit_addiw(RV_REG_T2, *rd, 0, ctx);
431 	*rd = RV_REG_T2;
432 }
433 
434 static int emit_jump_and_link(u8 rd, s64 rvoff, bool fixed_addr,
435 			      struct rv_jit_context *ctx)
436 {
437 	s64 upper, lower;
438 
439 	if (rvoff && fixed_addr && is_21b_int(rvoff)) {
440 		emit(rv_jal(rd, rvoff >> 1), ctx);
441 		return 0;
442 	} else if (in_auipc_jalr_range(rvoff)) {
443 		upper = (rvoff + (1 << 11)) >> 12;
444 		lower = rvoff & 0xfff;
445 		emit(rv_auipc(RV_REG_T1, upper), ctx);
446 		emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
447 		return 0;
448 	}
449 
450 	pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
451 	return -ERANGE;
452 }
453 
454 static bool is_signed_bpf_cond(u8 cond)
455 {
456 	return cond == BPF_JSGT || cond == BPF_JSLT ||
457 		cond == BPF_JSGE || cond == BPF_JSLE;
458 }
459 
460 static int emit_call(u64 addr, bool fixed_addr, struct rv_jit_context *ctx)
461 {
462 	s64 off = 0;
463 	u64 ip;
464 
465 	if (addr && ctx->insns) {
466 		ip = (u64)(long)(ctx->insns + ctx->ninsns);
467 		off = addr - ip;
468 	}
469 
470 	return emit_jump_and_link(RV_REG_RA, off, fixed_addr, ctx);
471 }
472 
473 static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
474 			struct rv_jit_context *ctx)
475 {
476 	u8 r0;
477 	int jmp_offset;
478 
479 	if (off) {
480 		if (is_12b_int(off)) {
481 			emit_addi(RV_REG_T1, rd, off, ctx);
482 		} else {
483 			emit_imm(RV_REG_T1, off, ctx);
484 			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
485 		}
486 		rd = RV_REG_T1;
487 	}
488 
489 	switch (imm) {
490 	/* lock *(u32/u64 *)(dst_reg + off16) <op>= src_reg */
491 	case BPF_ADD:
492 		emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) :
493 		     rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
494 		break;
495 	case BPF_AND:
496 		emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) :
497 		     rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
498 		break;
499 	case BPF_OR:
500 		emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) :
501 		     rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
502 		break;
503 	case BPF_XOR:
504 		emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) :
505 		     rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
506 		break;
507 	/* src_reg = atomic_fetch_<op>(dst_reg + off16, src_reg) */
508 	case BPF_ADD | BPF_FETCH:
509 		emit(is64 ? rv_amoadd_d(rs, rs, rd, 0, 0) :
510 		     rv_amoadd_w(rs, rs, rd, 0, 0), ctx);
511 		if (!is64)
512 			emit_zext_32(rs, ctx);
513 		break;
514 	case BPF_AND | BPF_FETCH:
515 		emit(is64 ? rv_amoand_d(rs, rs, rd, 0, 0) :
516 		     rv_amoand_w(rs, rs, rd, 0, 0), ctx);
517 		if (!is64)
518 			emit_zext_32(rs, ctx);
519 		break;
520 	case BPF_OR | BPF_FETCH:
521 		emit(is64 ? rv_amoor_d(rs, rs, rd, 0, 0) :
522 		     rv_amoor_w(rs, rs, rd, 0, 0), ctx);
523 		if (!is64)
524 			emit_zext_32(rs, ctx);
525 		break;
526 	case BPF_XOR | BPF_FETCH:
527 		emit(is64 ? rv_amoxor_d(rs, rs, rd, 0, 0) :
528 		     rv_amoxor_w(rs, rs, rd, 0, 0), ctx);
529 		if (!is64)
530 			emit_zext_32(rs, ctx);
531 		break;
532 	/* src_reg = atomic_xchg(dst_reg + off16, src_reg); */
533 	case BPF_XCHG:
534 		emit(is64 ? rv_amoswap_d(rs, rs, rd, 0, 0) :
535 		     rv_amoswap_w(rs, rs, rd, 0, 0), ctx);
536 		if (!is64)
537 			emit_zext_32(rs, ctx);
538 		break;
539 	/* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */
540 	case BPF_CMPXCHG:
541 		r0 = bpf_to_rv_reg(BPF_REG_0, ctx);
542 		emit(is64 ? rv_addi(RV_REG_T2, r0, 0) :
543 		     rv_addiw(RV_REG_T2, r0, 0), ctx);
544 		emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) :
545 		     rv_lr_w(r0, 0, rd, 0, 0), ctx);
546 		jmp_offset = ninsns_rvoff(8);
547 		emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx);
548 		emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 0) :
549 		     rv_sc_w(RV_REG_T3, rs, rd, 0, 0), ctx);
550 		jmp_offset = ninsns_rvoff(-6);
551 		emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx);
552 		emit(rv_fence(0x3, 0x3), ctx);
553 		break;
554 	}
555 }
556 
557 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
558 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
559 
560 bool ex_handler_bpf(const struct exception_table_entry *ex,
561 		    struct pt_regs *regs)
562 {
563 	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
564 	int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
565 
566 	*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
567 	regs->epc = (unsigned long)&ex->fixup - offset;
568 
569 	return true;
570 }
571 
572 /* For accesses to BTF pointers, add an entry to the exception table */
573 static int add_exception_handler(const struct bpf_insn *insn,
574 				 struct rv_jit_context *ctx,
575 				 int dst_reg, int insn_len)
576 {
577 	struct exception_table_entry *ex;
578 	unsigned long pc;
579 	off_t offset;
580 
581 	if (!ctx->insns || !ctx->prog->aux->extable || BPF_MODE(insn->code) != BPF_PROBE_MEM)
582 		return 0;
583 
584 	if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
585 		return -EINVAL;
586 
587 	if (WARN_ON_ONCE(insn_len > ctx->ninsns))
588 		return -EINVAL;
589 
590 	if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
591 		return -EINVAL;
592 
593 	ex = &ctx->prog->aux->extable[ctx->nexentries];
594 	pc = (unsigned long)&ctx->insns[ctx->ninsns - insn_len];
595 
596 	offset = pc - (long)&ex->insn;
597 	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
598 		return -ERANGE;
599 	ex->insn = offset;
600 
601 	/*
602 	 * Since the extable follows the program, the fixup offset is always
603 	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
604 	 * to keep things simple, and put the destination register in the upper
605 	 * bits. We don't need to worry about buildtime or runtime sort
606 	 * modifying the upper bits because the table is already sorted, and
607 	 * isn't part of the main exception table.
608 	 */
609 	offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
610 	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
611 		return -ERANGE;
612 
613 	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
614 		FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
615 	ex->type = EX_TYPE_BPF;
616 
617 	ctx->nexentries++;
618 	return 0;
619 }
620 
621 static int gen_call_or_nops(void *target, void *ip, u32 *insns)
622 {
623 	s64 rvoff;
624 	int i, ret;
625 	struct rv_jit_context ctx;
626 
627 	ctx.ninsns = 0;
628 	ctx.insns = (u16 *)insns;
629 
630 	if (!target) {
631 		for (i = 0; i < 4; i++)
632 			emit(rv_nop(), &ctx);
633 		return 0;
634 	}
635 
636 	rvoff = (s64)(target - (ip + 4));
637 	emit(rv_sd(RV_REG_SP, -8, RV_REG_RA), &ctx);
638 	ret = emit_jump_and_link(RV_REG_RA, rvoff, false, &ctx);
639 	if (ret)
640 		return ret;
641 	emit(rv_ld(RV_REG_RA, -8, RV_REG_SP), &ctx);
642 
643 	return 0;
644 }
645 
646 static int gen_jump_or_nops(void *target, void *ip, u32 *insns)
647 {
648 	s64 rvoff;
649 	struct rv_jit_context ctx;
650 
651 	ctx.ninsns = 0;
652 	ctx.insns = (u16 *)insns;
653 
654 	if (!target) {
655 		emit(rv_nop(), &ctx);
656 		emit(rv_nop(), &ctx);
657 		return 0;
658 	}
659 
660 	rvoff = (s64)(target - ip);
661 	return emit_jump_and_link(RV_REG_ZERO, rvoff, false, &ctx);
662 }
663 
664 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
665 		       void *old_addr, void *new_addr)
666 {
667 	u32 old_insns[4], new_insns[4];
668 	bool is_call = poke_type == BPF_MOD_CALL;
669 	int (*gen_insns)(void *target, void *ip, u32 *insns);
670 	int ninsns = is_call ? 4 : 2;
671 	int ret;
672 
673 	if (!is_bpf_text_address((unsigned long)ip))
674 		return -ENOTSUPP;
675 
676 	gen_insns = is_call ? gen_call_or_nops : gen_jump_or_nops;
677 
678 	ret = gen_insns(old_addr, ip, old_insns);
679 	if (ret)
680 		return ret;
681 
682 	if (memcmp(ip, old_insns, ninsns * 4))
683 		return -EFAULT;
684 
685 	ret = gen_insns(new_addr, ip, new_insns);
686 	if (ret)
687 		return ret;
688 
689 	cpus_read_lock();
690 	mutex_lock(&text_mutex);
691 	if (memcmp(ip, new_insns, ninsns * 4))
692 		ret = patch_text(ip, new_insns, ninsns);
693 	mutex_unlock(&text_mutex);
694 	cpus_read_unlock();
695 
696 	return ret;
697 }
698 
699 static void store_args(int nregs, int args_off, struct rv_jit_context *ctx)
700 {
701 	int i;
702 
703 	for (i = 0; i < nregs; i++) {
704 		emit_sd(RV_REG_FP, -args_off, RV_REG_A0 + i, ctx);
705 		args_off -= 8;
706 	}
707 }
708 
709 static void restore_args(int nregs, int args_off, struct rv_jit_context *ctx)
710 {
711 	int i;
712 
713 	for (i = 0; i < nregs; i++) {
714 		emit_ld(RV_REG_A0 + i, -args_off, RV_REG_FP, ctx);
715 		args_off -= 8;
716 	}
717 }
718 
719 static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_off,
720 			   int run_ctx_off, bool save_ret, struct rv_jit_context *ctx)
721 {
722 	int ret, branch_off;
723 	struct bpf_prog *p = l->link.prog;
724 	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
725 
726 	if (l->cookie) {
727 		emit_imm(RV_REG_T1, l->cookie, ctx);
728 		emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_T1, ctx);
729 	} else {
730 		emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_ZERO, ctx);
731 	}
732 
733 	/* arg1: prog */
734 	emit_imm(RV_REG_A0, (const s64)p, ctx);
735 	/* arg2: &run_ctx */
736 	emit_addi(RV_REG_A1, RV_REG_FP, -run_ctx_off, ctx);
737 	ret = emit_call((const u64)bpf_trampoline_enter(p), true, ctx);
738 	if (ret)
739 		return ret;
740 
741 	/* if (__bpf_prog_enter(prog) == 0)
742 	 *	goto skip_exec_of_prog;
743 	 */
744 	branch_off = ctx->ninsns;
745 	/* nop reserved for conditional jump */
746 	emit(rv_nop(), ctx);
747 
748 	/* store prog start time */
749 	emit_mv(RV_REG_S1, RV_REG_A0, ctx);
750 
751 	/* arg1: &args_off */
752 	emit_addi(RV_REG_A0, RV_REG_FP, -args_off, ctx);
753 	if (!p->jited)
754 		/* arg2: progs[i]->insnsi for interpreter */
755 		emit_imm(RV_REG_A1, (const s64)p->insnsi, ctx);
756 	ret = emit_call((const u64)p->bpf_func, true, ctx);
757 	if (ret)
758 		return ret;
759 
760 	if (save_ret)
761 		emit_sd(RV_REG_FP, -retval_off, regmap[BPF_REG_0], ctx);
762 
763 	/* update branch with beqz */
764 	if (ctx->insns) {
765 		int offset = ninsns_rvoff(ctx->ninsns - branch_off);
766 		u32 insn = rv_beq(RV_REG_A0, RV_REG_ZERO, offset >> 1);
767 		*(u32 *)(ctx->insns + branch_off) = insn;
768 	}
769 
770 	/* arg1: prog */
771 	emit_imm(RV_REG_A0, (const s64)p, ctx);
772 	/* arg2: prog start time */
773 	emit_mv(RV_REG_A1, RV_REG_S1, ctx);
774 	/* arg3: &run_ctx */
775 	emit_addi(RV_REG_A2, RV_REG_FP, -run_ctx_off, ctx);
776 	ret = emit_call((const u64)bpf_trampoline_exit(p), true, ctx);
777 
778 	return ret;
779 }
780 
781 static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
782 					 const struct btf_func_model *m,
783 					 struct bpf_tramp_links *tlinks,
784 					 void *func_addr, u32 flags,
785 					 struct rv_jit_context *ctx)
786 {
787 	int i, ret, offset;
788 	int *branches_off = NULL;
789 	int stack_size = 0, nregs = m->nr_args;
790 	int retaddr_off, fp_off, retval_off, args_off;
791 	int nregs_off, ip_off, run_ctx_off, sreg_off;
792 	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
793 	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
794 	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
795 	void *orig_call = func_addr;
796 	bool save_ret;
797 	u32 insn;
798 
799 	/* Generated trampoline stack layout:
800 	 *
801 	 * FP - 8	    [ RA of parent func	] return address of parent
802 	 *					  function
803 	 * FP - retaddr_off [ RA of traced func	] return address of traced
804 	 *					  function
805 	 * FP - fp_off	    [ FP of parent func ]
806 	 *
807 	 * FP - retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
808 	 *					  BPF_TRAMP_F_RET_FENTRY_RET
809 	 *                  [ argN              ]
810 	 *                  [ ...               ]
811 	 * FP - args_off    [ arg1              ]
812 	 *
813 	 * FP - nregs_off   [ regs count        ]
814 	 *
815 	 * FP - ip_off      [ traced func	] BPF_TRAMP_F_IP_ARG
816 	 *
817 	 * FP - run_ctx_off [ bpf_tramp_run_ctx ]
818 	 *
819 	 * FP - sreg_off    [ callee saved reg	]
820 	 *
821 	 *		    [ pads              ] pads for 16 bytes alignment
822 	 */
823 
824 	if (flags & (BPF_TRAMP_F_ORIG_STACK | BPF_TRAMP_F_SHARE_IPMODIFY))
825 		return -ENOTSUPP;
826 
827 	/* extra regiters for struct arguments */
828 	for (i = 0; i < m->nr_args; i++)
829 		if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
830 			nregs += round_up(m->arg_size[i], 8) / 8 - 1;
831 
832 	/* 8 arguments passed by registers */
833 	if (nregs > 8)
834 		return -ENOTSUPP;
835 
836 	/* room for parent function return address */
837 	stack_size += 8;
838 
839 	stack_size += 8;
840 	retaddr_off = stack_size;
841 
842 	stack_size += 8;
843 	fp_off = stack_size;
844 
845 	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
846 	if (save_ret) {
847 		stack_size += 8;
848 		retval_off = stack_size;
849 	}
850 
851 	stack_size += nregs * 8;
852 	args_off = stack_size;
853 
854 	stack_size += 8;
855 	nregs_off = stack_size;
856 
857 	if (flags & BPF_TRAMP_F_IP_ARG) {
858 		stack_size += 8;
859 		ip_off = stack_size;
860 	}
861 
862 	stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
863 	run_ctx_off = stack_size;
864 
865 	stack_size += 8;
866 	sreg_off = stack_size;
867 
868 	stack_size = round_up(stack_size, 16);
869 
870 	emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
871 
872 	emit_sd(RV_REG_SP, stack_size - retaddr_off, RV_REG_RA, ctx);
873 	emit_sd(RV_REG_SP, stack_size - fp_off, RV_REG_FP, ctx);
874 
875 	emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
876 
877 	/* callee saved register S1 to pass start time */
878 	emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
879 
880 	/* store ip address of the traced function */
881 	if (flags & BPF_TRAMP_F_IP_ARG) {
882 		emit_imm(RV_REG_T1, (const s64)func_addr, ctx);
883 		emit_sd(RV_REG_FP, -ip_off, RV_REG_T1, ctx);
884 	}
885 
886 	emit_li(RV_REG_T1, nregs, ctx);
887 	emit_sd(RV_REG_FP, -nregs_off, RV_REG_T1, ctx);
888 
889 	store_args(nregs, args_off, ctx);
890 
891 	/* skip to actual body of traced function */
892 	if (flags & BPF_TRAMP_F_SKIP_FRAME)
893 		orig_call += 16;
894 
895 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
896 		emit_imm(RV_REG_A0, (const s64)im, ctx);
897 		ret = emit_call((const u64)__bpf_tramp_enter, true, ctx);
898 		if (ret)
899 			return ret;
900 	}
901 
902 	for (i = 0; i < fentry->nr_links; i++) {
903 		ret = invoke_bpf_prog(fentry->links[i], args_off, retval_off, run_ctx_off,
904 				      flags & BPF_TRAMP_F_RET_FENTRY_RET, ctx);
905 		if (ret)
906 			return ret;
907 	}
908 
909 	if (fmod_ret->nr_links) {
910 		branches_off = kcalloc(fmod_ret->nr_links, sizeof(int), GFP_KERNEL);
911 		if (!branches_off)
912 			return -ENOMEM;
913 
914 		/* cleanup to avoid garbage return value confusion */
915 		emit_sd(RV_REG_FP, -retval_off, RV_REG_ZERO, ctx);
916 		for (i = 0; i < fmod_ret->nr_links; i++) {
917 			ret = invoke_bpf_prog(fmod_ret->links[i], args_off, retval_off,
918 					      run_ctx_off, true, ctx);
919 			if (ret)
920 				goto out;
921 			emit_ld(RV_REG_T1, -retval_off, RV_REG_FP, ctx);
922 			branches_off[i] = ctx->ninsns;
923 			/* nop reserved for conditional jump */
924 			emit(rv_nop(), ctx);
925 		}
926 	}
927 
928 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
929 		restore_args(nregs, args_off, ctx);
930 		ret = emit_call((const u64)orig_call, true, ctx);
931 		if (ret)
932 			goto out;
933 		emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
934 		im->ip_after_call = ctx->insns + ctx->ninsns;
935 		/* 2 nops reserved for auipc+jalr pair */
936 		emit(rv_nop(), ctx);
937 		emit(rv_nop(), ctx);
938 	}
939 
940 	/* update branches saved in invoke_bpf_mod_ret with bnez */
941 	for (i = 0; ctx->insns && i < fmod_ret->nr_links; i++) {
942 		offset = ninsns_rvoff(ctx->ninsns - branches_off[i]);
943 		insn = rv_bne(RV_REG_T1, RV_REG_ZERO, offset >> 1);
944 		*(u32 *)(ctx->insns + branches_off[i]) = insn;
945 	}
946 
947 	for (i = 0; i < fexit->nr_links; i++) {
948 		ret = invoke_bpf_prog(fexit->links[i], args_off, retval_off,
949 				      run_ctx_off, false, ctx);
950 		if (ret)
951 			goto out;
952 	}
953 
954 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
955 		im->ip_epilogue = ctx->insns + ctx->ninsns;
956 		emit_imm(RV_REG_A0, (const s64)im, ctx);
957 		ret = emit_call((const u64)__bpf_tramp_exit, true, ctx);
958 		if (ret)
959 			goto out;
960 	}
961 
962 	if (flags & BPF_TRAMP_F_RESTORE_REGS)
963 		restore_args(nregs, args_off, ctx);
964 
965 	if (save_ret)
966 		emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx);
967 
968 	emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
969 
970 	if (flags & BPF_TRAMP_F_SKIP_FRAME)
971 		/* return address of parent function */
972 		emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
973 	else
974 		/* return address of traced function */
975 		emit_ld(RV_REG_RA, stack_size - retaddr_off, RV_REG_SP, ctx);
976 
977 	emit_ld(RV_REG_FP, stack_size - fp_off, RV_REG_SP, ctx);
978 	emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
979 
980 	emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
981 
982 	ret = ctx->ninsns;
983 out:
984 	kfree(branches_off);
985 	return ret;
986 }
987 
988 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
989 				void *image_end, const struct btf_func_model *m,
990 				u32 flags, struct bpf_tramp_links *tlinks,
991 				void *func_addr)
992 {
993 	int ret;
994 	struct rv_jit_context ctx;
995 
996 	ctx.ninsns = 0;
997 	ctx.insns = NULL;
998 	ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
999 	if (ret < 0)
1000 		return ret;
1001 
1002 	if (ninsns_rvoff(ret) > (long)image_end - (long)image)
1003 		return -EFBIG;
1004 
1005 	ctx.ninsns = 0;
1006 	ctx.insns = image;
1007 	ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
1008 	if (ret < 0)
1009 		return ret;
1010 
1011 	bpf_flush_icache(ctx.insns, ctx.insns + ctx.ninsns);
1012 
1013 	return ninsns_rvoff(ret);
1014 }
1015 
1016 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
1017 		      bool extra_pass)
1018 {
1019 	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
1020 		    BPF_CLASS(insn->code) == BPF_JMP;
1021 	int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
1022 	struct bpf_prog_aux *aux = ctx->prog->aux;
1023 	u8 rd = -1, rs = -1, code = insn->code;
1024 	s16 off = insn->off;
1025 	s32 imm = insn->imm;
1026 
1027 	init_regs(&rd, &rs, insn, ctx);
1028 
1029 	switch (code) {
1030 	/* dst = src */
1031 	case BPF_ALU | BPF_MOV | BPF_X:
1032 	case BPF_ALU64 | BPF_MOV | BPF_X:
1033 		if (imm == 1) {
1034 			/* Special mov32 for zext */
1035 			emit_zext_32(rd, ctx);
1036 			break;
1037 		}
1038 		emit_mv(rd, rs, ctx);
1039 		if (!is64 && !aux->verifier_zext)
1040 			emit_zext_32(rd, ctx);
1041 		break;
1042 
1043 	/* dst = dst OP src */
1044 	case BPF_ALU | BPF_ADD | BPF_X:
1045 	case BPF_ALU64 | BPF_ADD | BPF_X:
1046 		emit_add(rd, rd, rs, ctx);
1047 		if (!is64 && !aux->verifier_zext)
1048 			emit_zext_32(rd, ctx);
1049 		break;
1050 	case BPF_ALU | BPF_SUB | BPF_X:
1051 	case BPF_ALU64 | BPF_SUB | BPF_X:
1052 		if (is64)
1053 			emit_sub(rd, rd, rs, ctx);
1054 		else
1055 			emit_subw(rd, rd, rs, ctx);
1056 
1057 		if (!is64 && !aux->verifier_zext)
1058 			emit_zext_32(rd, ctx);
1059 		break;
1060 	case BPF_ALU | BPF_AND | BPF_X:
1061 	case BPF_ALU64 | BPF_AND | BPF_X:
1062 		emit_and(rd, rd, rs, ctx);
1063 		if (!is64 && !aux->verifier_zext)
1064 			emit_zext_32(rd, ctx);
1065 		break;
1066 	case BPF_ALU | BPF_OR | BPF_X:
1067 	case BPF_ALU64 | BPF_OR | BPF_X:
1068 		emit_or(rd, rd, rs, ctx);
1069 		if (!is64 && !aux->verifier_zext)
1070 			emit_zext_32(rd, ctx);
1071 		break;
1072 	case BPF_ALU | BPF_XOR | BPF_X:
1073 	case BPF_ALU64 | BPF_XOR | BPF_X:
1074 		emit_xor(rd, rd, rs, ctx);
1075 		if (!is64 && !aux->verifier_zext)
1076 			emit_zext_32(rd, ctx);
1077 		break;
1078 	case BPF_ALU | BPF_MUL | BPF_X:
1079 	case BPF_ALU64 | BPF_MUL | BPF_X:
1080 		emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
1081 		if (!is64 && !aux->verifier_zext)
1082 			emit_zext_32(rd, ctx);
1083 		break;
1084 	case BPF_ALU | BPF_DIV | BPF_X:
1085 	case BPF_ALU64 | BPF_DIV | BPF_X:
1086 		emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
1087 		if (!is64 && !aux->verifier_zext)
1088 			emit_zext_32(rd, ctx);
1089 		break;
1090 	case BPF_ALU | BPF_MOD | BPF_X:
1091 	case BPF_ALU64 | BPF_MOD | BPF_X:
1092 		emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
1093 		if (!is64 && !aux->verifier_zext)
1094 			emit_zext_32(rd, ctx);
1095 		break;
1096 	case BPF_ALU | BPF_LSH | BPF_X:
1097 	case BPF_ALU64 | BPF_LSH | BPF_X:
1098 		emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
1099 		if (!is64 && !aux->verifier_zext)
1100 			emit_zext_32(rd, ctx);
1101 		break;
1102 	case BPF_ALU | BPF_RSH | BPF_X:
1103 	case BPF_ALU64 | BPF_RSH | BPF_X:
1104 		emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
1105 		if (!is64 && !aux->verifier_zext)
1106 			emit_zext_32(rd, ctx);
1107 		break;
1108 	case BPF_ALU | BPF_ARSH | BPF_X:
1109 	case BPF_ALU64 | BPF_ARSH | BPF_X:
1110 		emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
1111 		if (!is64 && !aux->verifier_zext)
1112 			emit_zext_32(rd, ctx);
1113 		break;
1114 
1115 	/* dst = -dst */
1116 	case BPF_ALU | BPF_NEG:
1117 	case BPF_ALU64 | BPF_NEG:
1118 		emit_sub(rd, RV_REG_ZERO, rd, ctx);
1119 		if (!is64 && !aux->verifier_zext)
1120 			emit_zext_32(rd, ctx);
1121 		break;
1122 
1123 	/* dst = BSWAP##imm(dst) */
1124 	case BPF_ALU | BPF_END | BPF_FROM_LE:
1125 		switch (imm) {
1126 		case 16:
1127 			emit_slli(rd, rd, 48, ctx);
1128 			emit_srli(rd, rd, 48, ctx);
1129 			break;
1130 		case 32:
1131 			if (!aux->verifier_zext)
1132 				emit_zext_32(rd, ctx);
1133 			break;
1134 		case 64:
1135 			/* Do nothing */
1136 			break;
1137 		}
1138 		break;
1139 
1140 	case BPF_ALU | BPF_END | BPF_FROM_BE:
1141 		emit_li(RV_REG_T2, 0, ctx);
1142 
1143 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1144 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1145 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1146 		emit_srli(rd, rd, 8, ctx);
1147 		if (imm == 16)
1148 			goto out_be;
1149 
1150 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1151 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1152 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1153 		emit_srli(rd, rd, 8, ctx);
1154 
1155 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1156 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1157 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1158 		emit_srli(rd, rd, 8, ctx);
1159 		if (imm == 32)
1160 			goto out_be;
1161 
1162 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1163 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1164 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1165 		emit_srli(rd, rd, 8, ctx);
1166 
1167 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1168 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1169 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1170 		emit_srli(rd, rd, 8, ctx);
1171 
1172 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1173 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1174 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1175 		emit_srli(rd, rd, 8, ctx);
1176 
1177 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1178 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1179 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1180 		emit_srli(rd, rd, 8, ctx);
1181 out_be:
1182 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
1183 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1184 
1185 		emit_mv(rd, RV_REG_T2, ctx);
1186 		break;
1187 
1188 	/* dst = imm */
1189 	case BPF_ALU | BPF_MOV | BPF_K:
1190 	case BPF_ALU64 | BPF_MOV | BPF_K:
1191 		emit_imm(rd, imm, ctx);
1192 		if (!is64 && !aux->verifier_zext)
1193 			emit_zext_32(rd, ctx);
1194 		break;
1195 
1196 	/* dst = dst OP imm */
1197 	case BPF_ALU | BPF_ADD | BPF_K:
1198 	case BPF_ALU64 | BPF_ADD | BPF_K:
1199 		if (is_12b_int(imm)) {
1200 			emit_addi(rd, rd, imm, ctx);
1201 		} else {
1202 			emit_imm(RV_REG_T1, imm, ctx);
1203 			emit_add(rd, rd, RV_REG_T1, ctx);
1204 		}
1205 		if (!is64 && !aux->verifier_zext)
1206 			emit_zext_32(rd, ctx);
1207 		break;
1208 	case BPF_ALU | BPF_SUB | BPF_K:
1209 	case BPF_ALU64 | BPF_SUB | BPF_K:
1210 		if (is_12b_int(-imm)) {
1211 			emit_addi(rd, rd, -imm, ctx);
1212 		} else {
1213 			emit_imm(RV_REG_T1, imm, ctx);
1214 			emit_sub(rd, rd, RV_REG_T1, ctx);
1215 		}
1216 		if (!is64 && !aux->verifier_zext)
1217 			emit_zext_32(rd, ctx);
1218 		break;
1219 	case BPF_ALU | BPF_AND | BPF_K:
1220 	case BPF_ALU64 | BPF_AND | BPF_K:
1221 		if (is_12b_int(imm)) {
1222 			emit_andi(rd, rd, imm, ctx);
1223 		} else {
1224 			emit_imm(RV_REG_T1, imm, ctx);
1225 			emit_and(rd, rd, RV_REG_T1, ctx);
1226 		}
1227 		if (!is64 && !aux->verifier_zext)
1228 			emit_zext_32(rd, ctx);
1229 		break;
1230 	case BPF_ALU | BPF_OR | BPF_K:
1231 	case BPF_ALU64 | BPF_OR | BPF_K:
1232 		if (is_12b_int(imm)) {
1233 			emit(rv_ori(rd, rd, imm), ctx);
1234 		} else {
1235 			emit_imm(RV_REG_T1, imm, ctx);
1236 			emit_or(rd, rd, RV_REG_T1, ctx);
1237 		}
1238 		if (!is64 && !aux->verifier_zext)
1239 			emit_zext_32(rd, ctx);
1240 		break;
1241 	case BPF_ALU | BPF_XOR | BPF_K:
1242 	case BPF_ALU64 | BPF_XOR | BPF_K:
1243 		if (is_12b_int(imm)) {
1244 			emit(rv_xori(rd, rd, imm), ctx);
1245 		} else {
1246 			emit_imm(RV_REG_T1, imm, ctx);
1247 			emit_xor(rd, rd, RV_REG_T1, ctx);
1248 		}
1249 		if (!is64 && !aux->verifier_zext)
1250 			emit_zext_32(rd, ctx);
1251 		break;
1252 	case BPF_ALU | BPF_MUL | BPF_K:
1253 	case BPF_ALU64 | BPF_MUL | BPF_K:
1254 		emit_imm(RV_REG_T1, imm, ctx);
1255 		emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
1256 		     rv_mulw(rd, rd, RV_REG_T1), ctx);
1257 		if (!is64 && !aux->verifier_zext)
1258 			emit_zext_32(rd, ctx);
1259 		break;
1260 	case BPF_ALU | BPF_DIV | BPF_K:
1261 	case BPF_ALU64 | BPF_DIV | BPF_K:
1262 		emit_imm(RV_REG_T1, imm, ctx);
1263 		emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
1264 		     rv_divuw(rd, rd, RV_REG_T1), ctx);
1265 		if (!is64 && !aux->verifier_zext)
1266 			emit_zext_32(rd, ctx);
1267 		break;
1268 	case BPF_ALU | BPF_MOD | BPF_K:
1269 	case BPF_ALU64 | BPF_MOD | BPF_K:
1270 		emit_imm(RV_REG_T1, imm, ctx);
1271 		emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
1272 		     rv_remuw(rd, rd, RV_REG_T1), ctx);
1273 		if (!is64 && !aux->verifier_zext)
1274 			emit_zext_32(rd, ctx);
1275 		break;
1276 	case BPF_ALU | BPF_LSH | BPF_K:
1277 	case BPF_ALU64 | BPF_LSH | BPF_K:
1278 		emit_slli(rd, rd, imm, ctx);
1279 
1280 		if (!is64 && !aux->verifier_zext)
1281 			emit_zext_32(rd, ctx);
1282 		break;
1283 	case BPF_ALU | BPF_RSH | BPF_K:
1284 	case BPF_ALU64 | BPF_RSH | BPF_K:
1285 		if (is64)
1286 			emit_srli(rd, rd, imm, ctx);
1287 		else
1288 			emit(rv_srliw(rd, rd, imm), ctx);
1289 
1290 		if (!is64 && !aux->verifier_zext)
1291 			emit_zext_32(rd, ctx);
1292 		break;
1293 	case BPF_ALU | BPF_ARSH | BPF_K:
1294 	case BPF_ALU64 | BPF_ARSH | BPF_K:
1295 		if (is64)
1296 			emit_srai(rd, rd, imm, ctx);
1297 		else
1298 			emit(rv_sraiw(rd, rd, imm), ctx);
1299 
1300 		if (!is64 && !aux->verifier_zext)
1301 			emit_zext_32(rd, ctx);
1302 		break;
1303 
1304 	/* JUMP off */
1305 	case BPF_JMP | BPF_JA:
1306 		rvoff = rv_offset(i, off, ctx);
1307 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1308 		if (ret)
1309 			return ret;
1310 		break;
1311 
1312 	/* IF (dst COND src) JUMP off */
1313 	case BPF_JMP | BPF_JEQ | BPF_X:
1314 	case BPF_JMP32 | BPF_JEQ | BPF_X:
1315 	case BPF_JMP | BPF_JGT | BPF_X:
1316 	case BPF_JMP32 | BPF_JGT | BPF_X:
1317 	case BPF_JMP | BPF_JLT | BPF_X:
1318 	case BPF_JMP32 | BPF_JLT | BPF_X:
1319 	case BPF_JMP | BPF_JGE | BPF_X:
1320 	case BPF_JMP32 | BPF_JGE | BPF_X:
1321 	case BPF_JMP | BPF_JLE | BPF_X:
1322 	case BPF_JMP32 | BPF_JLE | BPF_X:
1323 	case BPF_JMP | BPF_JNE | BPF_X:
1324 	case BPF_JMP32 | BPF_JNE | BPF_X:
1325 	case BPF_JMP | BPF_JSGT | BPF_X:
1326 	case BPF_JMP32 | BPF_JSGT | BPF_X:
1327 	case BPF_JMP | BPF_JSLT | BPF_X:
1328 	case BPF_JMP32 | BPF_JSLT | BPF_X:
1329 	case BPF_JMP | BPF_JSGE | BPF_X:
1330 	case BPF_JMP32 | BPF_JSGE | BPF_X:
1331 	case BPF_JMP | BPF_JSLE | BPF_X:
1332 	case BPF_JMP32 | BPF_JSLE | BPF_X:
1333 	case BPF_JMP | BPF_JSET | BPF_X:
1334 	case BPF_JMP32 | BPF_JSET | BPF_X:
1335 		rvoff = rv_offset(i, off, ctx);
1336 		if (!is64) {
1337 			s = ctx->ninsns;
1338 			if (is_signed_bpf_cond(BPF_OP(code)))
1339 				emit_sext_32_rd_rs(&rd, &rs, ctx);
1340 			else
1341 				emit_zext_32_rd_rs(&rd, &rs, ctx);
1342 			e = ctx->ninsns;
1343 
1344 			/* Adjust for extra insns */
1345 			rvoff -= ninsns_rvoff(e - s);
1346 		}
1347 
1348 		if (BPF_OP(code) == BPF_JSET) {
1349 			/* Adjust for and */
1350 			rvoff -= 4;
1351 			emit_and(RV_REG_T1, rd, rs, ctx);
1352 			emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
1353 				    ctx);
1354 		} else {
1355 			emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1356 		}
1357 		break;
1358 
1359 	/* IF (dst COND imm) JUMP off */
1360 	case BPF_JMP | BPF_JEQ | BPF_K:
1361 	case BPF_JMP32 | BPF_JEQ | BPF_K:
1362 	case BPF_JMP | BPF_JGT | BPF_K:
1363 	case BPF_JMP32 | BPF_JGT | BPF_K:
1364 	case BPF_JMP | BPF_JLT | BPF_K:
1365 	case BPF_JMP32 | BPF_JLT | BPF_K:
1366 	case BPF_JMP | BPF_JGE | BPF_K:
1367 	case BPF_JMP32 | BPF_JGE | BPF_K:
1368 	case BPF_JMP | BPF_JLE | BPF_K:
1369 	case BPF_JMP32 | BPF_JLE | BPF_K:
1370 	case BPF_JMP | BPF_JNE | BPF_K:
1371 	case BPF_JMP32 | BPF_JNE | BPF_K:
1372 	case BPF_JMP | BPF_JSGT | BPF_K:
1373 	case BPF_JMP32 | BPF_JSGT | BPF_K:
1374 	case BPF_JMP | BPF_JSLT | BPF_K:
1375 	case BPF_JMP32 | BPF_JSLT | BPF_K:
1376 	case BPF_JMP | BPF_JSGE | BPF_K:
1377 	case BPF_JMP32 | BPF_JSGE | BPF_K:
1378 	case BPF_JMP | BPF_JSLE | BPF_K:
1379 	case BPF_JMP32 | BPF_JSLE | BPF_K:
1380 		rvoff = rv_offset(i, off, ctx);
1381 		s = ctx->ninsns;
1382 		if (imm) {
1383 			emit_imm(RV_REG_T1, imm, ctx);
1384 			rs = RV_REG_T1;
1385 		} else {
1386 			/* If imm is 0, simply use zero register. */
1387 			rs = RV_REG_ZERO;
1388 		}
1389 		if (!is64) {
1390 			if (is_signed_bpf_cond(BPF_OP(code)))
1391 				emit_sext_32_rd(&rd, ctx);
1392 			else
1393 				emit_zext_32_rd_t1(&rd, ctx);
1394 		}
1395 		e = ctx->ninsns;
1396 
1397 		/* Adjust for extra insns */
1398 		rvoff -= ninsns_rvoff(e - s);
1399 		emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1400 		break;
1401 
1402 	case BPF_JMP | BPF_JSET | BPF_K:
1403 	case BPF_JMP32 | BPF_JSET | BPF_K:
1404 		rvoff = rv_offset(i, off, ctx);
1405 		s = ctx->ninsns;
1406 		if (is_12b_int(imm)) {
1407 			emit_andi(RV_REG_T1, rd, imm, ctx);
1408 		} else {
1409 			emit_imm(RV_REG_T1, imm, ctx);
1410 			emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
1411 		}
1412 		/* For jset32, we should clear the upper 32 bits of t1, but
1413 		 * sign-extension is sufficient here and saves one instruction,
1414 		 * as t1 is used only in comparison against zero.
1415 		 */
1416 		if (!is64 && imm < 0)
1417 			emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
1418 		e = ctx->ninsns;
1419 		rvoff -= ninsns_rvoff(e - s);
1420 		emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1421 		break;
1422 
1423 	/* function call */
1424 	case BPF_JMP | BPF_CALL:
1425 	{
1426 		bool fixed_addr;
1427 		u64 addr;
1428 
1429 		mark_call(ctx);
1430 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1431 					    &addr, &fixed_addr);
1432 		if (ret < 0)
1433 			return ret;
1434 
1435 		ret = emit_call(addr, fixed_addr, ctx);
1436 		if (ret)
1437 			return ret;
1438 
1439 		emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
1440 		break;
1441 	}
1442 	/* tail call */
1443 	case BPF_JMP | BPF_TAIL_CALL:
1444 		if (emit_bpf_tail_call(i, ctx))
1445 			return -1;
1446 		break;
1447 
1448 	/* function return */
1449 	case BPF_JMP | BPF_EXIT:
1450 		if (i == ctx->prog->len - 1)
1451 			break;
1452 
1453 		rvoff = epilogue_offset(ctx);
1454 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1455 		if (ret)
1456 			return ret;
1457 		break;
1458 
1459 	/* dst = imm64 */
1460 	case BPF_LD | BPF_IMM | BPF_DW:
1461 	{
1462 		struct bpf_insn insn1 = insn[1];
1463 		u64 imm64;
1464 
1465 		imm64 = (u64)insn1.imm << 32 | (u32)imm;
1466 		if (bpf_pseudo_func(insn)) {
1467 			/* fixed-length insns for extra jit pass */
1468 			ret = emit_addr(rd, imm64, extra_pass, ctx);
1469 			if (ret)
1470 				return ret;
1471 		} else {
1472 			emit_imm(rd, imm64, ctx);
1473 		}
1474 
1475 		return 1;
1476 	}
1477 
1478 	/* LDX: dst = *(size *)(src + off) */
1479 	case BPF_LDX | BPF_MEM | BPF_B:
1480 	case BPF_LDX | BPF_MEM | BPF_H:
1481 	case BPF_LDX | BPF_MEM | BPF_W:
1482 	case BPF_LDX | BPF_MEM | BPF_DW:
1483 	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1484 	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1485 	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1486 	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1487 	{
1488 		int insn_len, insns_start;
1489 
1490 		switch (BPF_SIZE(code)) {
1491 		case BPF_B:
1492 			if (is_12b_int(off)) {
1493 				insns_start = ctx->ninsns;
1494 				emit(rv_lbu(rd, off, rs), ctx);
1495 				insn_len = ctx->ninsns - insns_start;
1496 				break;
1497 			}
1498 
1499 			emit_imm(RV_REG_T1, off, ctx);
1500 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1501 			insns_start = ctx->ninsns;
1502 			emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1503 			insn_len = ctx->ninsns - insns_start;
1504 			if (insn_is_zext(&insn[1]))
1505 				return 1;
1506 			break;
1507 		case BPF_H:
1508 			if (is_12b_int(off)) {
1509 				insns_start = ctx->ninsns;
1510 				emit(rv_lhu(rd, off, rs), ctx);
1511 				insn_len = ctx->ninsns - insns_start;
1512 				break;
1513 			}
1514 
1515 			emit_imm(RV_REG_T1, off, ctx);
1516 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1517 			insns_start = ctx->ninsns;
1518 			emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1519 			insn_len = ctx->ninsns - insns_start;
1520 			if (insn_is_zext(&insn[1]))
1521 				return 1;
1522 			break;
1523 		case BPF_W:
1524 			if (is_12b_int(off)) {
1525 				insns_start = ctx->ninsns;
1526 				emit(rv_lwu(rd, off, rs), ctx);
1527 				insn_len = ctx->ninsns - insns_start;
1528 				break;
1529 			}
1530 
1531 			emit_imm(RV_REG_T1, off, ctx);
1532 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1533 			insns_start = ctx->ninsns;
1534 			emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1535 			insn_len = ctx->ninsns - insns_start;
1536 			if (insn_is_zext(&insn[1]))
1537 				return 1;
1538 			break;
1539 		case BPF_DW:
1540 			if (is_12b_int(off)) {
1541 				insns_start = ctx->ninsns;
1542 				emit_ld(rd, off, rs, ctx);
1543 				insn_len = ctx->ninsns - insns_start;
1544 				break;
1545 			}
1546 
1547 			emit_imm(RV_REG_T1, off, ctx);
1548 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1549 			insns_start = ctx->ninsns;
1550 			emit_ld(rd, 0, RV_REG_T1, ctx);
1551 			insn_len = ctx->ninsns - insns_start;
1552 			break;
1553 		}
1554 
1555 		ret = add_exception_handler(insn, ctx, rd, insn_len);
1556 		if (ret)
1557 			return ret;
1558 		break;
1559 	}
1560 	/* speculation barrier */
1561 	case BPF_ST | BPF_NOSPEC:
1562 		break;
1563 
1564 	/* ST: *(size *)(dst + off) = imm */
1565 	case BPF_ST | BPF_MEM | BPF_B:
1566 		emit_imm(RV_REG_T1, imm, ctx);
1567 		if (is_12b_int(off)) {
1568 			emit(rv_sb(rd, off, RV_REG_T1), ctx);
1569 			break;
1570 		}
1571 
1572 		emit_imm(RV_REG_T2, off, ctx);
1573 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1574 		emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1575 		break;
1576 
1577 	case BPF_ST | BPF_MEM | BPF_H:
1578 		emit_imm(RV_REG_T1, imm, ctx);
1579 		if (is_12b_int(off)) {
1580 			emit(rv_sh(rd, off, RV_REG_T1), ctx);
1581 			break;
1582 		}
1583 
1584 		emit_imm(RV_REG_T2, off, ctx);
1585 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1586 		emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1587 		break;
1588 	case BPF_ST | BPF_MEM | BPF_W:
1589 		emit_imm(RV_REG_T1, imm, ctx);
1590 		if (is_12b_int(off)) {
1591 			emit_sw(rd, off, RV_REG_T1, ctx);
1592 			break;
1593 		}
1594 
1595 		emit_imm(RV_REG_T2, off, ctx);
1596 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1597 		emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1598 		break;
1599 	case BPF_ST | BPF_MEM | BPF_DW:
1600 		emit_imm(RV_REG_T1, imm, ctx);
1601 		if (is_12b_int(off)) {
1602 			emit_sd(rd, off, RV_REG_T1, ctx);
1603 			break;
1604 		}
1605 
1606 		emit_imm(RV_REG_T2, off, ctx);
1607 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1608 		emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1609 		break;
1610 
1611 	/* STX: *(size *)(dst + off) = src */
1612 	case BPF_STX | BPF_MEM | BPF_B:
1613 		if (is_12b_int(off)) {
1614 			emit(rv_sb(rd, off, rs), ctx);
1615 			break;
1616 		}
1617 
1618 		emit_imm(RV_REG_T1, off, ctx);
1619 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1620 		emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1621 		break;
1622 	case BPF_STX | BPF_MEM | BPF_H:
1623 		if (is_12b_int(off)) {
1624 			emit(rv_sh(rd, off, rs), ctx);
1625 			break;
1626 		}
1627 
1628 		emit_imm(RV_REG_T1, off, ctx);
1629 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1630 		emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1631 		break;
1632 	case BPF_STX | BPF_MEM | BPF_W:
1633 		if (is_12b_int(off)) {
1634 			emit_sw(rd, off, rs, ctx);
1635 			break;
1636 		}
1637 
1638 		emit_imm(RV_REG_T1, off, ctx);
1639 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1640 		emit_sw(RV_REG_T1, 0, rs, ctx);
1641 		break;
1642 	case BPF_STX | BPF_MEM | BPF_DW:
1643 		if (is_12b_int(off)) {
1644 			emit_sd(rd, off, rs, ctx);
1645 			break;
1646 		}
1647 
1648 		emit_imm(RV_REG_T1, off, ctx);
1649 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1650 		emit_sd(RV_REG_T1, 0, rs, ctx);
1651 		break;
1652 	case BPF_STX | BPF_ATOMIC | BPF_W:
1653 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1654 		emit_atomic(rd, rs, off, imm,
1655 			    BPF_SIZE(code) == BPF_DW, ctx);
1656 		break;
1657 	default:
1658 		pr_err("bpf-jit: unknown opcode %02x\n", code);
1659 		return -EINVAL;
1660 	}
1661 
1662 	return 0;
1663 }
1664 
1665 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1666 {
1667 	int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
1668 
1669 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1670 	if (bpf_stack_adjust)
1671 		mark_fp(ctx);
1672 
1673 	if (seen_reg(RV_REG_RA, ctx))
1674 		stack_adjust += 8;
1675 	stack_adjust += 8; /* RV_REG_FP */
1676 	if (seen_reg(RV_REG_S1, ctx))
1677 		stack_adjust += 8;
1678 	if (seen_reg(RV_REG_S2, ctx))
1679 		stack_adjust += 8;
1680 	if (seen_reg(RV_REG_S3, ctx))
1681 		stack_adjust += 8;
1682 	if (seen_reg(RV_REG_S4, ctx))
1683 		stack_adjust += 8;
1684 	if (seen_reg(RV_REG_S5, ctx))
1685 		stack_adjust += 8;
1686 	if (seen_reg(RV_REG_S6, ctx))
1687 		stack_adjust += 8;
1688 
1689 	stack_adjust = round_up(stack_adjust, 16);
1690 	stack_adjust += bpf_stack_adjust;
1691 
1692 	store_offset = stack_adjust - 8;
1693 
1694 	/* reserve 4 nop insns */
1695 	for (i = 0; i < 4; i++)
1696 		emit(rv_nop(), ctx);
1697 
1698 	/* First instruction is always setting the tail-call-counter
1699 	 * (TCC) register. This instruction is skipped for tail calls.
1700 	 * Force using a 4-byte (non-compressed) instruction.
1701 	 */
1702 	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1703 
1704 	emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1705 
1706 	if (seen_reg(RV_REG_RA, ctx)) {
1707 		emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1708 		store_offset -= 8;
1709 	}
1710 	emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1711 	store_offset -= 8;
1712 	if (seen_reg(RV_REG_S1, ctx)) {
1713 		emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1714 		store_offset -= 8;
1715 	}
1716 	if (seen_reg(RV_REG_S2, ctx)) {
1717 		emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1718 		store_offset -= 8;
1719 	}
1720 	if (seen_reg(RV_REG_S3, ctx)) {
1721 		emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1722 		store_offset -= 8;
1723 	}
1724 	if (seen_reg(RV_REG_S4, ctx)) {
1725 		emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1726 		store_offset -= 8;
1727 	}
1728 	if (seen_reg(RV_REG_S5, ctx)) {
1729 		emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1730 		store_offset -= 8;
1731 	}
1732 	if (seen_reg(RV_REG_S6, ctx)) {
1733 		emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1734 		store_offset -= 8;
1735 	}
1736 
1737 	emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1738 
1739 	if (bpf_stack_adjust)
1740 		emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1741 
1742 	/* Program contains calls and tail calls, so RV_REG_TCC need
1743 	 * to be saved across calls.
1744 	 */
1745 	if (seen_tail_call(ctx) && seen_call(ctx))
1746 		emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1747 
1748 	ctx->stack_size = stack_adjust;
1749 }
1750 
1751 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1752 {
1753 	__build_epilogue(false, ctx);
1754 }
1755 
1756 bool bpf_jit_supports_kfunc_call(void)
1757 {
1758 	return true;
1759 }
1760