xref: /openbmc/linux/arch/riscv/net/bpf_jit_comp64.c (revision 2f0754f2)
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7 
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include "bpf_jit.h"
12 
13 #define RV_REG_TCC RV_REG_A6
14 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
15 
16 static const int regmap[] = {
17 	[BPF_REG_0] =	RV_REG_A5,
18 	[BPF_REG_1] =	RV_REG_A0,
19 	[BPF_REG_2] =	RV_REG_A1,
20 	[BPF_REG_3] =	RV_REG_A2,
21 	[BPF_REG_4] =	RV_REG_A3,
22 	[BPF_REG_5] =	RV_REG_A4,
23 	[BPF_REG_6] =	RV_REG_S1,
24 	[BPF_REG_7] =	RV_REG_S2,
25 	[BPF_REG_8] =	RV_REG_S3,
26 	[BPF_REG_9] =	RV_REG_S4,
27 	[BPF_REG_FP] =	RV_REG_S5,
28 	[BPF_REG_AX] =	RV_REG_T0,
29 };
30 
31 static const int pt_regmap[] = {
32 	[RV_REG_A0] = offsetof(struct pt_regs, a0),
33 	[RV_REG_A1] = offsetof(struct pt_regs, a1),
34 	[RV_REG_A2] = offsetof(struct pt_regs, a2),
35 	[RV_REG_A3] = offsetof(struct pt_regs, a3),
36 	[RV_REG_A4] = offsetof(struct pt_regs, a4),
37 	[RV_REG_A5] = offsetof(struct pt_regs, a5),
38 	[RV_REG_S1] = offsetof(struct pt_regs, s1),
39 	[RV_REG_S2] = offsetof(struct pt_regs, s2),
40 	[RV_REG_S3] = offsetof(struct pt_regs, s3),
41 	[RV_REG_S4] = offsetof(struct pt_regs, s4),
42 	[RV_REG_S5] = offsetof(struct pt_regs, s5),
43 	[RV_REG_T0] = offsetof(struct pt_regs, t0),
44 };
45 
46 enum {
47 	RV_CTX_F_SEEN_TAIL_CALL =	0,
48 	RV_CTX_F_SEEN_CALL =		RV_REG_RA,
49 	RV_CTX_F_SEEN_S1 =		RV_REG_S1,
50 	RV_CTX_F_SEEN_S2 =		RV_REG_S2,
51 	RV_CTX_F_SEEN_S3 =		RV_REG_S3,
52 	RV_CTX_F_SEEN_S4 =		RV_REG_S4,
53 	RV_CTX_F_SEEN_S5 =		RV_REG_S5,
54 	RV_CTX_F_SEEN_S6 =		RV_REG_S6,
55 };
56 
57 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
58 {
59 	u8 reg = regmap[bpf_reg];
60 
61 	switch (reg) {
62 	case RV_CTX_F_SEEN_S1:
63 	case RV_CTX_F_SEEN_S2:
64 	case RV_CTX_F_SEEN_S3:
65 	case RV_CTX_F_SEEN_S4:
66 	case RV_CTX_F_SEEN_S5:
67 	case RV_CTX_F_SEEN_S6:
68 		__set_bit(reg, &ctx->flags);
69 	}
70 	return reg;
71 };
72 
73 static bool seen_reg(int reg, struct rv_jit_context *ctx)
74 {
75 	switch (reg) {
76 	case RV_CTX_F_SEEN_CALL:
77 	case RV_CTX_F_SEEN_S1:
78 	case RV_CTX_F_SEEN_S2:
79 	case RV_CTX_F_SEEN_S3:
80 	case RV_CTX_F_SEEN_S4:
81 	case RV_CTX_F_SEEN_S5:
82 	case RV_CTX_F_SEEN_S6:
83 		return test_bit(reg, &ctx->flags);
84 	}
85 	return false;
86 }
87 
88 static void mark_fp(struct rv_jit_context *ctx)
89 {
90 	__set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
91 }
92 
93 static void mark_call(struct rv_jit_context *ctx)
94 {
95 	__set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
96 }
97 
98 static bool seen_call(struct rv_jit_context *ctx)
99 {
100 	return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
101 }
102 
103 static void mark_tail_call(struct rv_jit_context *ctx)
104 {
105 	__set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
106 }
107 
108 static bool seen_tail_call(struct rv_jit_context *ctx)
109 {
110 	return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
111 }
112 
113 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
114 {
115 	mark_tail_call(ctx);
116 
117 	if (seen_call(ctx)) {
118 		__set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
119 		return RV_REG_S6;
120 	}
121 	return RV_REG_A6;
122 }
123 
124 static bool is_32b_int(s64 val)
125 {
126 	return -(1L << 31) <= val && val < (1L << 31);
127 }
128 
129 static bool in_auipc_jalr_range(s64 val)
130 {
131 	/*
132 	 * auipc+jalr can reach any signed PC-relative offset in the range
133 	 * [-2^31 - 2^11, 2^31 - 2^11).
134 	 */
135 	return (-(1L << 31) - (1L << 11)) <= val &&
136 		val < ((1L << 31) - (1L << 11));
137 }
138 
139 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
140 {
141 	/* Note that the immediate from the add is sign-extended,
142 	 * which means that we need to compensate this by adding 2^12,
143 	 * when the 12th bit is set. A simpler way of doing this, and
144 	 * getting rid of the check, is to just add 2**11 before the
145 	 * shift. The "Loading a 32-Bit constant" example from the
146 	 * "Computer Organization and Design, RISC-V edition" book by
147 	 * Patterson/Hennessy highlights this fact.
148 	 *
149 	 * This also means that we need to process LSB to MSB.
150 	 */
151 	s64 upper = (val + (1 << 11)) >> 12;
152 	/* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
153 	 * and addi are signed and RVC checks will perform signed comparisons.
154 	 */
155 	s64 lower = ((val & 0xfff) << 52) >> 52;
156 	int shift;
157 
158 	if (is_32b_int(val)) {
159 		if (upper)
160 			emit_lui(rd, upper, ctx);
161 
162 		if (!upper) {
163 			emit_li(rd, lower, ctx);
164 			return;
165 		}
166 
167 		emit_addiw(rd, rd, lower, ctx);
168 		return;
169 	}
170 
171 	shift = __ffs(upper);
172 	upper >>= shift;
173 	shift += 12;
174 
175 	emit_imm(rd, upper, ctx);
176 
177 	emit_slli(rd, rd, shift, ctx);
178 	if (lower)
179 		emit_addi(rd, rd, lower, ctx);
180 }
181 
182 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
183 {
184 	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
185 
186 	if (seen_reg(RV_REG_RA, ctx)) {
187 		emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
188 		store_offset -= 8;
189 	}
190 	emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
191 	store_offset -= 8;
192 	if (seen_reg(RV_REG_S1, ctx)) {
193 		emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
194 		store_offset -= 8;
195 	}
196 	if (seen_reg(RV_REG_S2, ctx)) {
197 		emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
198 		store_offset -= 8;
199 	}
200 	if (seen_reg(RV_REG_S3, ctx)) {
201 		emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
202 		store_offset -= 8;
203 	}
204 	if (seen_reg(RV_REG_S4, ctx)) {
205 		emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
206 		store_offset -= 8;
207 	}
208 	if (seen_reg(RV_REG_S5, ctx)) {
209 		emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
210 		store_offset -= 8;
211 	}
212 	if (seen_reg(RV_REG_S6, ctx)) {
213 		emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
214 		store_offset -= 8;
215 	}
216 
217 	emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
218 	/* Set return value. */
219 	if (!is_tail_call)
220 		emit_mv(RV_REG_A0, RV_REG_A5, ctx);
221 	emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
222 		  is_tail_call ? 4 : 0, /* skip TCC init */
223 		  ctx);
224 }
225 
226 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
227 		     struct rv_jit_context *ctx)
228 {
229 	switch (cond) {
230 	case BPF_JEQ:
231 		emit(rv_beq(rd, rs, rvoff >> 1), ctx);
232 		return;
233 	case BPF_JGT:
234 		emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
235 		return;
236 	case BPF_JLT:
237 		emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
238 		return;
239 	case BPF_JGE:
240 		emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
241 		return;
242 	case BPF_JLE:
243 		emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
244 		return;
245 	case BPF_JNE:
246 		emit(rv_bne(rd, rs, rvoff >> 1), ctx);
247 		return;
248 	case BPF_JSGT:
249 		emit(rv_blt(rs, rd, rvoff >> 1), ctx);
250 		return;
251 	case BPF_JSLT:
252 		emit(rv_blt(rd, rs, rvoff >> 1), ctx);
253 		return;
254 	case BPF_JSGE:
255 		emit(rv_bge(rd, rs, rvoff >> 1), ctx);
256 		return;
257 	case BPF_JSLE:
258 		emit(rv_bge(rs, rd, rvoff >> 1), ctx);
259 	}
260 }
261 
262 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
263 			struct rv_jit_context *ctx)
264 {
265 	s64 upper, lower;
266 
267 	if (is_13b_int(rvoff)) {
268 		emit_bcc(cond, rd, rs, rvoff, ctx);
269 		return;
270 	}
271 
272 	/* Adjust for jal */
273 	rvoff -= 4;
274 
275 	/* Transform, e.g.:
276 	 *   bne rd,rs,foo
277 	 * to
278 	 *   beq rd,rs,<.L1>
279 	 *   (auipc foo)
280 	 *   jal(r) foo
281 	 * .L1
282 	 */
283 	cond = invert_bpf_cond(cond);
284 	if (is_21b_int(rvoff)) {
285 		emit_bcc(cond, rd, rs, 8, ctx);
286 		emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
287 		return;
288 	}
289 
290 	/* 32b No need for an additional rvoff adjustment, since we
291 	 * get that from the auipc at PC', where PC = PC' + 4.
292 	 */
293 	upper = (rvoff + (1 << 11)) >> 12;
294 	lower = rvoff & 0xfff;
295 
296 	emit_bcc(cond, rd, rs, 12, ctx);
297 	emit(rv_auipc(RV_REG_T1, upper), ctx);
298 	emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
299 }
300 
301 static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
302 {
303 	emit_slli(reg, reg, 32, ctx);
304 	emit_srli(reg, reg, 32, ctx);
305 }
306 
307 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
308 {
309 	int tc_ninsn, off, start_insn = ctx->ninsns;
310 	u8 tcc = rv_tail_call_reg(ctx);
311 
312 	/* a0: &ctx
313 	 * a1: &array
314 	 * a2: index
315 	 *
316 	 * if (index >= array->map.max_entries)
317 	 *	goto out;
318 	 */
319 	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
320 		   ctx->offset[0];
321 	emit_zext_32(RV_REG_A2, ctx);
322 
323 	off = offsetof(struct bpf_array, map.max_entries);
324 	if (is_12b_check(off, insn))
325 		return -1;
326 	emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
327 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
328 	emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
329 
330 	/* if (--TCC < 0)
331 	 *     goto out;
332 	 */
333 	emit_addi(RV_REG_TCC, tcc, -1, ctx);
334 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
335 	emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
336 
337 	/* prog = array->ptrs[index];
338 	 * if (!prog)
339 	 *     goto out;
340 	 */
341 	emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
342 	emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
343 	off = offsetof(struct bpf_array, ptrs);
344 	if (is_12b_check(off, insn))
345 		return -1;
346 	emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
347 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
348 	emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
349 
350 	/* goto *(prog->bpf_func + 4); */
351 	off = offsetof(struct bpf_prog, bpf_func);
352 	if (is_12b_check(off, insn))
353 		return -1;
354 	emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
355 	__build_epilogue(true, ctx);
356 	return 0;
357 }
358 
359 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
360 		      struct rv_jit_context *ctx)
361 {
362 	u8 code = insn->code;
363 
364 	switch (code) {
365 	case BPF_JMP | BPF_JA:
366 	case BPF_JMP | BPF_CALL:
367 	case BPF_JMP | BPF_EXIT:
368 	case BPF_JMP | BPF_TAIL_CALL:
369 		break;
370 	default:
371 		*rd = bpf_to_rv_reg(insn->dst_reg, ctx);
372 	}
373 
374 	if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
375 	    code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
376 	    code & BPF_LDX || code & BPF_STX)
377 		*rs = bpf_to_rv_reg(insn->src_reg, ctx);
378 }
379 
380 static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
381 {
382 	emit_mv(RV_REG_T2, *rd, ctx);
383 	emit_zext_32(RV_REG_T2, ctx);
384 	emit_mv(RV_REG_T1, *rs, ctx);
385 	emit_zext_32(RV_REG_T1, ctx);
386 	*rd = RV_REG_T2;
387 	*rs = RV_REG_T1;
388 }
389 
390 static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
391 {
392 	emit_addiw(RV_REG_T2, *rd, 0, ctx);
393 	emit_addiw(RV_REG_T1, *rs, 0, ctx);
394 	*rd = RV_REG_T2;
395 	*rs = RV_REG_T1;
396 }
397 
398 static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
399 {
400 	emit_mv(RV_REG_T2, *rd, ctx);
401 	emit_zext_32(RV_REG_T2, ctx);
402 	emit_zext_32(RV_REG_T1, ctx);
403 	*rd = RV_REG_T2;
404 }
405 
406 static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
407 {
408 	emit_addiw(RV_REG_T2, *rd, 0, ctx);
409 	*rd = RV_REG_T2;
410 }
411 
412 static int emit_jump_and_link(u8 rd, s64 rvoff, bool force_jalr,
413 			      struct rv_jit_context *ctx)
414 {
415 	s64 upper, lower;
416 
417 	if (rvoff && is_21b_int(rvoff) && !force_jalr) {
418 		emit(rv_jal(rd, rvoff >> 1), ctx);
419 		return 0;
420 	} else if (in_auipc_jalr_range(rvoff)) {
421 		upper = (rvoff + (1 << 11)) >> 12;
422 		lower = rvoff & 0xfff;
423 		emit(rv_auipc(RV_REG_T1, upper), ctx);
424 		emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
425 		return 0;
426 	}
427 
428 	pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
429 	return -ERANGE;
430 }
431 
432 static bool is_signed_bpf_cond(u8 cond)
433 {
434 	return cond == BPF_JSGT || cond == BPF_JSLT ||
435 		cond == BPF_JSGE || cond == BPF_JSLE;
436 }
437 
438 static int emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
439 {
440 	s64 off = 0;
441 	u64 ip;
442 	u8 rd;
443 	int ret;
444 
445 	if (addr && ctx->insns) {
446 		ip = (u64)(long)(ctx->insns + ctx->ninsns);
447 		off = addr - ip;
448 	}
449 
450 	ret = emit_jump_and_link(RV_REG_RA, off, !fixed, ctx);
451 	if (ret)
452 		return ret;
453 	rd = bpf_to_rv_reg(BPF_REG_0, ctx);
454 	emit_mv(rd, RV_REG_A0, ctx);
455 	return 0;
456 }
457 
458 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
459 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
460 
461 bool ex_handler_bpf(const struct exception_table_entry *ex,
462 		    struct pt_regs *regs)
463 {
464 	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
465 	int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
466 
467 	*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
468 	regs->epc = (unsigned long)&ex->fixup - offset;
469 
470 	return true;
471 }
472 
473 /* For accesses to BTF pointers, add an entry to the exception table */
474 static int add_exception_handler(const struct bpf_insn *insn,
475 				 struct rv_jit_context *ctx,
476 				 int dst_reg, int insn_len)
477 {
478 	struct exception_table_entry *ex;
479 	unsigned long pc;
480 	off_t offset;
481 
482 	if (!ctx->insns || !ctx->prog->aux->extable || BPF_MODE(insn->code) != BPF_PROBE_MEM)
483 		return 0;
484 
485 	if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
486 		return -EINVAL;
487 
488 	if (WARN_ON_ONCE(insn_len > ctx->ninsns))
489 		return -EINVAL;
490 
491 	if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
492 		return -EINVAL;
493 
494 	ex = &ctx->prog->aux->extable[ctx->nexentries];
495 	pc = (unsigned long)&ctx->insns[ctx->ninsns - insn_len];
496 
497 	offset = pc - (long)&ex->insn;
498 	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
499 		return -ERANGE;
500 	ex->insn = offset;
501 
502 	/*
503 	 * Since the extable follows the program, the fixup offset is always
504 	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
505 	 * to keep things simple, and put the destination register in the upper
506 	 * bits. We don't need to worry about buildtime or runtime sort
507 	 * modifying the upper bits because the table is already sorted, and
508 	 * isn't part of the main exception table.
509 	 */
510 	offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
511 	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
512 		return -ERANGE;
513 
514 	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
515 		FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
516 	ex->type = EX_TYPE_BPF;
517 
518 	ctx->nexentries++;
519 	return 0;
520 }
521 
522 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
523 		      bool extra_pass)
524 {
525 	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
526 		    BPF_CLASS(insn->code) == BPF_JMP;
527 	int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
528 	struct bpf_prog_aux *aux = ctx->prog->aux;
529 	u8 rd = -1, rs = -1, code = insn->code;
530 	s16 off = insn->off;
531 	s32 imm = insn->imm;
532 
533 	init_regs(&rd, &rs, insn, ctx);
534 
535 	switch (code) {
536 	/* dst = src */
537 	case BPF_ALU | BPF_MOV | BPF_X:
538 	case BPF_ALU64 | BPF_MOV | BPF_X:
539 		if (imm == 1) {
540 			/* Special mov32 for zext */
541 			emit_zext_32(rd, ctx);
542 			break;
543 		}
544 		emit_mv(rd, rs, ctx);
545 		if (!is64 && !aux->verifier_zext)
546 			emit_zext_32(rd, ctx);
547 		break;
548 
549 	/* dst = dst OP src */
550 	case BPF_ALU | BPF_ADD | BPF_X:
551 	case BPF_ALU64 | BPF_ADD | BPF_X:
552 		emit_add(rd, rd, rs, ctx);
553 		if (!is64 && !aux->verifier_zext)
554 			emit_zext_32(rd, ctx);
555 		break;
556 	case BPF_ALU | BPF_SUB | BPF_X:
557 	case BPF_ALU64 | BPF_SUB | BPF_X:
558 		if (is64)
559 			emit_sub(rd, rd, rs, ctx);
560 		else
561 			emit_subw(rd, rd, rs, ctx);
562 
563 		if (!is64 && !aux->verifier_zext)
564 			emit_zext_32(rd, ctx);
565 		break;
566 	case BPF_ALU | BPF_AND | BPF_X:
567 	case BPF_ALU64 | BPF_AND | BPF_X:
568 		emit_and(rd, rd, rs, ctx);
569 		if (!is64 && !aux->verifier_zext)
570 			emit_zext_32(rd, ctx);
571 		break;
572 	case BPF_ALU | BPF_OR | BPF_X:
573 	case BPF_ALU64 | BPF_OR | BPF_X:
574 		emit_or(rd, rd, rs, ctx);
575 		if (!is64 && !aux->verifier_zext)
576 			emit_zext_32(rd, ctx);
577 		break;
578 	case BPF_ALU | BPF_XOR | BPF_X:
579 	case BPF_ALU64 | BPF_XOR | BPF_X:
580 		emit_xor(rd, rd, rs, ctx);
581 		if (!is64 && !aux->verifier_zext)
582 			emit_zext_32(rd, ctx);
583 		break;
584 	case BPF_ALU | BPF_MUL | BPF_X:
585 	case BPF_ALU64 | BPF_MUL | BPF_X:
586 		emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
587 		if (!is64 && !aux->verifier_zext)
588 			emit_zext_32(rd, ctx);
589 		break;
590 	case BPF_ALU | BPF_DIV | BPF_X:
591 	case BPF_ALU64 | BPF_DIV | BPF_X:
592 		emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
593 		if (!is64 && !aux->verifier_zext)
594 			emit_zext_32(rd, ctx);
595 		break;
596 	case BPF_ALU | BPF_MOD | BPF_X:
597 	case BPF_ALU64 | BPF_MOD | BPF_X:
598 		emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
599 		if (!is64 && !aux->verifier_zext)
600 			emit_zext_32(rd, ctx);
601 		break;
602 	case BPF_ALU | BPF_LSH | BPF_X:
603 	case BPF_ALU64 | BPF_LSH | BPF_X:
604 		emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
605 		if (!is64 && !aux->verifier_zext)
606 			emit_zext_32(rd, ctx);
607 		break;
608 	case BPF_ALU | BPF_RSH | BPF_X:
609 	case BPF_ALU64 | BPF_RSH | BPF_X:
610 		emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
611 		if (!is64 && !aux->verifier_zext)
612 			emit_zext_32(rd, ctx);
613 		break;
614 	case BPF_ALU | BPF_ARSH | BPF_X:
615 	case BPF_ALU64 | BPF_ARSH | BPF_X:
616 		emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
617 		if (!is64 && !aux->verifier_zext)
618 			emit_zext_32(rd, ctx);
619 		break;
620 
621 	/* dst = -dst */
622 	case BPF_ALU | BPF_NEG:
623 	case BPF_ALU64 | BPF_NEG:
624 		emit_sub(rd, RV_REG_ZERO, rd, ctx);
625 		if (!is64 && !aux->verifier_zext)
626 			emit_zext_32(rd, ctx);
627 		break;
628 
629 	/* dst = BSWAP##imm(dst) */
630 	case BPF_ALU | BPF_END | BPF_FROM_LE:
631 		switch (imm) {
632 		case 16:
633 			emit_slli(rd, rd, 48, ctx);
634 			emit_srli(rd, rd, 48, ctx);
635 			break;
636 		case 32:
637 			if (!aux->verifier_zext)
638 				emit_zext_32(rd, ctx);
639 			break;
640 		case 64:
641 			/* Do nothing */
642 			break;
643 		}
644 		break;
645 
646 	case BPF_ALU | BPF_END | BPF_FROM_BE:
647 		emit_li(RV_REG_T2, 0, ctx);
648 
649 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
650 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
651 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
652 		emit_srli(rd, rd, 8, ctx);
653 		if (imm == 16)
654 			goto out_be;
655 
656 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
657 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
658 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
659 		emit_srli(rd, rd, 8, ctx);
660 
661 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
662 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
663 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
664 		emit_srli(rd, rd, 8, ctx);
665 		if (imm == 32)
666 			goto out_be;
667 
668 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
669 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
670 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
671 		emit_srli(rd, rd, 8, ctx);
672 
673 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
674 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
675 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
676 		emit_srli(rd, rd, 8, ctx);
677 
678 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
679 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
680 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
681 		emit_srli(rd, rd, 8, ctx);
682 
683 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
684 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
685 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
686 		emit_srli(rd, rd, 8, ctx);
687 out_be:
688 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
689 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
690 
691 		emit_mv(rd, RV_REG_T2, ctx);
692 		break;
693 
694 	/* dst = imm */
695 	case BPF_ALU | BPF_MOV | BPF_K:
696 	case BPF_ALU64 | BPF_MOV | BPF_K:
697 		emit_imm(rd, imm, ctx);
698 		if (!is64 && !aux->verifier_zext)
699 			emit_zext_32(rd, ctx);
700 		break;
701 
702 	/* dst = dst OP imm */
703 	case BPF_ALU | BPF_ADD | BPF_K:
704 	case BPF_ALU64 | BPF_ADD | BPF_K:
705 		if (is_12b_int(imm)) {
706 			emit_addi(rd, rd, imm, ctx);
707 		} else {
708 			emit_imm(RV_REG_T1, imm, ctx);
709 			emit_add(rd, rd, RV_REG_T1, ctx);
710 		}
711 		if (!is64 && !aux->verifier_zext)
712 			emit_zext_32(rd, ctx);
713 		break;
714 	case BPF_ALU | BPF_SUB | BPF_K:
715 	case BPF_ALU64 | BPF_SUB | BPF_K:
716 		if (is_12b_int(-imm)) {
717 			emit_addi(rd, rd, -imm, ctx);
718 		} else {
719 			emit_imm(RV_REG_T1, imm, ctx);
720 			emit_sub(rd, rd, RV_REG_T1, ctx);
721 		}
722 		if (!is64 && !aux->verifier_zext)
723 			emit_zext_32(rd, ctx);
724 		break;
725 	case BPF_ALU | BPF_AND | BPF_K:
726 	case BPF_ALU64 | BPF_AND | BPF_K:
727 		if (is_12b_int(imm)) {
728 			emit_andi(rd, rd, imm, ctx);
729 		} else {
730 			emit_imm(RV_REG_T1, imm, ctx);
731 			emit_and(rd, rd, RV_REG_T1, ctx);
732 		}
733 		if (!is64 && !aux->verifier_zext)
734 			emit_zext_32(rd, ctx);
735 		break;
736 	case BPF_ALU | BPF_OR | BPF_K:
737 	case BPF_ALU64 | BPF_OR | BPF_K:
738 		if (is_12b_int(imm)) {
739 			emit(rv_ori(rd, rd, imm), ctx);
740 		} else {
741 			emit_imm(RV_REG_T1, imm, ctx);
742 			emit_or(rd, rd, RV_REG_T1, ctx);
743 		}
744 		if (!is64 && !aux->verifier_zext)
745 			emit_zext_32(rd, ctx);
746 		break;
747 	case BPF_ALU | BPF_XOR | BPF_K:
748 	case BPF_ALU64 | BPF_XOR | BPF_K:
749 		if (is_12b_int(imm)) {
750 			emit(rv_xori(rd, rd, imm), ctx);
751 		} else {
752 			emit_imm(RV_REG_T1, imm, ctx);
753 			emit_xor(rd, rd, RV_REG_T1, ctx);
754 		}
755 		if (!is64 && !aux->verifier_zext)
756 			emit_zext_32(rd, ctx);
757 		break;
758 	case BPF_ALU | BPF_MUL | BPF_K:
759 	case BPF_ALU64 | BPF_MUL | BPF_K:
760 		emit_imm(RV_REG_T1, imm, ctx);
761 		emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
762 		     rv_mulw(rd, rd, RV_REG_T1), ctx);
763 		if (!is64 && !aux->verifier_zext)
764 			emit_zext_32(rd, ctx);
765 		break;
766 	case BPF_ALU | BPF_DIV | BPF_K:
767 	case BPF_ALU64 | BPF_DIV | BPF_K:
768 		emit_imm(RV_REG_T1, imm, ctx);
769 		emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
770 		     rv_divuw(rd, rd, RV_REG_T1), ctx);
771 		if (!is64 && !aux->verifier_zext)
772 			emit_zext_32(rd, ctx);
773 		break;
774 	case BPF_ALU | BPF_MOD | BPF_K:
775 	case BPF_ALU64 | BPF_MOD | BPF_K:
776 		emit_imm(RV_REG_T1, imm, ctx);
777 		emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
778 		     rv_remuw(rd, rd, RV_REG_T1), ctx);
779 		if (!is64 && !aux->verifier_zext)
780 			emit_zext_32(rd, ctx);
781 		break;
782 	case BPF_ALU | BPF_LSH | BPF_K:
783 	case BPF_ALU64 | BPF_LSH | BPF_K:
784 		emit_slli(rd, rd, imm, ctx);
785 
786 		if (!is64 && !aux->verifier_zext)
787 			emit_zext_32(rd, ctx);
788 		break;
789 	case BPF_ALU | BPF_RSH | BPF_K:
790 	case BPF_ALU64 | BPF_RSH | BPF_K:
791 		if (is64)
792 			emit_srli(rd, rd, imm, ctx);
793 		else
794 			emit(rv_srliw(rd, rd, imm), ctx);
795 
796 		if (!is64 && !aux->verifier_zext)
797 			emit_zext_32(rd, ctx);
798 		break;
799 	case BPF_ALU | BPF_ARSH | BPF_K:
800 	case BPF_ALU64 | BPF_ARSH | BPF_K:
801 		if (is64)
802 			emit_srai(rd, rd, imm, ctx);
803 		else
804 			emit(rv_sraiw(rd, rd, imm), ctx);
805 
806 		if (!is64 && !aux->verifier_zext)
807 			emit_zext_32(rd, ctx);
808 		break;
809 
810 	/* JUMP off */
811 	case BPF_JMP | BPF_JA:
812 		rvoff = rv_offset(i, off, ctx);
813 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
814 		if (ret)
815 			return ret;
816 		break;
817 
818 	/* IF (dst COND src) JUMP off */
819 	case BPF_JMP | BPF_JEQ | BPF_X:
820 	case BPF_JMP32 | BPF_JEQ | BPF_X:
821 	case BPF_JMP | BPF_JGT | BPF_X:
822 	case BPF_JMP32 | BPF_JGT | BPF_X:
823 	case BPF_JMP | BPF_JLT | BPF_X:
824 	case BPF_JMP32 | BPF_JLT | BPF_X:
825 	case BPF_JMP | BPF_JGE | BPF_X:
826 	case BPF_JMP32 | BPF_JGE | BPF_X:
827 	case BPF_JMP | BPF_JLE | BPF_X:
828 	case BPF_JMP32 | BPF_JLE | BPF_X:
829 	case BPF_JMP | BPF_JNE | BPF_X:
830 	case BPF_JMP32 | BPF_JNE | BPF_X:
831 	case BPF_JMP | BPF_JSGT | BPF_X:
832 	case BPF_JMP32 | BPF_JSGT | BPF_X:
833 	case BPF_JMP | BPF_JSLT | BPF_X:
834 	case BPF_JMP32 | BPF_JSLT | BPF_X:
835 	case BPF_JMP | BPF_JSGE | BPF_X:
836 	case BPF_JMP32 | BPF_JSGE | BPF_X:
837 	case BPF_JMP | BPF_JSLE | BPF_X:
838 	case BPF_JMP32 | BPF_JSLE | BPF_X:
839 	case BPF_JMP | BPF_JSET | BPF_X:
840 	case BPF_JMP32 | BPF_JSET | BPF_X:
841 		rvoff = rv_offset(i, off, ctx);
842 		if (!is64) {
843 			s = ctx->ninsns;
844 			if (is_signed_bpf_cond(BPF_OP(code)))
845 				emit_sext_32_rd_rs(&rd, &rs, ctx);
846 			else
847 				emit_zext_32_rd_rs(&rd, &rs, ctx);
848 			e = ctx->ninsns;
849 
850 			/* Adjust for extra insns */
851 			rvoff -= ninsns_rvoff(e - s);
852 		}
853 
854 		if (BPF_OP(code) == BPF_JSET) {
855 			/* Adjust for and */
856 			rvoff -= 4;
857 			emit_and(RV_REG_T1, rd, rs, ctx);
858 			emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
859 				    ctx);
860 		} else {
861 			emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
862 		}
863 		break;
864 
865 	/* IF (dst COND imm) JUMP off */
866 	case BPF_JMP | BPF_JEQ | BPF_K:
867 	case BPF_JMP32 | BPF_JEQ | BPF_K:
868 	case BPF_JMP | BPF_JGT | BPF_K:
869 	case BPF_JMP32 | BPF_JGT | BPF_K:
870 	case BPF_JMP | BPF_JLT | BPF_K:
871 	case BPF_JMP32 | BPF_JLT | BPF_K:
872 	case BPF_JMP | BPF_JGE | BPF_K:
873 	case BPF_JMP32 | BPF_JGE | BPF_K:
874 	case BPF_JMP | BPF_JLE | BPF_K:
875 	case BPF_JMP32 | BPF_JLE | BPF_K:
876 	case BPF_JMP | BPF_JNE | BPF_K:
877 	case BPF_JMP32 | BPF_JNE | BPF_K:
878 	case BPF_JMP | BPF_JSGT | BPF_K:
879 	case BPF_JMP32 | BPF_JSGT | BPF_K:
880 	case BPF_JMP | BPF_JSLT | BPF_K:
881 	case BPF_JMP32 | BPF_JSLT | BPF_K:
882 	case BPF_JMP | BPF_JSGE | BPF_K:
883 	case BPF_JMP32 | BPF_JSGE | BPF_K:
884 	case BPF_JMP | BPF_JSLE | BPF_K:
885 	case BPF_JMP32 | BPF_JSLE | BPF_K:
886 		rvoff = rv_offset(i, off, ctx);
887 		s = ctx->ninsns;
888 		if (imm) {
889 			emit_imm(RV_REG_T1, imm, ctx);
890 			rs = RV_REG_T1;
891 		} else {
892 			/* If imm is 0, simply use zero register. */
893 			rs = RV_REG_ZERO;
894 		}
895 		if (!is64) {
896 			if (is_signed_bpf_cond(BPF_OP(code)))
897 				emit_sext_32_rd(&rd, ctx);
898 			else
899 				emit_zext_32_rd_t1(&rd, ctx);
900 		}
901 		e = ctx->ninsns;
902 
903 		/* Adjust for extra insns */
904 		rvoff -= ninsns_rvoff(e - s);
905 		emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
906 		break;
907 
908 	case BPF_JMP | BPF_JSET | BPF_K:
909 	case BPF_JMP32 | BPF_JSET | BPF_K:
910 		rvoff = rv_offset(i, off, ctx);
911 		s = ctx->ninsns;
912 		if (is_12b_int(imm)) {
913 			emit_andi(RV_REG_T1, rd, imm, ctx);
914 		} else {
915 			emit_imm(RV_REG_T1, imm, ctx);
916 			emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
917 		}
918 		/* For jset32, we should clear the upper 32 bits of t1, but
919 		 * sign-extension is sufficient here and saves one instruction,
920 		 * as t1 is used only in comparison against zero.
921 		 */
922 		if (!is64 && imm < 0)
923 			emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
924 		e = ctx->ninsns;
925 		rvoff -= ninsns_rvoff(e - s);
926 		emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
927 		break;
928 
929 	/* function call */
930 	case BPF_JMP | BPF_CALL:
931 	{
932 		bool fixed;
933 		u64 addr;
934 
935 		mark_call(ctx);
936 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &addr,
937 					    &fixed);
938 		if (ret < 0)
939 			return ret;
940 		ret = emit_call(fixed, addr, ctx);
941 		if (ret)
942 			return ret;
943 		break;
944 	}
945 	/* tail call */
946 	case BPF_JMP | BPF_TAIL_CALL:
947 		if (emit_bpf_tail_call(i, ctx))
948 			return -1;
949 		break;
950 
951 	/* function return */
952 	case BPF_JMP | BPF_EXIT:
953 		if (i == ctx->prog->len - 1)
954 			break;
955 
956 		rvoff = epilogue_offset(ctx);
957 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
958 		if (ret)
959 			return ret;
960 		break;
961 
962 	/* dst = imm64 */
963 	case BPF_LD | BPF_IMM | BPF_DW:
964 	{
965 		struct bpf_insn insn1 = insn[1];
966 		u64 imm64;
967 
968 		imm64 = (u64)insn1.imm << 32 | (u32)imm;
969 		emit_imm(rd, imm64, ctx);
970 		return 1;
971 	}
972 
973 	/* LDX: dst = *(size *)(src + off) */
974 	case BPF_LDX | BPF_MEM | BPF_B:
975 	case BPF_LDX | BPF_MEM | BPF_H:
976 	case BPF_LDX | BPF_MEM | BPF_W:
977 	case BPF_LDX | BPF_MEM | BPF_DW:
978 	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
979 	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
980 	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
981 	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
982 	{
983 		int insn_len, insns_start;
984 
985 		switch (BPF_SIZE(code)) {
986 		case BPF_B:
987 			if (is_12b_int(off)) {
988 				insns_start = ctx->ninsns;
989 				emit(rv_lbu(rd, off, rs), ctx);
990 				insn_len = ctx->ninsns - insns_start;
991 				break;
992 			}
993 
994 			emit_imm(RV_REG_T1, off, ctx);
995 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
996 			insns_start = ctx->ninsns;
997 			emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
998 			insn_len = ctx->ninsns - insns_start;
999 			if (insn_is_zext(&insn[1]))
1000 				return 1;
1001 			break;
1002 		case BPF_H:
1003 			if (is_12b_int(off)) {
1004 				insns_start = ctx->ninsns;
1005 				emit(rv_lhu(rd, off, rs), ctx);
1006 				insn_len = ctx->ninsns - insns_start;
1007 				break;
1008 			}
1009 
1010 			emit_imm(RV_REG_T1, off, ctx);
1011 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1012 			insns_start = ctx->ninsns;
1013 			emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1014 			insn_len = ctx->ninsns - insns_start;
1015 			if (insn_is_zext(&insn[1]))
1016 				return 1;
1017 			break;
1018 		case BPF_W:
1019 			if (is_12b_int(off)) {
1020 				insns_start = ctx->ninsns;
1021 				emit(rv_lwu(rd, off, rs), ctx);
1022 				insn_len = ctx->ninsns - insns_start;
1023 				break;
1024 			}
1025 
1026 			emit_imm(RV_REG_T1, off, ctx);
1027 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1028 			insns_start = ctx->ninsns;
1029 			emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1030 			insn_len = ctx->ninsns - insns_start;
1031 			if (insn_is_zext(&insn[1]))
1032 				return 1;
1033 			break;
1034 		case BPF_DW:
1035 			if (is_12b_int(off)) {
1036 				insns_start = ctx->ninsns;
1037 				emit_ld(rd, off, rs, ctx);
1038 				insn_len = ctx->ninsns - insns_start;
1039 				break;
1040 			}
1041 
1042 			emit_imm(RV_REG_T1, off, ctx);
1043 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1044 			insns_start = ctx->ninsns;
1045 			emit_ld(rd, 0, RV_REG_T1, ctx);
1046 			insn_len = ctx->ninsns - insns_start;
1047 			break;
1048 		}
1049 
1050 		ret = add_exception_handler(insn, ctx, rd, insn_len);
1051 		if (ret)
1052 			return ret;
1053 		break;
1054 	}
1055 	/* speculation barrier */
1056 	case BPF_ST | BPF_NOSPEC:
1057 		break;
1058 
1059 	/* ST: *(size *)(dst + off) = imm */
1060 	case BPF_ST | BPF_MEM | BPF_B:
1061 		emit_imm(RV_REG_T1, imm, ctx);
1062 		if (is_12b_int(off)) {
1063 			emit(rv_sb(rd, off, RV_REG_T1), ctx);
1064 			break;
1065 		}
1066 
1067 		emit_imm(RV_REG_T2, off, ctx);
1068 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1069 		emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1070 		break;
1071 
1072 	case BPF_ST | BPF_MEM | BPF_H:
1073 		emit_imm(RV_REG_T1, imm, ctx);
1074 		if (is_12b_int(off)) {
1075 			emit(rv_sh(rd, off, RV_REG_T1), ctx);
1076 			break;
1077 		}
1078 
1079 		emit_imm(RV_REG_T2, off, ctx);
1080 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1081 		emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1082 		break;
1083 	case BPF_ST | BPF_MEM | BPF_W:
1084 		emit_imm(RV_REG_T1, imm, ctx);
1085 		if (is_12b_int(off)) {
1086 			emit_sw(rd, off, RV_REG_T1, ctx);
1087 			break;
1088 		}
1089 
1090 		emit_imm(RV_REG_T2, off, ctx);
1091 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1092 		emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1093 		break;
1094 	case BPF_ST | BPF_MEM | BPF_DW:
1095 		emit_imm(RV_REG_T1, imm, ctx);
1096 		if (is_12b_int(off)) {
1097 			emit_sd(rd, off, RV_REG_T1, ctx);
1098 			break;
1099 		}
1100 
1101 		emit_imm(RV_REG_T2, off, ctx);
1102 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1103 		emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1104 		break;
1105 
1106 	/* STX: *(size *)(dst + off) = src */
1107 	case BPF_STX | BPF_MEM | BPF_B:
1108 		if (is_12b_int(off)) {
1109 			emit(rv_sb(rd, off, rs), ctx);
1110 			break;
1111 		}
1112 
1113 		emit_imm(RV_REG_T1, off, ctx);
1114 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1115 		emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1116 		break;
1117 	case BPF_STX | BPF_MEM | BPF_H:
1118 		if (is_12b_int(off)) {
1119 			emit(rv_sh(rd, off, rs), ctx);
1120 			break;
1121 		}
1122 
1123 		emit_imm(RV_REG_T1, off, ctx);
1124 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1125 		emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1126 		break;
1127 	case BPF_STX | BPF_MEM | BPF_W:
1128 		if (is_12b_int(off)) {
1129 			emit_sw(rd, off, rs, ctx);
1130 			break;
1131 		}
1132 
1133 		emit_imm(RV_REG_T1, off, ctx);
1134 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1135 		emit_sw(RV_REG_T1, 0, rs, ctx);
1136 		break;
1137 	case BPF_STX | BPF_MEM | BPF_DW:
1138 		if (is_12b_int(off)) {
1139 			emit_sd(rd, off, rs, ctx);
1140 			break;
1141 		}
1142 
1143 		emit_imm(RV_REG_T1, off, ctx);
1144 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1145 		emit_sd(RV_REG_T1, 0, rs, ctx);
1146 		break;
1147 	case BPF_STX | BPF_ATOMIC | BPF_W:
1148 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1149 		if (insn->imm != BPF_ADD) {
1150 			pr_err("bpf-jit: not supported: atomic operation %02x ***\n",
1151 			       insn->imm);
1152 			return -EINVAL;
1153 		}
1154 
1155 		/* atomic_add: lock *(u32 *)(dst + off) += src
1156 		 * atomic_add: lock *(u64 *)(dst + off) += src
1157 		 */
1158 
1159 		if (off) {
1160 			if (is_12b_int(off)) {
1161 				emit_addi(RV_REG_T1, rd, off, ctx);
1162 			} else {
1163 				emit_imm(RV_REG_T1, off, ctx);
1164 				emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1165 			}
1166 
1167 			rd = RV_REG_T1;
1168 		}
1169 
1170 		emit(BPF_SIZE(code) == BPF_W ?
1171 		     rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0) :
1172 		     rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0), ctx);
1173 		break;
1174 	default:
1175 		pr_err("bpf-jit: unknown opcode %02x\n", code);
1176 		return -EINVAL;
1177 	}
1178 
1179 	return 0;
1180 }
1181 
1182 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1183 {
1184 	int stack_adjust = 0, store_offset, bpf_stack_adjust;
1185 
1186 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1187 	if (bpf_stack_adjust)
1188 		mark_fp(ctx);
1189 
1190 	if (seen_reg(RV_REG_RA, ctx))
1191 		stack_adjust += 8;
1192 	stack_adjust += 8; /* RV_REG_FP */
1193 	if (seen_reg(RV_REG_S1, ctx))
1194 		stack_adjust += 8;
1195 	if (seen_reg(RV_REG_S2, ctx))
1196 		stack_adjust += 8;
1197 	if (seen_reg(RV_REG_S3, ctx))
1198 		stack_adjust += 8;
1199 	if (seen_reg(RV_REG_S4, ctx))
1200 		stack_adjust += 8;
1201 	if (seen_reg(RV_REG_S5, ctx))
1202 		stack_adjust += 8;
1203 	if (seen_reg(RV_REG_S6, ctx))
1204 		stack_adjust += 8;
1205 
1206 	stack_adjust = round_up(stack_adjust, 16);
1207 	stack_adjust += bpf_stack_adjust;
1208 
1209 	store_offset = stack_adjust - 8;
1210 
1211 	/* First instruction is always setting the tail-call-counter
1212 	 * (TCC) register. This instruction is skipped for tail calls.
1213 	 * Force using a 4-byte (non-compressed) instruction.
1214 	 */
1215 	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1216 
1217 	emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1218 
1219 	if (seen_reg(RV_REG_RA, ctx)) {
1220 		emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1221 		store_offset -= 8;
1222 	}
1223 	emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1224 	store_offset -= 8;
1225 	if (seen_reg(RV_REG_S1, ctx)) {
1226 		emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1227 		store_offset -= 8;
1228 	}
1229 	if (seen_reg(RV_REG_S2, ctx)) {
1230 		emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1231 		store_offset -= 8;
1232 	}
1233 	if (seen_reg(RV_REG_S3, ctx)) {
1234 		emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1235 		store_offset -= 8;
1236 	}
1237 	if (seen_reg(RV_REG_S4, ctx)) {
1238 		emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1239 		store_offset -= 8;
1240 	}
1241 	if (seen_reg(RV_REG_S5, ctx)) {
1242 		emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1243 		store_offset -= 8;
1244 	}
1245 	if (seen_reg(RV_REG_S6, ctx)) {
1246 		emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1247 		store_offset -= 8;
1248 	}
1249 
1250 	emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1251 
1252 	if (bpf_stack_adjust)
1253 		emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1254 
1255 	/* Program contains calls and tail calls, so RV_REG_TCC need
1256 	 * to be saved across calls.
1257 	 */
1258 	if (seen_tail_call(ctx) && seen_call(ctx))
1259 		emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1260 
1261 	ctx->stack_size = stack_adjust;
1262 }
1263 
1264 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1265 {
1266 	__build_epilogue(false, ctx);
1267 }
1268