xref: /openbmc/linux/arch/riscv/net/bpf_jit_comp64.c (revision 63705da3)
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7 
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include "bpf_jit.h"
12 
13 #define RV_REG_TCC RV_REG_A6
14 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
15 
16 static const int regmap[] = {
17 	[BPF_REG_0] =	RV_REG_A5,
18 	[BPF_REG_1] =	RV_REG_A0,
19 	[BPF_REG_2] =	RV_REG_A1,
20 	[BPF_REG_3] =	RV_REG_A2,
21 	[BPF_REG_4] =	RV_REG_A3,
22 	[BPF_REG_5] =	RV_REG_A4,
23 	[BPF_REG_6] =	RV_REG_S1,
24 	[BPF_REG_7] =	RV_REG_S2,
25 	[BPF_REG_8] =	RV_REG_S3,
26 	[BPF_REG_9] =	RV_REG_S4,
27 	[BPF_REG_FP] =	RV_REG_S5,
28 	[BPF_REG_AX] =	RV_REG_T0,
29 };
30 
31 static const int pt_regmap[] = {
32 	[RV_REG_A0] = offsetof(struct pt_regs, a0),
33 	[RV_REG_A1] = offsetof(struct pt_regs, a1),
34 	[RV_REG_A2] = offsetof(struct pt_regs, a2),
35 	[RV_REG_A3] = offsetof(struct pt_regs, a3),
36 	[RV_REG_A4] = offsetof(struct pt_regs, a4),
37 	[RV_REG_A5] = offsetof(struct pt_regs, a5),
38 	[RV_REG_S1] = offsetof(struct pt_regs, s1),
39 	[RV_REG_S2] = offsetof(struct pt_regs, s2),
40 	[RV_REG_S3] = offsetof(struct pt_regs, s3),
41 	[RV_REG_S4] = offsetof(struct pt_regs, s4),
42 	[RV_REG_S5] = offsetof(struct pt_regs, s5),
43 	[RV_REG_T0] = offsetof(struct pt_regs, t0),
44 };
45 
46 enum {
47 	RV_CTX_F_SEEN_TAIL_CALL =	0,
48 	RV_CTX_F_SEEN_CALL =		RV_REG_RA,
49 	RV_CTX_F_SEEN_S1 =		RV_REG_S1,
50 	RV_CTX_F_SEEN_S2 =		RV_REG_S2,
51 	RV_CTX_F_SEEN_S3 =		RV_REG_S3,
52 	RV_CTX_F_SEEN_S4 =		RV_REG_S4,
53 	RV_CTX_F_SEEN_S5 =		RV_REG_S5,
54 	RV_CTX_F_SEEN_S6 =		RV_REG_S6,
55 };
56 
57 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
58 {
59 	u8 reg = regmap[bpf_reg];
60 
61 	switch (reg) {
62 	case RV_CTX_F_SEEN_S1:
63 	case RV_CTX_F_SEEN_S2:
64 	case RV_CTX_F_SEEN_S3:
65 	case RV_CTX_F_SEEN_S4:
66 	case RV_CTX_F_SEEN_S5:
67 	case RV_CTX_F_SEEN_S6:
68 		__set_bit(reg, &ctx->flags);
69 	}
70 	return reg;
71 };
72 
73 static bool seen_reg(int reg, struct rv_jit_context *ctx)
74 {
75 	switch (reg) {
76 	case RV_CTX_F_SEEN_CALL:
77 	case RV_CTX_F_SEEN_S1:
78 	case RV_CTX_F_SEEN_S2:
79 	case RV_CTX_F_SEEN_S3:
80 	case RV_CTX_F_SEEN_S4:
81 	case RV_CTX_F_SEEN_S5:
82 	case RV_CTX_F_SEEN_S6:
83 		return test_bit(reg, &ctx->flags);
84 	}
85 	return false;
86 }
87 
88 static void mark_fp(struct rv_jit_context *ctx)
89 {
90 	__set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
91 }
92 
93 static void mark_call(struct rv_jit_context *ctx)
94 {
95 	__set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
96 }
97 
98 static bool seen_call(struct rv_jit_context *ctx)
99 {
100 	return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
101 }
102 
103 static void mark_tail_call(struct rv_jit_context *ctx)
104 {
105 	__set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
106 }
107 
108 static bool seen_tail_call(struct rv_jit_context *ctx)
109 {
110 	return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
111 }
112 
113 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
114 {
115 	mark_tail_call(ctx);
116 
117 	if (seen_call(ctx)) {
118 		__set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
119 		return RV_REG_S6;
120 	}
121 	return RV_REG_A6;
122 }
123 
124 static bool is_32b_int(s64 val)
125 {
126 	return -(1L << 31) <= val && val < (1L << 31);
127 }
128 
129 static bool in_auipc_jalr_range(s64 val)
130 {
131 	/*
132 	 * auipc+jalr can reach any signed PC-relative offset in the range
133 	 * [-2^31 - 2^11, 2^31 - 2^11).
134 	 */
135 	return (-(1L << 31) - (1L << 11)) <= val &&
136 		val < ((1L << 31) - (1L << 11));
137 }
138 
139 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
140 {
141 	/* Note that the immediate from the add is sign-extended,
142 	 * which means that we need to compensate this by adding 2^12,
143 	 * when the 12th bit is set. A simpler way of doing this, and
144 	 * getting rid of the check, is to just add 2**11 before the
145 	 * shift. The "Loading a 32-Bit constant" example from the
146 	 * "Computer Organization and Design, RISC-V edition" book by
147 	 * Patterson/Hennessy highlights this fact.
148 	 *
149 	 * This also means that we need to process LSB to MSB.
150 	 */
151 	s64 upper = (val + (1 << 11)) >> 12;
152 	/* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
153 	 * and addi are signed and RVC checks will perform signed comparisons.
154 	 */
155 	s64 lower = ((val & 0xfff) << 52) >> 52;
156 	int shift;
157 
158 	if (is_32b_int(val)) {
159 		if (upper)
160 			emit_lui(rd, upper, ctx);
161 
162 		if (!upper) {
163 			emit_li(rd, lower, ctx);
164 			return;
165 		}
166 
167 		emit_addiw(rd, rd, lower, ctx);
168 		return;
169 	}
170 
171 	shift = __ffs(upper);
172 	upper >>= shift;
173 	shift += 12;
174 
175 	emit_imm(rd, upper, ctx);
176 
177 	emit_slli(rd, rd, shift, ctx);
178 	if (lower)
179 		emit_addi(rd, rd, lower, ctx);
180 }
181 
182 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
183 {
184 	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
185 
186 	if (seen_reg(RV_REG_RA, ctx)) {
187 		emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
188 		store_offset -= 8;
189 	}
190 	emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
191 	store_offset -= 8;
192 	if (seen_reg(RV_REG_S1, ctx)) {
193 		emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
194 		store_offset -= 8;
195 	}
196 	if (seen_reg(RV_REG_S2, ctx)) {
197 		emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
198 		store_offset -= 8;
199 	}
200 	if (seen_reg(RV_REG_S3, ctx)) {
201 		emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
202 		store_offset -= 8;
203 	}
204 	if (seen_reg(RV_REG_S4, ctx)) {
205 		emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
206 		store_offset -= 8;
207 	}
208 	if (seen_reg(RV_REG_S5, ctx)) {
209 		emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
210 		store_offset -= 8;
211 	}
212 	if (seen_reg(RV_REG_S6, ctx)) {
213 		emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
214 		store_offset -= 8;
215 	}
216 
217 	emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
218 	/* Set return value. */
219 	if (!is_tail_call)
220 		emit_mv(RV_REG_A0, RV_REG_A5, ctx);
221 	emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
222 		  is_tail_call ? 4 : 0, /* skip TCC init */
223 		  ctx);
224 }
225 
226 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
227 		     struct rv_jit_context *ctx)
228 {
229 	switch (cond) {
230 	case BPF_JEQ:
231 		emit(rv_beq(rd, rs, rvoff >> 1), ctx);
232 		return;
233 	case BPF_JGT:
234 		emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
235 		return;
236 	case BPF_JLT:
237 		emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
238 		return;
239 	case BPF_JGE:
240 		emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
241 		return;
242 	case BPF_JLE:
243 		emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
244 		return;
245 	case BPF_JNE:
246 		emit(rv_bne(rd, rs, rvoff >> 1), ctx);
247 		return;
248 	case BPF_JSGT:
249 		emit(rv_blt(rs, rd, rvoff >> 1), ctx);
250 		return;
251 	case BPF_JSLT:
252 		emit(rv_blt(rd, rs, rvoff >> 1), ctx);
253 		return;
254 	case BPF_JSGE:
255 		emit(rv_bge(rd, rs, rvoff >> 1), ctx);
256 		return;
257 	case BPF_JSLE:
258 		emit(rv_bge(rs, rd, rvoff >> 1), ctx);
259 	}
260 }
261 
262 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
263 			struct rv_jit_context *ctx)
264 {
265 	s64 upper, lower;
266 
267 	if (is_13b_int(rvoff)) {
268 		emit_bcc(cond, rd, rs, rvoff, ctx);
269 		return;
270 	}
271 
272 	/* Adjust for jal */
273 	rvoff -= 4;
274 
275 	/* Transform, e.g.:
276 	 *   bne rd,rs,foo
277 	 * to
278 	 *   beq rd,rs,<.L1>
279 	 *   (auipc foo)
280 	 *   jal(r) foo
281 	 * .L1
282 	 */
283 	cond = invert_bpf_cond(cond);
284 	if (is_21b_int(rvoff)) {
285 		emit_bcc(cond, rd, rs, 8, ctx);
286 		emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
287 		return;
288 	}
289 
290 	/* 32b No need for an additional rvoff adjustment, since we
291 	 * get that from the auipc at PC', where PC = PC' + 4.
292 	 */
293 	upper = (rvoff + (1 << 11)) >> 12;
294 	lower = rvoff & 0xfff;
295 
296 	emit_bcc(cond, rd, rs, 12, ctx);
297 	emit(rv_auipc(RV_REG_T1, upper), ctx);
298 	emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
299 }
300 
301 static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
302 {
303 	emit_slli(reg, reg, 32, ctx);
304 	emit_srli(reg, reg, 32, ctx);
305 }
306 
307 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
308 {
309 	int tc_ninsn, off, start_insn = ctx->ninsns;
310 	u8 tcc = rv_tail_call_reg(ctx);
311 
312 	/* a0: &ctx
313 	 * a1: &array
314 	 * a2: index
315 	 *
316 	 * if (index >= array->map.max_entries)
317 	 *	goto out;
318 	 */
319 	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
320 		   ctx->offset[0];
321 	emit_zext_32(RV_REG_A2, ctx);
322 
323 	off = offsetof(struct bpf_array, map.max_entries);
324 	if (is_12b_check(off, insn))
325 		return -1;
326 	emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
327 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
328 	emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
329 
330 	/* if (TCC-- < 0)
331 	 *     goto out;
332 	 */
333 	emit_addi(RV_REG_T1, tcc, -1, ctx);
334 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
335 	emit_branch(BPF_JSLT, tcc, RV_REG_ZERO, off, ctx);
336 
337 	/* prog = array->ptrs[index];
338 	 * if (!prog)
339 	 *     goto out;
340 	 */
341 	emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
342 	emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
343 	off = offsetof(struct bpf_array, ptrs);
344 	if (is_12b_check(off, insn))
345 		return -1;
346 	emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
347 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
348 	emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
349 
350 	/* goto *(prog->bpf_func + 4); */
351 	off = offsetof(struct bpf_prog, bpf_func);
352 	if (is_12b_check(off, insn))
353 		return -1;
354 	emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
355 	emit_mv(RV_REG_TCC, RV_REG_T1, ctx);
356 	__build_epilogue(true, ctx);
357 	return 0;
358 }
359 
360 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
361 		      struct rv_jit_context *ctx)
362 {
363 	u8 code = insn->code;
364 
365 	switch (code) {
366 	case BPF_JMP | BPF_JA:
367 	case BPF_JMP | BPF_CALL:
368 	case BPF_JMP | BPF_EXIT:
369 	case BPF_JMP | BPF_TAIL_CALL:
370 		break;
371 	default:
372 		*rd = bpf_to_rv_reg(insn->dst_reg, ctx);
373 	}
374 
375 	if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
376 	    code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
377 	    code & BPF_LDX || code & BPF_STX)
378 		*rs = bpf_to_rv_reg(insn->src_reg, ctx);
379 }
380 
381 static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
382 {
383 	emit_mv(RV_REG_T2, *rd, ctx);
384 	emit_zext_32(RV_REG_T2, ctx);
385 	emit_mv(RV_REG_T1, *rs, ctx);
386 	emit_zext_32(RV_REG_T1, ctx);
387 	*rd = RV_REG_T2;
388 	*rs = RV_REG_T1;
389 }
390 
391 static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
392 {
393 	emit_addiw(RV_REG_T2, *rd, 0, ctx);
394 	emit_addiw(RV_REG_T1, *rs, 0, ctx);
395 	*rd = RV_REG_T2;
396 	*rs = RV_REG_T1;
397 }
398 
399 static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
400 {
401 	emit_mv(RV_REG_T2, *rd, ctx);
402 	emit_zext_32(RV_REG_T2, ctx);
403 	emit_zext_32(RV_REG_T1, ctx);
404 	*rd = RV_REG_T2;
405 }
406 
407 static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
408 {
409 	emit_addiw(RV_REG_T2, *rd, 0, ctx);
410 	*rd = RV_REG_T2;
411 }
412 
413 static int emit_jump_and_link(u8 rd, s64 rvoff, bool force_jalr,
414 			      struct rv_jit_context *ctx)
415 {
416 	s64 upper, lower;
417 
418 	if (rvoff && is_21b_int(rvoff) && !force_jalr) {
419 		emit(rv_jal(rd, rvoff >> 1), ctx);
420 		return 0;
421 	} else if (in_auipc_jalr_range(rvoff)) {
422 		upper = (rvoff + (1 << 11)) >> 12;
423 		lower = rvoff & 0xfff;
424 		emit(rv_auipc(RV_REG_T1, upper), ctx);
425 		emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
426 		return 0;
427 	}
428 
429 	pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
430 	return -ERANGE;
431 }
432 
433 static bool is_signed_bpf_cond(u8 cond)
434 {
435 	return cond == BPF_JSGT || cond == BPF_JSLT ||
436 		cond == BPF_JSGE || cond == BPF_JSLE;
437 }
438 
439 static int emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
440 {
441 	s64 off = 0;
442 	u64 ip;
443 	u8 rd;
444 	int ret;
445 
446 	if (addr && ctx->insns) {
447 		ip = (u64)(long)(ctx->insns + ctx->ninsns);
448 		off = addr - ip;
449 	}
450 
451 	ret = emit_jump_and_link(RV_REG_RA, off, !fixed, ctx);
452 	if (ret)
453 		return ret;
454 	rd = bpf_to_rv_reg(BPF_REG_0, ctx);
455 	emit_mv(rd, RV_REG_A0, ctx);
456 	return 0;
457 }
458 
459 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
460 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
461 
462 int rv_bpf_fixup_exception(const struct exception_table_entry *ex,
463 				struct pt_regs *regs);
464 int rv_bpf_fixup_exception(const struct exception_table_entry *ex,
465 				struct pt_regs *regs)
466 {
467 	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
468 	int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
469 
470 	*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
471 	regs->epc = (unsigned long)&ex->fixup - offset;
472 
473 	return 1;
474 }
475 
476 /* For accesses to BTF pointers, add an entry to the exception table */
477 static int add_exception_handler(const struct bpf_insn *insn,
478 				 struct rv_jit_context *ctx,
479 				 int dst_reg, int insn_len)
480 {
481 	struct exception_table_entry *ex;
482 	unsigned long pc;
483 	off_t offset;
484 
485 	if (!ctx->insns || !ctx->prog->aux->extable || BPF_MODE(insn->code) != BPF_PROBE_MEM)
486 		return 0;
487 
488 	if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
489 		return -EINVAL;
490 
491 	if (WARN_ON_ONCE(insn_len > ctx->ninsns))
492 		return -EINVAL;
493 
494 	if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
495 		return -EINVAL;
496 
497 	ex = &ctx->prog->aux->extable[ctx->nexentries];
498 	pc = (unsigned long)&ctx->insns[ctx->ninsns - insn_len];
499 
500 	offset = pc - (long)&ex->insn;
501 	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
502 		return -ERANGE;
503 	ex->insn = pc;
504 
505 	/*
506 	 * Since the extable follows the program, the fixup offset is always
507 	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
508 	 * to keep things simple, and put the destination register in the upper
509 	 * bits. We don't need to worry about buildtime or runtime sort
510 	 * modifying the upper bits because the table is already sorted, and
511 	 * isn't part of the main exception table.
512 	 */
513 	offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
514 	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
515 		return -ERANGE;
516 
517 	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
518 		FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
519 
520 	ctx->nexentries++;
521 	return 0;
522 }
523 
524 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
525 		      bool extra_pass)
526 {
527 	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
528 		    BPF_CLASS(insn->code) == BPF_JMP;
529 	int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
530 	struct bpf_prog_aux *aux = ctx->prog->aux;
531 	u8 rd = -1, rs = -1, code = insn->code;
532 	s16 off = insn->off;
533 	s32 imm = insn->imm;
534 
535 	init_regs(&rd, &rs, insn, ctx);
536 
537 	switch (code) {
538 	/* dst = src */
539 	case BPF_ALU | BPF_MOV | BPF_X:
540 	case BPF_ALU64 | BPF_MOV | BPF_X:
541 		if (imm == 1) {
542 			/* Special mov32 for zext */
543 			emit_zext_32(rd, ctx);
544 			break;
545 		}
546 		emit_mv(rd, rs, ctx);
547 		if (!is64 && !aux->verifier_zext)
548 			emit_zext_32(rd, ctx);
549 		break;
550 
551 	/* dst = dst OP src */
552 	case BPF_ALU | BPF_ADD | BPF_X:
553 	case BPF_ALU64 | BPF_ADD | BPF_X:
554 		emit_add(rd, rd, rs, ctx);
555 		if (!is64 && !aux->verifier_zext)
556 			emit_zext_32(rd, ctx);
557 		break;
558 	case BPF_ALU | BPF_SUB | BPF_X:
559 	case BPF_ALU64 | BPF_SUB | BPF_X:
560 		if (is64)
561 			emit_sub(rd, rd, rs, ctx);
562 		else
563 			emit_subw(rd, rd, rs, ctx);
564 
565 		if (!is64 && !aux->verifier_zext)
566 			emit_zext_32(rd, ctx);
567 		break;
568 	case BPF_ALU | BPF_AND | BPF_X:
569 	case BPF_ALU64 | BPF_AND | BPF_X:
570 		emit_and(rd, rd, rs, ctx);
571 		if (!is64 && !aux->verifier_zext)
572 			emit_zext_32(rd, ctx);
573 		break;
574 	case BPF_ALU | BPF_OR | BPF_X:
575 	case BPF_ALU64 | BPF_OR | BPF_X:
576 		emit_or(rd, rd, rs, ctx);
577 		if (!is64 && !aux->verifier_zext)
578 			emit_zext_32(rd, ctx);
579 		break;
580 	case BPF_ALU | BPF_XOR | BPF_X:
581 	case BPF_ALU64 | BPF_XOR | BPF_X:
582 		emit_xor(rd, rd, rs, ctx);
583 		if (!is64 && !aux->verifier_zext)
584 			emit_zext_32(rd, ctx);
585 		break;
586 	case BPF_ALU | BPF_MUL | BPF_X:
587 	case BPF_ALU64 | BPF_MUL | BPF_X:
588 		emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
589 		if (!is64 && !aux->verifier_zext)
590 			emit_zext_32(rd, ctx);
591 		break;
592 	case BPF_ALU | BPF_DIV | BPF_X:
593 	case BPF_ALU64 | BPF_DIV | BPF_X:
594 		emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
595 		if (!is64 && !aux->verifier_zext)
596 			emit_zext_32(rd, ctx);
597 		break;
598 	case BPF_ALU | BPF_MOD | BPF_X:
599 	case BPF_ALU64 | BPF_MOD | BPF_X:
600 		emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
601 		if (!is64 && !aux->verifier_zext)
602 			emit_zext_32(rd, ctx);
603 		break;
604 	case BPF_ALU | BPF_LSH | BPF_X:
605 	case BPF_ALU64 | BPF_LSH | BPF_X:
606 		emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
607 		if (!is64 && !aux->verifier_zext)
608 			emit_zext_32(rd, ctx);
609 		break;
610 	case BPF_ALU | BPF_RSH | BPF_X:
611 	case BPF_ALU64 | BPF_RSH | BPF_X:
612 		emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
613 		if (!is64 && !aux->verifier_zext)
614 			emit_zext_32(rd, ctx);
615 		break;
616 	case BPF_ALU | BPF_ARSH | BPF_X:
617 	case BPF_ALU64 | BPF_ARSH | BPF_X:
618 		emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
619 		if (!is64 && !aux->verifier_zext)
620 			emit_zext_32(rd, ctx);
621 		break;
622 
623 	/* dst = -dst */
624 	case BPF_ALU | BPF_NEG:
625 	case BPF_ALU64 | BPF_NEG:
626 		emit_sub(rd, RV_REG_ZERO, rd, ctx);
627 		if (!is64 && !aux->verifier_zext)
628 			emit_zext_32(rd, ctx);
629 		break;
630 
631 	/* dst = BSWAP##imm(dst) */
632 	case BPF_ALU | BPF_END | BPF_FROM_LE:
633 		switch (imm) {
634 		case 16:
635 			emit_slli(rd, rd, 48, ctx);
636 			emit_srli(rd, rd, 48, ctx);
637 			break;
638 		case 32:
639 			if (!aux->verifier_zext)
640 				emit_zext_32(rd, ctx);
641 			break;
642 		case 64:
643 			/* Do nothing */
644 			break;
645 		}
646 		break;
647 
648 	case BPF_ALU | BPF_END | BPF_FROM_BE:
649 		emit_li(RV_REG_T2, 0, ctx);
650 
651 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
652 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
653 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
654 		emit_srli(rd, rd, 8, ctx);
655 		if (imm == 16)
656 			goto out_be;
657 
658 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
659 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
660 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
661 		emit_srli(rd, rd, 8, ctx);
662 
663 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
664 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
665 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
666 		emit_srli(rd, rd, 8, ctx);
667 		if (imm == 32)
668 			goto out_be;
669 
670 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
671 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
672 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
673 		emit_srli(rd, rd, 8, ctx);
674 
675 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
676 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
677 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
678 		emit_srli(rd, rd, 8, ctx);
679 
680 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
681 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
682 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
683 		emit_srli(rd, rd, 8, ctx);
684 
685 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
686 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
687 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
688 		emit_srli(rd, rd, 8, ctx);
689 out_be:
690 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
691 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
692 
693 		emit_mv(rd, RV_REG_T2, ctx);
694 		break;
695 
696 	/* dst = imm */
697 	case BPF_ALU | BPF_MOV | BPF_K:
698 	case BPF_ALU64 | BPF_MOV | BPF_K:
699 		emit_imm(rd, imm, ctx);
700 		if (!is64 && !aux->verifier_zext)
701 			emit_zext_32(rd, ctx);
702 		break;
703 
704 	/* dst = dst OP imm */
705 	case BPF_ALU | BPF_ADD | BPF_K:
706 	case BPF_ALU64 | BPF_ADD | BPF_K:
707 		if (is_12b_int(imm)) {
708 			emit_addi(rd, rd, imm, ctx);
709 		} else {
710 			emit_imm(RV_REG_T1, imm, ctx);
711 			emit_add(rd, rd, RV_REG_T1, ctx);
712 		}
713 		if (!is64 && !aux->verifier_zext)
714 			emit_zext_32(rd, ctx);
715 		break;
716 	case BPF_ALU | BPF_SUB | BPF_K:
717 	case BPF_ALU64 | BPF_SUB | BPF_K:
718 		if (is_12b_int(-imm)) {
719 			emit_addi(rd, rd, -imm, ctx);
720 		} else {
721 			emit_imm(RV_REG_T1, imm, ctx);
722 			emit_sub(rd, rd, RV_REG_T1, ctx);
723 		}
724 		if (!is64 && !aux->verifier_zext)
725 			emit_zext_32(rd, ctx);
726 		break;
727 	case BPF_ALU | BPF_AND | BPF_K:
728 	case BPF_ALU64 | BPF_AND | BPF_K:
729 		if (is_12b_int(imm)) {
730 			emit_andi(rd, rd, imm, ctx);
731 		} else {
732 			emit_imm(RV_REG_T1, imm, ctx);
733 			emit_and(rd, rd, RV_REG_T1, ctx);
734 		}
735 		if (!is64 && !aux->verifier_zext)
736 			emit_zext_32(rd, ctx);
737 		break;
738 	case BPF_ALU | BPF_OR | BPF_K:
739 	case BPF_ALU64 | BPF_OR | BPF_K:
740 		if (is_12b_int(imm)) {
741 			emit(rv_ori(rd, rd, imm), ctx);
742 		} else {
743 			emit_imm(RV_REG_T1, imm, ctx);
744 			emit_or(rd, rd, RV_REG_T1, ctx);
745 		}
746 		if (!is64 && !aux->verifier_zext)
747 			emit_zext_32(rd, ctx);
748 		break;
749 	case BPF_ALU | BPF_XOR | BPF_K:
750 	case BPF_ALU64 | BPF_XOR | BPF_K:
751 		if (is_12b_int(imm)) {
752 			emit(rv_xori(rd, rd, imm), ctx);
753 		} else {
754 			emit_imm(RV_REG_T1, imm, ctx);
755 			emit_xor(rd, rd, RV_REG_T1, ctx);
756 		}
757 		if (!is64 && !aux->verifier_zext)
758 			emit_zext_32(rd, ctx);
759 		break;
760 	case BPF_ALU | BPF_MUL | BPF_K:
761 	case BPF_ALU64 | BPF_MUL | BPF_K:
762 		emit_imm(RV_REG_T1, imm, ctx);
763 		emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
764 		     rv_mulw(rd, rd, RV_REG_T1), ctx);
765 		if (!is64 && !aux->verifier_zext)
766 			emit_zext_32(rd, ctx);
767 		break;
768 	case BPF_ALU | BPF_DIV | BPF_K:
769 	case BPF_ALU64 | BPF_DIV | BPF_K:
770 		emit_imm(RV_REG_T1, imm, ctx);
771 		emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
772 		     rv_divuw(rd, rd, RV_REG_T1), ctx);
773 		if (!is64 && !aux->verifier_zext)
774 			emit_zext_32(rd, ctx);
775 		break;
776 	case BPF_ALU | BPF_MOD | BPF_K:
777 	case BPF_ALU64 | BPF_MOD | BPF_K:
778 		emit_imm(RV_REG_T1, imm, ctx);
779 		emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
780 		     rv_remuw(rd, rd, RV_REG_T1), ctx);
781 		if (!is64 && !aux->verifier_zext)
782 			emit_zext_32(rd, ctx);
783 		break;
784 	case BPF_ALU | BPF_LSH | BPF_K:
785 	case BPF_ALU64 | BPF_LSH | BPF_K:
786 		emit_slli(rd, rd, imm, ctx);
787 
788 		if (!is64 && !aux->verifier_zext)
789 			emit_zext_32(rd, ctx);
790 		break;
791 	case BPF_ALU | BPF_RSH | BPF_K:
792 	case BPF_ALU64 | BPF_RSH | BPF_K:
793 		if (is64)
794 			emit_srli(rd, rd, imm, ctx);
795 		else
796 			emit(rv_srliw(rd, rd, imm), ctx);
797 
798 		if (!is64 && !aux->verifier_zext)
799 			emit_zext_32(rd, ctx);
800 		break;
801 	case BPF_ALU | BPF_ARSH | BPF_K:
802 	case BPF_ALU64 | BPF_ARSH | BPF_K:
803 		if (is64)
804 			emit_srai(rd, rd, imm, ctx);
805 		else
806 			emit(rv_sraiw(rd, rd, imm), ctx);
807 
808 		if (!is64 && !aux->verifier_zext)
809 			emit_zext_32(rd, ctx);
810 		break;
811 
812 	/* JUMP off */
813 	case BPF_JMP | BPF_JA:
814 		rvoff = rv_offset(i, off, ctx);
815 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
816 		if (ret)
817 			return ret;
818 		break;
819 
820 	/* IF (dst COND src) JUMP off */
821 	case BPF_JMP | BPF_JEQ | BPF_X:
822 	case BPF_JMP32 | BPF_JEQ | BPF_X:
823 	case BPF_JMP | BPF_JGT | BPF_X:
824 	case BPF_JMP32 | BPF_JGT | BPF_X:
825 	case BPF_JMP | BPF_JLT | BPF_X:
826 	case BPF_JMP32 | BPF_JLT | BPF_X:
827 	case BPF_JMP | BPF_JGE | BPF_X:
828 	case BPF_JMP32 | BPF_JGE | BPF_X:
829 	case BPF_JMP | BPF_JLE | BPF_X:
830 	case BPF_JMP32 | BPF_JLE | BPF_X:
831 	case BPF_JMP | BPF_JNE | BPF_X:
832 	case BPF_JMP32 | BPF_JNE | BPF_X:
833 	case BPF_JMP | BPF_JSGT | BPF_X:
834 	case BPF_JMP32 | BPF_JSGT | BPF_X:
835 	case BPF_JMP | BPF_JSLT | BPF_X:
836 	case BPF_JMP32 | BPF_JSLT | BPF_X:
837 	case BPF_JMP | BPF_JSGE | BPF_X:
838 	case BPF_JMP32 | BPF_JSGE | BPF_X:
839 	case BPF_JMP | BPF_JSLE | BPF_X:
840 	case BPF_JMP32 | BPF_JSLE | BPF_X:
841 	case BPF_JMP | BPF_JSET | BPF_X:
842 	case BPF_JMP32 | BPF_JSET | BPF_X:
843 		rvoff = rv_offset(i, off, ctx);
844 		if (!is64) {
845 			s = ctx->ninsns;
846 			if (is_signed_bpf_cond(BPF_OP(code)))
847 				emit_sext_32_rd_rs(&rd, &rs, ctx);
848 			else
849 				emit_zext_32_rd_rs(&rd, &rs, ctx);
850 			e = ctx->ninsns;
851 
852 			/* Adjust for extra insns */
853 			rvoff -= ninsns_rvoff(e - s);
854 		}
855 
856 		if (BPF_OP(code) == BPF_JSET) {
857 			/* Adjust for and */
858 			rvoff -= 4;
859 			emit_and(RV_REG_T1, rd, rs, ctx);
860 			emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
861 				    ctx);
862 		} else {
863 			emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
864 		}
865 		break;
866 
867 	/* IF (dst COND imm) JUMP off */
868 	case BPF_JMP | BPF_JEQ | BPF_K:
869 	case BPF_JMP32 | BPF_JEQ | BPF_K:
870 	case BPF_JMP | BPF_JGT | BPF_K:
871 	case BPF_JMP32 | BPF_JGT | BPF_K:
872 	case BPF_JMP | BPF_JLT | BPF_K:
873 	case BPF_JMP32 | BPF_JLT | BPF_K:
874 	case BPF_JMP | BPF_JGE | BPF_K:
875 	case BPF_JMP32 | BPF_JGE | BPF_K:
876 	case BPF_JMP | BPF_JLE | BPF_K:
877 	case BPF_JMP32 | BPF_JLE | BPF_K:
878 	case BPF_JMP | BPF_JNE | BPF_K:
879 	case BPF_JMP32 | BPF_JNE | BPF_K:
880 	case BPF_JMP | BPF_JSGT | BPF_K:
881 	case BPF_JMP32 | BPF_JSGT | BPF_K:
882 	case BPF_JMP | BPF_JSLT | BPF_K:
883 	case BPF_JMP32 | BPF_JSLT | BPF_K:
884 	case BPF_JMP | BPF_JSGE | BPF_K:
885 	case BPF_JMP32 | BPF_JSGE | BPF_K:
886 	case BPF_JMP | BPF_JSLE | BPF_K:
887 	case BPF_JMP32 | BPF_JSLE | BPF_K:
888 		rvoff = rv_offset(i, off, ctx);
889 		s = ctx->ninsns;
890 		if (imm) {
891 			emit_imm(RV_REG_T1, imm, ctx);
892 			rs = RV_REG_T1;
893 		} else {
894 			/* If imm is 0, simply use zero register. */
895 			rs = RV_REG_ZERO;
896 		}
897 		if (!is64) {
898 			if (is_signed_bpf_cond(BPF_OP(code)))
899 				emit_sext_32_rd(&rd, ctx);
900 			else
901 				emit_zext_32_rd_t1(&rd, ctx);
902 		}
903 		e = ctx->ninsns;
904 
905 		/* Adjust for extra insns */
906 		rvoff -= ninsns_rvoff(e - s);
907 		emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
908 		break;
909 
910 	case BPF_JMP | BPF_JSET | BPF_K:
911 	case BPF_JMP32 | BPF_JSET | BPF_K:
912 		rvoff = rv_offset(i, off, ctx);
913 		s = ctx->ninsns;
914 		if (is_12b_int(imm)) {
915 			emit_andi(RV_REG_T1, rd, imm, ctx);
916 		} else {
917 			emit_imm(RV_REG_T1, imm, ctx);
918 			emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
919 		}
920 		/* For jset32, we should clear the upper 32 bits of t1, but
921 		 * sign-extension is sufficient here and saves one instruction,
922 		 * as t1 is used only in comparison against zero.
923 		 */
924 		if (!is64 && imm < 0)
925 			emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
926 		e = ctx->ninsns;
927 		rvoff -= ninsns_rvoff(e - s);
928 		emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
929 		break;
930 
931 	/* function call */
932 	case BPF_JMP | BPF_CALL:
933 	{
934 		bool fixed;
935 		u64 addr;
936 
937 		mark_call(ctx);
938 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &addr,
939 					    &fixed);
940 		if (ret < 0)
941 			return ret;
942 		ret = emit_call(fixed, addr, ctx);
943 		if (ret)
944 			return ret;
945 		break;
946 	}
947 	/* tail call */
948 	case BPF_JMP | BPF_TAIL_CALL:
949 		if (emit_bpf_tail_call(i, ctx))
950 			return -1;
951 		break;
952 
953 	/* function return */
954 	case BPF_JMP | BPF_EXIT:
955 		if (i == ctx->prog->len - 1)
956 			break;
957 
958 		rvoff = epilogue_offset(ctx);
959 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
960 		if (ret)
961 			return ret;
962 		break;
963 
964 	/* dst = imm64 */
965 	case BPF_LD | BPF_IMM | BPF_DW:
966 	{
967 		struct bpf_insn insn1 = insn[1];
968 		u64 imm64;
969 
970 		imm64 = (u64)insn1.imm << 32 | (u32)imm;
971 		emit_imm(rd, imm64, ctx);
972 		return 1;
973 	}
974 
975 	/* LDX: dst = *(size *)(src + off) */
976 	case BPF_LDX | BPF_MEM | BPF_B:
977 	case BPF_LDX | BPF_MEM | BPF_H:
978 	case BPF_LDX | BPF_MEM | BPF_W:
979 	case BPF_LDX | BPF_MEM | BPF_DW:
980 	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
981 	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
982 	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
983 	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
984 	{
985 		int insn_len, insns_start;
986 
987 		switch (BPF_SIZE(code)) {
988 		case BPF_B:
989 			if (is_12b_int(off)) {
990 				insns_start = ctx->ninsns;
991 				emit(rv_lbu(rd, off, rs), ctx);
992 				insn_len = ctx->ninsns - insns_start;
993 				break;
994 			}
995 
996 			emit_imm(RV_REG_T1, off, ctx);
997 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
998 			insns_start = ctx->ninsns;
999 			emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1000 			insn_len = ctx->ninsns - insns_start;
1001 			if (insn_is_zext(&insn[1]))
1002 				return 1;
1003 			break;
1004 		case BPF_H:
1005 			if (is_12b_int(off)) {
1006 				insns_start = ctx->ninsns;
1007 				emit(rv_lhu(rd, off, rs), ctx);
1008 				insn_len = ctx->ninsns - insns_start;
1009 				break;
1010 			}
1011 
1012 			emit_imm(RV_REG_T1, off, ctx);
1013 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1014 			insns_start = ctx->ninsns;
1015 			emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1016 			insn_len = ctx->ninsns - insns_start;
1017 			if (insn_is_zext(&insn[1]))
1018 				return 1;
1019 			break;
1020 		case BPF_W:
1021 			if (is_12b_int(off)) {
1022 				insns_start = ctx->ninsns;
1023 				emit(rv_lwu(rd, off, rs), ctx);
1024 				insn_len = ctx->ninsns - insns_start;
1025 				break;
1026 			}
1027 
1028 			emit_imm(RV_REG_T1, off, ctx);
1029 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1030 			insns_start = ctx->ninsns;
1031 			emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1032 			insn_len = ctx->ninsns - insns_start;
1033 			if (insn_is_zext(&insn[1]))
1034 				return 1;
1035 			break;
1036 		case BPF_DW:
1037 			if (is_12b_int(off)) {
1038 				insns_start = ctx->ninsns;
1039 				emit_ld(rd, off, rs, ctx);
1040 				insn_len = ctx->ninsns - insns_start;
1041 				break;
1042 			}
1043 
1044 			emit_imm(RV_REG_T1, off, ctx);
1045 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1046 			insns_start = ctx->ninsns;
1047 			emit_ld(rd, 0, RV_REG_T1, ctx);
1048 			insn_len = ctx->ninsns - insns_start;
1049 			break;
1050 		}
1051 
1052 		ret = add_exception_handler(insn, ctx, rd, insn_len);
1053 		if (ret)
1054 			return ret;
1055 		break;
1056 	}
1057 	/* speculation barrier */
1058 	case BPF_ST | BPF_NOSPEC:
1059 		break;
1060 
1061 	/* ST: *(size *)(dst + off) = imm */
1062 	case BPF_ST | BPF_MEM | BPF_B:
1063 		emit_imm(RV_REG_T1, imm, ctx);
1064 		if (is_12b_int(off)) {
1065 			emit(rv_sb(rd, off, RV_REG_T1), ctx);
1066 			break;
1067 		}
1068 
1069 		emit_imm(RV_REG_T2, off, ctx);
1070 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1071 		emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1072 		break;
1073 
1074 	case BPF_ST | BPF_MEM | BPF_H:
1075 		emit_imm(RV_REG_T1, imm, ctx);
1076 		if (is_12b_int(off)) {
1077 			emit(rv_sh(rd, off, RV_REG_T1), ctx);
1078 			break;
1079 		}
1080 
1081 		emit_imm(RV_REG_T2, off, ctx);
1082 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1083 		emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1084 		break;
1085 	case BPF_ST | BPF_MEM | BPF_W:
1086 		emit_imm(RV_REG_T1, imm, ctx);
1087 		if (is_12b_int(off)) {
1088 			emit_sw(rd, off, RV_REG_T1, ctx);
1089 			break;
1090 		}
1091 
1092 		emit_imm(RV_REG_T2, off, ctx);
1093 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1094 		emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1095 		break;
1096 	case BPF_ST | BPF_MEM | BPF_DW:
1097 		emit_imm(RV_REG_T1, imm, ctx);
1098 		if (is_12b_int(off)) {
1099 			emit_sd(rd, off, RV_REG_T1, ctx);
1100 			break;
1101 		}
1102 
1103 		emit_imm(RV_REG_T2, off, ctx);
1104 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1105 		emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1106 		break;
1107 
1108 	/* STX: *(size *)(dst + off) = src */
1109 	case BPF_STX | BPF_MEM | BPF_B:
1110 		if (is_12b_int(off)) {
1111 			emit(rv_sb(rd, off, rs), ctx);
1112 			break;
1113 		}
1114 
1115 		emit_imm(RV_REG_T1, off, ctx);
1116 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1117 		emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1118 		break;
1119 	case BPF_STX | BPF_MEM | BPF_H:
1120 		if (is_12b_int(off)) {
1121 			emit(rv_sh(rd, off, rs), ctx);
1122 			break;
1123 		}
1124 
1125 		emit_imm(RV_REG_T1, off, ctx);
1126 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1127 		emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1128 		break;
1129 	case BPF_STX | BPF_MEM | BPF_W:
1130 		if (is_12b_int(off)) {
1131 			emit_sw(rd, off, rs, ctx);
1132 			break;
1133 		}
1134 
1135 		emit_imm(RV_REG_T1, off, ctx);
1136 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1137 		emit_sw(RV_REG_T1, 0, rs, ctx);
1138 		break;
1139 	case BPF_STX | BPF_MEM | BPF_DW:
1140 		if (is_12b_int(off)) {
1141 			emit_sd(rd, off, rs, ctx);
1142 			break;
1143 		}
1144 
1145 		emit_imm(RV_REG_T1, off, ctx);
1146 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1147 		emit_sd(RV_REG_T1, 0, rs, ctx);
1148 		break;
1149 	case BPF_STX | BPF_ATOMIC | BPF_W:
1150 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1151 		if (insn->imm != BPF_ADD) {
1152 			pr_err("bpf-jit: not supported: atomic operation %02x ***\n",
1153 			       insn->imm);
1154 			return -EINVAL;
1155 		}
1156 
1157 		/* atomic_add: lock *(u32 *)(dst + off) += src
1158 		 * atomic_add: lock *(u64 *)(dst + off) += src
1159 		 */
1160 
1161 		if (off) {
1162 			if (is_12b_int(off)) {
1163 				emit_addi(RV_REG_T1, rd, off, ctx);
1164 			} else {
1165 				emit_imm(RV_REG_T1, off, ctx);
1166 				emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1167 			}
1168 
1169 			rd = RV_REG_T1;
1170 		}
1171 
1172 		emit(BPF_SIZE(code) == BPF_W ?
1173 		     rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0) :
1174 		     rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0), ctx);
1175 		break;
1176 	default:
1177 		pr_err("bpf-jit: unknown opcode %02x\n", code);
1178 		return -EINVAL;
1179 	}
1180 
1181 	return 0;
1182 }
1183 
1184 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1185 {
1186 	int stack_adjust = 0, store_offset, bpf_stack_adjust;
1187 
1188 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1189 	if (bpf_stack_adjust)
1190 		mark_fp(ctx);
1191 
1192 	if (seen_reg(RV_REG_RA, ctx))
1193 		stack_adjust += 8;
1194 	stack_adjust += 8; /* RV_REG_FP */
1195 	if (seen_reg(RV_REG_S1, ctx))
1196 		stack_adjust += 8;
1197 	if (seen_reg(RV_REG_S2, ctx))
1198 		stack_adjust += 8;
1199 	if (seen_reg(RV_REG_S3, ctx))
1200 		stack_adjust += 8;
1201 	if (seen_reg(RV_REG_S4, ctx))
1202 		stack_adjust += 8;
1203 	if (seen_reg(RV_REG_S5, ctx))
1204 		stack_adjust += 8;
1205 	if (seen_reg(RV_REG_S6, ctx))
1206 		stack_adjust += 8;
1207 
1208 	stack_adjust = round_up(stack_adjust, 16);
1209 	stack_adjust += bpf_stack_adjust;
1210 
1211 	store_offset = stack_adjust - 8;
1212 
1213 	/* First instruction is always setting the tail-call-counter
1214 	 * (TCC) register. This instruction is skipped for tail calls.
1215 	 * Force using a 4-byte (non-compressed) instruction.
1216 	 */
1217 	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1218 
1219 	emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1220 
1221 	if (seen_reg(RV_REG_RA, ctx)) {
1222 		emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1223 		store_offset -= 8;
1224 	}
1225 	emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1226 	store_offset -= 8;
1227 	if (seen_reg(RV_REG_S1, ctx)) {
1228 		emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1229 		store_offset -= 8;
1230 	}
1231 	if (seen_reg(RV_REG_S2, ctx)) {
1232 		emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1233 		store_offset -= 8;
1234 	}
1235 	if (seen_reg(RV_REG_S3, ctx)) {
1236 		emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1237 		store_offset -= 8;
1238 	}
1239 	if (seen_reg(RV_REG_S4, ctx)) {
1240 		emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1241 		store_offset -= 8;
1242 	}
1243 	if (seen_reg(RV_REG_S5, ctx)) {
1244 		emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1245 		store_offset -= 8;
1246 	}
1247 	if (seen_reg(RV_REG_S6, ctx)) {
1248 		emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1249 		store_offset -= 8;
1250 	}
1251 
1252 	emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1253 
1254 	if (bpf_stack_adjust)
1255 		emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1256 
1257 	/* Program contains calls and tail calls, so RV_REG_TCC need
1258 	 * to be saved across calls.
1259 	 */
1260 	if (seen_tail_call(ctx) && seen_call(ctx))
1261 		emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1262 
1263 	ctx->stack_size = stack_adjust;
1264 }
1265 
1266 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1267 {
1268 	__build_epilogue(false, ctx);
1269 }
1270