xref: /openbmc/linux/arch/riscv/net/bpf_jit_comp64.c (revision 43ee1e3f)
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7 
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include "bpf_jit.h"
12 
13 #define RV_REG_TCC RV_REG_A6
14 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
15 
16 static const int regmap[] = {
17 	[BPF_REG_0] =	RV_REG_A5,
18 	[BPF_REG_1] =	RV_REG_A0,
19 	[BPF_REG_2] =	RV_REG_A1,
20 	[BPF_REG_3] =	RV_REG_A2,
21 	[BPF_REG_4] =	RV_REG_A3,
22 	[BPF_REG_5] =	RV_REG_A4,
23 	[BPF_REG_6] =	RV_REG_S1,
24 	[BPF_REG_7] =	RV_REG_S2,
25 	[BPF_REG_8] =	RV_REG_S3,
26 	[BPF_REG_9] =	RV_REG_S4,
27 	[BPF_REG_FP] =	RV_REG_S5,
28 	[BPF_REG_AX] =	RV_REG_T0,
29 };
30 
31 static const int pt_regmap[] = {
32 	[RV_REG_A0] = offsetof(struct pt_regs, a0),
33 	[RV_REG_A1] = offsetof(struct pt_regs, a1),
34 	[RV_REG_A2] = offsetof(struct pt_regs, a2),
35 	[RV_REG_A3] = offsetof(struct pt_regs, a3),
36 	[RV_REG_A4] = offsetof(struct pt_regs, a4),
37 	[RV_REG_A5] = offsetof(struct pt_regs, a5),
38 	[RV_REG_S1] = offsetof(struct pt_regs, s1),
39 	[RV_REG_S2] = offsetof(struct pt_regs, s2),
40 	[RV_REG_S3] = offsetof(struct pt_regs, s3),
41 	[RV_REG_S4] = offsetof(struct pt_regs, s4),
42 	[RV_REG_S5] = offsetof(struct pt_regs, s5),
43 	[RV_REG_T0] = offsetof(struct pt_regs, t0),
44 };
45 
46 enum {
47 	RV_CTX_F_SEEN_TAIL_CALL =	0,
48 	RV_CTX_F_SEEN_CALL =		RV_REG_RA,
49 	RV_CTX_F_SEEN_S1 =		RV_REG_S1,
50 	RV_CTX_F_SEEN_S2 =		RV_REG_S2,
51 	RV_CTX_F_SEEN_S3 =		RV_REG_S3,
52 	RV_CTX_F_SEEN_S4 =		RV_REG_S4,
53 	RV_CTX_F_SEEN_S5 =		RV_REG_S5,
54 	RV_CTX_F_SEEN_S6 =		RV_REG_S6,
55 };
56 
57 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
58 {
59 	u8 reg = regmap[bpf_reg];
60 
61 	switch (reg) {
62 	case RV_CTX_F_SEEN_S1:
63 	case RV_CTX_F_SEEN_S2:
64 	case RV_CTX_F_SEEN_S3:
65 	case RV_CTX_F_SEEN_S4:
66 	case RV_CTX_F_SEEN_S5:
67 	case RV_CTX_F_SEEN_S6:
68 		__set_bit(reg, &ctx->flags);
69 	}
70 	return reg;
71 };
72 
73 static bool seen_reg(int reg, struct rv_jit_context *ctx)
74 {
75 	switch (reg) {
76 	case RV_CTX_F_SEEN_CALL:
77 	case RV_CTX_F_SEEN_S1:
78 	case RV_CTX_F_SEEN_S2:
79 	case RV_CTX_F_SEEN_S3:
80 	case RV_CTX_F_SEEN_S4:
81 	case RV_CTX_F_SEEN_S5:
82 	case RV_CTX_F_SEEN_S6:
83 		return test_bit(reg, &ctx->flags);
84 	}
85 	return false;
86 }
87 
88 static void mark_fp(struct rv_jit_context *ctx)
89 {
90 	__set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
91 }
92 
93 static void mark_call(struct rv_jit_context *ctx)
94 {
95 	__set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
96 }
97 
98 static bool seen_call(struct rv_jit_context *ctx)
99 {
100 	return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
101 }
102 
103 static void mark_tail_call(struct rv_jit_context *ctx)
104 {
105 	__set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
106 }
107 
108 static bool seen_tail_call(struct rv_jit_context *ctx)
109 {
110 	return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
111 }
112 
113 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
114 {
115 	mark_tail_call(ctx);
116 
117 	if (seen_call(ctx)) {
118 		__set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
119 		return RV_REG_S6;
120 	}
121 	return RV_REG_A6;
122 }
123 
124 static bool is_32b_int(s64 val)
125 {
126 	return -(1L << 31) <= val && val < (1L << 31);
127 }
128 
129 static bool in_auipc_jalr_range(s64 val)
130 {
131 	/*
132 	 * auipc+jalr can reach any signed PC-relative offset in the range
133 	 * [-2^31 - 2^11, 2^31 - 2^11).
134 	 */
135 	return (-(1L << 31) - (1L << 11)) <= val &&
136 		val < ((1L << 31) - (1L << 11));
137 }
138 
139 /* Emit fixed-length instructions for address */
140 static int emit_addr(u8 rd, u64 addr, bool extra_pass, struct rv_jit_context *ctx)
141 {
142 	u64 ip = (u64)(ctx->insns + ctx->ninsns);
143 	s64 off = addr - ip;
144 	s64 upper = (off + (1 << 11)) >> 12;
145 	s64 lower = off & 0xfff;
146 
147 	if (extra_pass && !in_auipc_jalr_range(off)) {
148 		pr_err("bpf-jit: target offset 0x%llx is out of range\n", off);
149 		return -ERANGE;
150 	}
151 
152 	emit(rv_auipc(rd, upper), ctx);
153 	emit(rv_addi(rd, rd, lower), ctx);
154 	return 0;
155 }
156 
157 /* Emit variable-length instructions for 32-bit and 64-bit imm */
158 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
159 {
160 	/* Note that the immediate from the add is sign-extended,
161 	 * which means that we need to compensate this by adding 2^12,
162 	 * when the 12th bit is set. A simpler way of doing this, and
163 	 * getting rid of the check, is to just add 2**11 before the
164 	 * shift. The "Loading a 32-Bit constant" example from the
165 	 * "Computer Organization and Design, RISC-V edition" book by
166 	 * Patterson/Hennessy highlights this fact.
167 	 *
168 	 * This also means that we need to process LSB to MSB.
169 	 */
170 	s64 upper = (val + (1 << 11)) >> 12;
171 	/* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
172 	 * and addi are signed and RVC checks will perform signed comparisons.
173 	 */
174 	s64 lower = ((val & 0xfff) << 52) >> 52;
175 	int shift;
176 
177 	if (is_32b_int(val)) {
178 		if (upper)
179 			emit_lui(rd, upper, ctx);
180 
181 		if (!upper) {
182 			emit_li(rd, lower, ctx);
183 			return;
184 		}
185 
186 		emit_addiw(rd, rd, lower, ctx);
187 		return;
188 	}
189 
190 	shift = __ffs(upper);
191 	upper >>= shift;
192 	shift += 12;
193 
194 	emit_imm(rd, upper, ctx);
195 
196 	emit_slli(rd, rd, shift, ctx);
197 	if (lower)
198 		emit_addi(rd, rd, lower, ctx);
199 }
200 
201 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
202 {
203 	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
204 
205 	if (seen_reg(RV_REG_RA, ctx)) {
206 		emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
207 		store_offset -= 8;
208 	}
209 	emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
210 	store_offset -= 8;
211 	if (seen_reg(RV_REG_S1, ctx)) {
212 		emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
213 		store_offset -= 8;
214 	}
215 	if (seen_reg(RV_REG_S2, ctx)) {
216 		emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
217 		store_offset -= 8;
218 	}
219 	if (seen_reg(RV_REG_S3, ctx)) {
220 		emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
221 		store_offset -= 8;
222 	}
223 	if (seen_reg(RV_REG_S4, ctx)) {
224 		emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
225 		store_offset -= 8;
226 	}
227 	if (seen_reg(RV_REG_S5, ctx)) {
228 		emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
229 		store_offset -= 8;
230 	}
231 	if (seen_reg(RV_REG_S6, ctx)) {
232 		emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
233 		store_offset -= 8;
234 	}
235 
236 	emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
237 	/* Set return value. */
238 	if (!is_tail_call)
239 		emit_mv(RV_REG_A0, RV_REG_A5, ctx);
240 	emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
241 		  is_tail_call ? 4 : 0, /* skip TCC init */
242 		  ctx);
243 }
244 
245 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
246 		     struct rv_jit_context *ctx)
247 {
248 	switch (cond) {
249 	case BPF_JEQ:
250 		emit(rv_beq(rd, rs, rvoff >> 1), ctx);
251 		return;
252 	case BPF_JGT:
253 		emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
254 		return;
255 	case BPF_JLT:
256 		emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
257 		return;
258 	case BPF_JGE:
259 		emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
260 		return;
261 	case BPF_JLE:
262 		emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
263 		return;
264 	case BPF_JNE:
265 		emit(rv_bne(rd, rs, rvoff >> 1), ctx);
266 		return;
267 	case BPF_JSGT:
268 		emit(rv_blt(rs, rd, rvoff >> 1), ctx);
269 		return;
270 	case BPF_JSLT:
271 		emit(rv_blt(rd, rs, rvoff >> 1), ctx);
272 		return;
273 	case BPF_JSGE:
274 		emit(rv_bge(rd, rs, rvoff >> 1), ctx);
275 		return;
276 	case BPF_JSLE:
277 		emit(rv_bge(rs, rd, rvoff >> 1), ctx);
278 	}
279 }
280 
281 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
282 			struct rv_jit_context *ctx)
283 {
284 	s64 upper, lower;
285 
286 	if (is_13b_int(rvoff)) {
287 		emit_bcc(cond, rd, rs, rvoff, ctx);
288 		return;
289 	}
290 
291 	/* Adjust for jal */
292 	rvoff -= 4;
293 
294 	/* Transform, e.g.:
295 	 *   bne rd,rs,foo
296 	 * to
297 	 *   beq rd,rs,<.L1>
298 	 *   (auipc foo)
299 	 *   jal(r) foo
300 	 * .L1
301 	 */
302 	cond = invert_bpf_cond(cond);
303 	if (is_21b_int(rvoff)) {
304 		emit_bcc(cond, rd, rs, 8, ctx);
305 		emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
306 		return;
307 	}
308 
309 	/* 32b No need for an additional rvoff adjustment, since we
310 	 * get that from the auipc at PC', where PC = PC' + 4.
311 	 */
312 	upper = (rvoff + (1 << 11)) >> 12;
313 	lower = rvoff & 0xfff;
314 
315 	emit_bcc(cond, rd, rs, 12, ctx);
316 	emit(rv_auipc(RV_REG_T1, upper), ctx);
317 	emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
318 }
319 
320 static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
321 {
322 	emit_slli(reg, reg, 32, ctx);
323 	emit_srli(reg, reg, 32, ctx);
324 }
325 
326 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
327 {
328 	int tc_ninsn, off, start_insn = ctx->ninsns;
329 	u8 tcc = rv_tail_call_reg(ctx);
330 
331 	/* a0: &ctx
332 	 * a1: &array
333 	 * a2: index
334 	 *
335 	 * if (index >= array->map.max_entries)
336 	 *	goto out;
337 	 */
338 	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
339 		   ctx->offset[0];
340 	emit_zext_32(RV_REG_A2, ctx);
341 
342 	off = offsetof(struct bpf_array, map.max_entries);
343 	if (is_12b_check(off, insn))
344 		return -1;
345 	emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
346 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
347 	emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
348 
349 	/* if (--TCC < 0)
350 	 *     goto out;
351 	 */
352 	emit_addi(RV_REG_TCC, tcc, -1, ctx);
353 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
354 	emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
355 
356 	/* prog = array->ptrs[index];
357 	 * if (!prog)
358 	 *     goto out;
359 	 */
360 	emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
361 	emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
362 	off = offsetof(struct bpf_array, ptrs);
363 	if (is_12b_check(off, insn))
364 		return -1;
365 	emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
366 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
367 	emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
368 
369 	/* goto *(prog->bpf_func + 4); */
370 	off = offsetof(struct bpf_prog, bpf_func);
371 	if (is_12b_check(off, insn))
372 		return -1;
373 	emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
374 	__build_epilogue(true, ctx);
375 	return 0;
376 }
377 
378 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
379 		      struct rv_jit_context *ctx)
380 {
381 	u8 code = insn->code;
382 
383 	switch (code) {
384 	case BPF_JMP | BPF_JA:
385 	case BPF_JMP | BPF_CALL:
386 	case BPF_JMP | BPF_EXIT:
387 	case BPF_JMP | BPF_TAIL_CALL:
388 		break;
389 	default:
390 		*rd = bpf_to_rv_reg(insn->dst_reg, ctx);
391 	}
392 
393 	if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
394 	    code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
395 	    code & BPF_LDX || code & BPF_STX)
396 		*rs = bpf_to_rv_reg(insn->src_reg, ctx);
397 }
398 
399 static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
400 {
401 	emit_mv(RV_REG_T2, *rd, ctx);
402 	emit_zext_32(RV_REG_T2, ctx);
403 	emit_mv(RV_REG_T1, *rs, ctx);
404 	emit_zext_32(RV_REG_T1, ctx);
405 	*rd = RV_REG_T2;
406 	*rs = RV_REG_T1;
407 }
408 
409 static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
410 {
411 	emit_addiw(RV_REG_T2, *rd, 0, ctx);
412 	emit_addiw(RV_REG_T1, *rs, 0, ctx);
413 	*rd = RV_REG_T2;
414 	*rs = RV_REG_T1;
415 }
416 
417 static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
418 {
419 	emit_mv(RV_REG_T2, *rd, ctx);
420 	emit_zext_32(RV_REG_T2, ctx);
421 	emit_zext_32(RV_REG_T1, ctx);
422 	*rd = RV_REG_T2;
423 }
424 
425 static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
426 {
427 	emit_addiw(RV_REG_T2, *rd, 0, ctx);
428 	*rd = RV_REG_T2;
429 }
430 
431 static int emit_jump_and_link(u8 rd, s64 rvoff, bool force_jalr,
432 			      struct rv_jit_context *ctx)
433 {
434 	s64 upper, lower;
435 
436 	if (rvoff && is_21b_int(rvoff) && !force_jalr) {
437 		emit(rv_jal(rd, rvoff >> 1), ctx);
438 		return 0;
439 	} else if (in_auipc_jalr_range(rvoff)) {
440 		upper = (rvoff + (1 << 11)) >> 12;
441 		lower = rvoff & 0xfff;
442 		emit(rv_auipc(RV_REG_T1, upper), ctx);
443 		emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
444 		return 0;
445 	}
446 
447 	pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
448 	return -ERANGE;
449 }
450 
451 static bool is_signed_bpf_cond(u8 cond)
452 {
453 	return cond == BPF_JSGT || cond == BPF_JSLT ||
454 		cond == BPF_JSGE || cond == BPF_JSLE;
455 }
456 
457 static int emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
458 {
459 	s64 off = 0;
460 	u64 ip;
461 	u8 rd;
462 	int ret;
463 
464 	if (addr && ctx->insns) {
465 		ip = (u64)(long)(ctx->insns + ctx->ninsns);
466 		off = addr - ip;
467 	}
468 
469 	ret = emit_jump_and_link(RV_REG_RA, off, !fixed, ctx);
470 	if (ret)
471 		return ret;
472 	rd = bpf_to_rv_reg(BPF_REG_0, ctx);
473 	emit_mv(rd, RV_REG_A0, ctx);
474 	return 0;
475 }
476 
477 static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
478 			struct rv_jit_context *ctx)
479 {
480 	u8 r0;
481 	int jmp_offset;
482 
483 	if (off) {
484 		if (is_12b_int(off)) {
485 			emit_addi(RV_REG_T1, rd, off, ctx);
486 		} else {
487 			emit_imm(RV_REG_T1, off, ctx);
488 			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
489 		}
490 		rd = RV_REG_T1;
491 	}
492 
493 	switch (imm) {
494 	/* lock *(u32/u64 *)(dst_reg + off16) <op>= src_reg */
495 	case BPF_ADD:
496 		emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) :
497 		     rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
498 		break;
499 	case BPF_AND:
500 		emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) :
501 		     rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
502 		break;
503 	case BPF_OR:
504 		emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) :
505 		     rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
506 		break;
507 	case BPF_XOR:
508 		emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) :
509 		     rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
510 		break;
511 	/* src_reg = atomic_fetch_<op>(dst_reg + off16, src_reg) */
512 	case BPF_ADD | BPF_FETCH:
513 		emit(is64 ? rv_amoadd_d(rs, rs, rd, 0, 0) :
514 		     rv_amoadd_w(rs, rs, rd, 0, 0), ctx);
515 		if (!is64)
516 			emit_zext_32(rs, ctx);
517 		break;
518 	case BPF_AND | BPF_FETCH:
519 		emit(is64 ? rv_amoand_d(rs, rs, rd, 0, 0) :
520 		     rv_amoand_w(rs, rs, rd, 0, 0), ctx);
521 		if (!is64)
522 			emit_zext_32(rs, ctx);
523 		break;
524 	case BPF_OR | BPF_FETCH:
525 		emit(is64 ? rv_amoor_d(rs, rs, rd, 0, 0) :
526 		     rv_amoor_w(rs, rs, rd, 0, 0), ctx);
527 		if (!is64)
528 			emit_zext_32(rs, ctx);
529 		break;
530 	case BPF_XOR | BPF_FETCH:
531 		emit(is64 ? rv_amoxor_d(rs, rs, rd, 0, 0) :
532 		     rv_amoxor_w(rs, rs, rd, 0, 0), ctx);
533 		if (!is64)
534 			emit_zext_32(rs, ctx);
535 		break;
536 	/* src_reg = atomic_xchg(dst_reg + off16, src_reg); */
537 	case BPF_XCHG:
538 		emit(is64 ? rv_amoswap_d(rs, rs, rd, 0, 0) :
539 		     rv_amoswap_w(rs, rs, rd, 0, 0), ctx);
540 		if (!is64)
541 			emit_zext_32(rs, ctx);
542 		break;
543 	/* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */
544 	case BPF_CMPXCHG:
545 		r0 = bpf_to_rv_reg(BPF_REG_0, ctx);
546 		emit(is64 ? rv_addi(RV_REG_T2, r0, 0) :
547 		     rv_addiw(RV_REG_T2, r0, 0), ctx);
548 		emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) :
549 		     rv_lr_w(r0, 0, rd, 0, 0), ctx);
550 		jmp_offset = ninsns_rvoff(8);
551 		emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx);
552 		emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 0) :
553 		     rv_sc_w(RV_REG_T3, rs, rd, 0, 0), ctx);
554 		jmp_offset = ninsns_rvoff(-6);
555 		emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx);
556 		emit(rv_fence(0x3, 0x3), ctx);
557 		break;
558 	}
559 }
560 
561 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
562 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
563 
564 bool ex_handler_bpf(const struct exception_table_entry *ex,
565 		    struct pt_regs *regs)
566 {
567 	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
568 	int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
569 
570 	*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
571 	regs->epc = (unsigned long)&ex->fixup - offset;
572 
573 	return true;
574 }
575 
576 /* For accesses to BTF pointers, add an entry to the exception table */
577 static int add_exception_handler(const struct bpf_insn *insn,
578 				 struct rv_jit_context *ctx,
579 				 int dst_reg, int insn_len)
580 {
581 	struct exception_table_entry *ex;
582 	unsigned long pc;
583 	off_t offset;
584 
585 	if (!ctx->insns || !ctx->prog->aux->extable || BPF_MODE(insn->code) != BPF_PROBE_MEM)
586 		return 0;
587 
588 	if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
589 		return -EINVAL;
590 
591 	if (WARN_ON_ONCE(insn_len > ctx->ninsns))
592 		return -EINVAL;
593 
594 	if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
595 		return -EINVAL;
596 
597 	ex = &ctx->prog->aux->extable[ctx->nexentries];
598 	pc = (unsigned long)&ctx->insns[ctx->ninsns - insn_len];
599 
600 	offset = pc - (long)&ex->insn;
601 	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
602 		return -ERANGE;
603 	ex->insn = offset;
604 
605 	/*
606 	 * Since the extable follows the program, the fixup offset is always
607 	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
608 	 * to keep things simple, and put the destination register in the upper
609 	 * bits. We don't need to worry about buildtime or runtime sort
610 	 * modifying the upper bits because the table is already sorted, and
611 	 * isn't part of the main exception table.
612 	 */
613 	offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
614 	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
615 		return -ERANGE;
616 
617 	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
618 		FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
619 	ex->type = EX_TYPE_BPF;
620 
621 	ctx->nexentries++;
622 	return 0;
623 }
624 
625 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
626 		      bool extra_pass)
627 {
628 	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
629 		    BPF_CLASS(insn->code) == BPF_JMP;
630 	int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
631 	struct bpf_prog_aux *aux = ctx->prog->aux;
632 	u8 rd = -1, rs = -1, code = insn->code;
633 	s16 off = insn->off;
634 	s32 imm = insn->imm;
635 
636 	init_regs(&rd, &rs, insn, ctx);
637 
638 	switch (code) {
639 	/* dst = src */
640 	case BPF_ALU | BPF_MOV | BPF_X:
641 	case BPF_ALU64 | BPF_MOV | BPF_X:
642 		if (imm == 1) {
643 			/* Special mov32 for zext */
644 			emit_zext_32(rd, ctx);
645 			break;
646 		}
647 		emit_mv(rd, rs, ctx);
648 		if (!is64 && !aux->verifier_zext)
649 			emit_zext_32(rd, ctx);
650 		break;
651 
652 	/* dst = dst OP src */
653 	case BPF_ALU | BPF_ADD | BPF_X:
654 	case BPF_ALU64 | BPF_ADD | BPF_X:
655 		emit_add(rd, rd, rs, ctx);
656 		if (!is64 && !aux->verifier_zext)
657 			emit_zext_32(rd, ctx);
658 		break;
659 	case BPF_ALU | BPF_SUB | BPF_X:
660 	case BPF_ALU64 | BPF_SUB | BPF_X:
661 		if (is64)
662 			emit_sub(rd, rd, rs, ctx);
663 		else
664 			emit_subw(rd, rd, rs, ctx);
665 
666 		if (!is64 && !aux->verifier_zext)
667 			emit_zext_32(rd, ctx);
668 		break;
669 	case BPF_ALU | BPF_AND | BPF_X:
670 	case BPF_ALU64 | BPF_AND | BPF_X:
671 		emit_and(rd, rd, rs, ctx);
672 		if (!is64 && !aux->verifier_zext)
673 			emit_zext_32(rd, ctx);
674 		break;
675 	case BPF_ALU | BPF_OR | BPF_X:
676 	case BPF_ALU64 | BPF_OR | BPF_X:
677 		emit_or(rd, rd, rs, ctx);
678 		if (!is64 && !aux->verifier_zext)
679 			emit_zext_32(rd, ctx);
680 		break;
681 	case BPF_ALU | BPF_XOR | BPF_X:
682 	case BPF_ALU64 | BPF_XOR | BPF_X:
683 		emit_xor(rd, rd, rs, ctx);
684 		if (!is64 && !aux->verifier_zext)
685 			emit_zext_32(rd, ctx);
686 		break;
687 	case BPF_ALU | BPF_MUL | BPF_X:
688 	case BPF_ALU64 | BPF_MUL | BPF_X:
689 		emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
690 		if (!is64 && !aux->verifier_zext)
691 			emit_zext_32(rd, ctx);
692 		break;
693 	case BPF_ALU | BPF_DIV | BPF_X:
694 	case BPF_ALU64 | BPF_DIV | BPF_X:
695 		emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
696 		if (!is64 && !aux->verifier_zext)
697 			emit_zext_32(rd, ctx);
698 		break;
699 	case BPF_ALU | BPF_MOD | BPF_X:
700 	case BPF_ALU64 | BPF_MOD | BPF_X:
701 		emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
702 		if (!is64 && !aux->verifier_zext)
703 			emit_zext_32(rd, ctx);
704 		break;
705 	case BPF_ALU | BPF_LSH | BPF_X:
706 	case BPF_ALU64 | BPF_LSH | BPF_X:
707 		emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
708 		if (!is64 && !aux->verifier_zext)
709 			emit_zext_32(rd, ctx);
710 		break;
711 	case BPF_ALU | BPF_RSH | BPF_X:
712 	case BPF_ALU64 | BPF_RSH | BPF_X:
713 		emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
714 		if (!is64 && !aux->verifier_zext)
715 			emit_zext_32(rd, ctx);
716 		break;
717 	case BPF_ALU | BPF_ARSH | BPF_X:
718 	case BPF_ALU64 | BPF_ARSH | BPF_X:
719 		emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
720 		if (!is64 && !aux->verifier_zext)
721 			emit_zext_32(rd, ctx);
722 		break;
723 
724 	/* dst = -dst */
725 	case BPF_ALU | BPF_NEG:
726 	case BPF_ALU64 | BPF_NEG:
727 		emit_sub(rd, RV_REG_ZERO, rd, ctx);
728 		if (!is64 && !aux->verifier_zext)
729 			emit_zext_32(rd, ctx);
730 		break;
731 
732 	/* dst = BSWAP##imm(dst) */
733 	case BPF_ALU | BPF_END | BPF_FROM_LE:
734 		switch (imm) {
735 		case 16:
736 			emit_slli(rd, rd, 48, ctx);
737 			emit_srli(rd, rd, 48, ctx);
738 			break;
739 		case 32:
740 			if (!aux->verifier_zext)
741 				emit_zext_32(rd, ctx);
742 			break;
743 		case 64:
744 			/* Do nothing */
745 			break;
746 		}
747 		break;
748 
749 	case BPF_ALU | BPF_END | BPF_FROM_BE:
750 		emit_li(RV_REG_T2, 0, ctx);
751 
752 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
753 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
754 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
755 		emit_srli(rd, rd, 8, ctx);
756 		if (imm == 16)
757 			goto out_be;
758 
759 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
760 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
761 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
762 		emit_srli(rd, rd, 8, ctx);
763 
764 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
765 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
766 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
767 		emit_srli(rd, rd, 8, ctx);
768 		if (imm == 32)
769 			goto out_be;
770 
771 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
772 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
773 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
774 		emit_srli(rd, rd, 8, ctx);
775 
776 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
777 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
778 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
779 		emit_srli(rd, rd, 8, ctx);
780 
781 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
782 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
783 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
784 		emit_srli(rd, rd, 8, ctx);
785 
786 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
787 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
788 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
789 		emit_srli(rd, rd, 8, ctx);
790 out_be:
791 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
792 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
793 
794 		emit_mv(rd, RV_REG_T2, ctx);
795 		break;
796 
797 	/* dst = imm */
798 	case BPF_ALU | BPF_MOV | BPF_K:
799 	case BPF_ALU64 | BPF_MOV | BPF_K:
800 		emit_imm(rd, imm, ctx);
801 		if (!is64 && !aux->verifier_zext)
802 			emit_zext_32(rd, ctx);
803 		break;
804 
805 	/* dst = dst OP imm */
806 	case BPF_ALU | BPF_ADD | BPF_K:
807 	case BPF_ALU64 | BPF_ADD | BPF_K:
808 		if (is_12b_int(imm)) {
809 			emit_addi(rd, rd, imm, ctx);
810 		} else {
811 			emit_imm(RV_REG_T1, imm, ctx);
812 			emit_add(rd, rd, RV_REG_T1, ctx);
813 		}
814 		if (!is64 && !aux->verifier_zext)
815 			emit_zext_32(rd, ctx);
816 		break;
817 	case BPF_ALU | BPF_SUB | BPF_K:
818 	case BPF_ALU64 | BPF_SUB | BPF_K:
819 		if (is_12b_int(-imm)) {
820 			emit_addi(rd, rd, -imm, ctx);
821 		} else {
822 			emit_imm(RV_REG_T1, imm, ctx);
823 			emit_sub(rd, rd, RV_REG_T1, ctx);
824 		}
825 		if (!is64 && !aux->verifier_zext)
826 			emit_zext_32(rd, ctx);
827 		break;
828 	case BPF_ALU | BPF_AND | BPF_K:
829 	case BPF_ALU64 | BPF_AND | BPF_K:
830 		if (is_12b_int(imm)) {
831 			emit_andi(rd, rd, imm, ctx);
832 		} else {
833 			emit_imm(RV_REG_T1, imm, ctx);
834 			emit_and(rd, rd, RV_REG_T1, ctx);
835 		}
836 		if (!is64 && !aux->verifier_zext)
837 			emit_zext_32(rd, ctx);
838 		break;
839 	case BPF_ALU | BPF_OR | BPF_K:
840 	case BPF_ALU64 | BPF_OR | BPF_K:
841 		if (is_12b_int(imm)) {
842 			emit(rv_ori(rd, rd, imm), ctx);
843 		} else {
844 			emit_imm(RV_REG_T1, imm, ctx);
845 			emit_or(rd, rd, RV_REG_T1, ctx);
846 		}
847 		if (!is64 && !aux->verifier_zext)
848 			emit_zext_32(rd, ctx);
849 		break;
850 	case BPF_ALU | BPF_XOR | BPF_K:
851 	case BPF_ALU64 | BPF_XOR | BPF_K:
852 		if (is_12b_int(imm)) {
853 			emit(rv_xori(rd, rd, imm), ctx);
854 		} else {
855 			emit_imm(RV_REG_T1, imm, ctx);
856 			emit_xor(rd, rd, RV_REG_T1, ctx);
857 		}
858 		if (!is64 && !aux->verifier_zext)
859 			emit_zext_32(rd, ctx);
860 		break;
861 	case BPF_ALU | BPF_MUL | BPF_K:
862 	case BPF_ALU64 | BPF_MUL | BPF_K:
863 		emit_imm(RV_REG_T1, imm, ctx);
864 		emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
865 		     rv_mulw(rd, rd, RV_REG_T1), ctx);
866 		if (!is64 && !aux->verifier_zext)
867 			emit_zext_32(rd, ctx);
868 		break;
869 	case BPF_ALU | BPF_DIV | BPF_K:
870 	case BPF_ALU64 | BPF_DIV | BPF_K:
871 		emit_imm(RV_REG_T1, imm, ctx);
872 		emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
873 		     rv_divuw(rd, rd, RV_REG_T1), ctx);
874 		if (!is64 && !aux->verifier_zext)
875 			emit_zext_32(rd, ctx);
876 		break;
877 	case BPF_ALU | BPF_MOD | BPF_K:
878 	case BPF_ALU64 | BPF_MOD | BPF_K:
879 		emit_imm(RV_REG_T1, imm, ctx);
880 		emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
881 		     rv_remuw(rd, rd, RV_REG_T1), ctx);
882 		if (!is64 && !aux->verifier_zext)
883 			emit_zext_32(rd, ctx);
884 		break;
885 	case BPF_ALU | BPF_LSH | BPF_K:
886 	case BPF_ALU64 | BPF_LSH | BPF_K:
887 		emit_slli(rd, rd, imm, ctx);
888 
889 		if (!is64 && !aux->verifier_zext)
890 			emit_zext_32(rd, ctx);
891 		break;
892 	case BPF_ALU | BPF_RSH | BPF_K:
893 	case BPF_ALU64 | BPF_RSH | BPF_K:
894 		if (is64)
895 			emit_srli(rd, rd, imm, ctx);
896 		else
897 			emit(rv_srliw(rd, rd, imm), ctx);
898 
899 		if (!is64 && !aux->verifier_zext)
900 			emit_zext_32(rd, ctx);
901 		break;
902 	case BPF_ALU | BPF_ARSH | BPF_K:
903 	case BPF_ALU64 | BPF_ARSH | BPF_K:
904 		if (is64)
905 			emit_srai(rd, rd, imm, ctx);
906 		else
907 			emit(rv_sraiw(rd, rd, imm), ctx);
908 
909 		if (!is64 && !aux->verifier_zext)
910 			emit_zext_32(rd, ctx);
911 		break;
912 
913 	/* JUMP off */
914 	case BPF_JMP | BPF_JA:
915 		rvoff = rv_offset(i, off, ctx);
916 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
917 		if (ret)
918 			return ret;
919 		break;
920 
921 	/* IF (dst COND src) JUMP off */
922 	case BPF_JMP | BPF_JEQ | BPF_X:
923 	case BPF_JMP32 | BPF_JEQ | BPF_X:
924 	case BPF_JMP | BPF_JGT | BPF_X:
925 	case BPF_JMP32 | BPF_JGT | BPF_X:
926 	case BPF_JMP | BPF_JLT | BPF_X:
927 	case BPF_JMP32 | BPF_JLT | BPF_X:
928 	case BPF_JMP | BPF_JGE | BPF_X:
929 	case BPF_JMP32 | BPF_JGE | BPF_X:
930 	case BPF_JMP | BPF_JLE | BPF_X:
931 	case BPF_JMP32 | BPF_JLE | BPF_X:
932 	case BPF_JMP | BPF_JNE | BPF_X:
933 	case BPF_JMP32 | BPF_JNE | BPF_X:
934 	case BPF_JMP | BPF_JSGT | BPF_X:
935 	case BPF_JMP32 | BPF_JSGT | BPF_X:
936 	case BPF_JMP | BPF_JSLT | BPF_X:
937 	case BPF_JMP32 | BPF_JSLT | BPF_X:
938 	case BPF_JMP | BPF_JSGE | BPF_X:
939 	case BPF_JMP32 | BPF_JSGE | BPF_X:
940 	case BPF_JMP | BPF_JSLE | BPF_X:
941 	case BPF_JMP32 | BPF_JSLE | BPF_X:
942 	case BPF_JMP | BPF_JSET | BPF_X:
943 	case BPF_JMP32 | BPF_JSET | BPF_X:
944 		rvoff = rv_offset(i, off, ctx);
945 		if (!is64) {
946 			s = ctx->ninsns;
947 			if (is_signed_bpf_cond(BPF_OP(code)))
948 				emit_sext_32_rd_rs(&rd, &rs, ctx);
949 			else
950 				emit_zext_32_rd_rs(&rd, &rs, ctx);
951 			e = ctx->ninsns;
952 
953 			/* Adjust for extra insns */
954 			rvoff -= ninsns_rvoff(e - s);
955 		}
956 
957 		if (BPF_OP(code) == BPF_JSET) {
958 			/* Adjust for and */
959 			rvoff -= 4;
960 			emit_and(RV_REG_T1, rd, rs, ctx);
961 			emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
962 				    ctx);
963 		} else {
964 			emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
965 		}
966 		break;
967 
968 	/* IF (dst COND imm) JUMP off */
969 	case BPF_JMP | BPF_JEQ | BPF_K:
970 	case BPF_JMP32 | BPF_JEQ | BPF_K:
971 	case BPF_JMP | BPF_JGT | BPF_K:
972 	case BPF_JMP32 | BPF_JGT | BPF_K:
973 	case BPF_JMP | BPF_JLT | BPF_K:
974 	case BPF_JMP32 | BPF_JLT | BPF_K:
975 	case BPF_JMP | BPF_JGE | BPF_K:
976 	case BPF_JMP32 | BPF_JGE | BPF_K:
977 	case BPF_JMP | BPF_JLE | BPF_K:
978 	case BPF_JMP32 | BPF_JLE | BPF_K:
979 	case BPF_JMP | BPF_JNE | BPF_K:
980 	case BPF_JMP32 | BPF_JNE | BPF_K:
981 	case BPF_JMP | BPF_JSGT | BPF_K:
982 	case BPF_JMP32 | BPF_JSGT | BPF_K:
983 	case BPF_JMP | BPF_JSLT | BPF_K:
984 	case BPF_JMP32 | BPF_JSLT | BPF_K:
985 	case BPF_JMP | BPF_JSGE | BPF_K:
986 	case BPF_JMP32 | BPF_JSGE | BPF_K:
987 	case BPF_JMP | BPF_JSLE | BPF_K:
988 	case BPF_JMP32 | BPF_JSLE | BPF_K:
989 		rvoff = rv_offset(i, off, ctx);
990 		s = ctx->ninsns;
991 		if (imm) {
992 			emit_imm(RV_REG_T1, imm, ctx);
993 			rs = RV_REG_T1;
994 		} else {
995 			/* If imm is 0, simply use zero register. */
996 			rs = RV_REG_ZERO;
997 		}
998 		if (!is64) {
999 			if (is_signed_bpf_cond(BPF_OP(code)))
1000 				emit_sext_32_rd(&rd, ctx);
1001 			else
1002 				emit_zext_32_rd_t1(&rd, ctx);
1003 		}
1004 		e = ctx->ninsns;
1005 
1006 		/* Adjust for extra insns */
1007 		rvoff -= ninsns_rvoff(e - s);
1008 		emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1009 		break;
1010 
1011 	case BPF_JMP | BPF_JSET | BPF_K:
1012 	case BPF_JMP32 | BPF_JSET | BPF_K:
1013 		rvoff = rv_offset(i, off, ctx);
1014 		s = ctx->ninsns;
1015 		if (is_12b_int(imm)) {
1016 			emit_andi(RV_REG_T1, rd, imm, ctx);
1017 		} else {
1018 			emit_imm(RV_REG_T1, imm, ctx);
1019 			emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
1020 		}
1021 		/* For jset32, we should clear the upper 32 bits of t1, but
1022 		 * sign-extension is sufficient here and saves one instruction,
1023 		 * as t1 is used only in comparison against zero.
1024 		 */
1025 		if (!is64 && imm < 0)
1026 			emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
1027 		e = ctx->ninsns;
1028 		rvoff -= ninsns_rvoff(e - s);
1029 		emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1030 		break;
1031 
1032 	/* function call */
1033 	case BPF_JMP | BPF_CALL:
1034 	{
1035 		bool fixed;
1036 		u64 addr;
1037 
1038 		mark_call(ctx);
1039 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &addr,
1040 					    &fixed);
1041 		if (ret < 0)
1042 			return ret;
1043 		ret = emit_call(fixed, addr, ctx);
1044 		if (ret)
1045 			return ret;
1046 		break;
1047 	}
1048 	/* tail call */
1049 	case BPF_JMP | BPF_TAIL_CALL:
1050 		if (emit_bpf_tail_call(i, ctx))
1051 			return -1;
1052 		break;
1053 
1054 	/* function return */
1055 	case BPF_JMP | BPF_EXIT:
1056 		if (i == ctx->prog->len - 1)
1057 			break;
1058 
1059 		rvoff = epilogue_offset(ctx);
1060 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
1061 		if (ret)
1062 			return ret;
1063 		break;
1064 
1065 	/* dst = imm64 */
1066 	case BPF_LD | BPF_IMM | BPF_DW:
1067 	{
1068 		struct bpf_insn insn1 = insn[1];
1069 		u64 imm64;
1070 
1071 		imm64 = (u64)insn1.imm << 32 | (u32)imm;
1072 		if (bpf_pseudo_func(insn)) {
1073 			/* fixed-length insns for extra jit pass */
1074 			ret = emit_addr(rd, imm64, extra_pass, ctx);
1075 			if (ret)
1076 				return ret;
1077 		} else {
1078 			emit_imm(rd, imm64, ctx);
1079 		}
1080 
1081 		return 1;
1082 	}
1083 
1084 	/* LDX: dst = *(size *)(src + off) */
1085 	case BPF_LDX | BPF_MEM | BPF_B:
1086 	case BPF_LDX | BPF_MEM | BPF_H:
1087 	case BPF_LDX | BPF_MEM | BPF_W:
1088 	case BPF_LDX | BPF_MEM | BPF_DW:
1089 	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1090 	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1091 	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1092 	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1093 	{
1094 		int insn_len, insns_start;
1095 
1096 		switch (BPF_SIZE(code)) {
1097 		case BPF_B:
1098 			if (is_12b_int(off)) {
1099 				insns_start = ctx->ninsns;
1100 				emit(rv_lbu(rd, off, rs), ctx);
1101 				insn_len = ctx->ninsns - insns_start;
1102 				break;
1103 			}
1104 
1105 			emit_imm(RV_REG_T1, off, ctx);
1106 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1107 			insns_start = ctx->ninsns;
1108 			emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1109 			insn_len = ctx->ninsns - insns_start;
1110 			if (insn_is_zext(&insn[1]))
1111 				return 1;
1112 			break;
1113 		case BPF_H:
1114 			if (is_12b_int(off)) {
1115 				insns_start = ctx->ninsns;
1116 				emit(rv_lhu(rd, off, rs), ctx);
1117 				insn_len = ctx->ninsns - insns_start;
1118 				break;
1119 			}
1120 
1121 			emit_imm(RV_REG_T1, off, ctx);
1122 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1123 			insns_start = ctx->ninsns;
1124 			emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1125 			insn_len = ctx->ninsns - insns_start;
1126 			if (insn_is_zext(&insn[1]))
1127 				return 1;
1128 			break;
1129 		case BPF_W:
1130 			if (is_12b_int(off)) {
1131 				insns_start = ctx->ninsns;
1132 				emit(rv_lwu(rd, off, rs), ctx);
1133 				insn_len = ctx->ninsns - insns_start;
1134 				break;
1135 			}
1136 
1137 			emit_imm(RV_REG_T1, off, ctx);
1138 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1139 			insns_start = ctx->ninsns;
1140 			emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1141 			insn_len = ctx->ninsns - insns_start;
1142 			if (insn_is_zext(&insn[1]))
1143 				return 1;
1144 			break;
1145 		case BPF_DW:
1146 			if (is_12b_int(off)) {
1147 				insns_start = ctx->ninsns;
1148 				emit_ld(rd, off, rs, ctx);
1149 				insn_len = ctx->ninsns - insns_start;
1150 				break;
1151 			}
1152 
1153 			emit_imm(RV_REG_T1, off, ctx);
1154 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1155 			insns_start = ctx->ninsns;
1156 			emit_ld(rd, 0, RV_REG_T1, ctx);
1157 			insn_len = ctx->ninsns - insns_start;
1158 			break;
1159 		}
1160 
1161 		ret = add_exception_handler(insn, ctx, rd, insn_len);
1162 		if (ret)
1163 			return ret;
1164 		break;
1165 	}
1166 	/* speculation barrier */
1167 	case BPF_ST | BPF_NOSPEC:
1168 		break;
1169 
1170 	/* ST: *(size *)(dst + off) = imm */
1171 	case BPF_ST | BPF_MEM | BPF_B:
1172 		emit_imm(RV_REG_T1, imm, ctx);
1173 		if (is_12b_int(off)) {
1174 			emit(rv_sb(rd, off, RV_REG_T1), ctx);
1175 			break;
1176 		}
1177 
1178 		emit_imm(RV_REG_T2, off, ctx);
1179 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1180 		emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1181 		break;
1182 
1183 	case BPF_ST | BPF_MEM | BPF_H:
1184 		emit_imm(RV_REG_T1, imm, ctx);
1185 		if (is_12b_int(off)) {
1186 			emit(rv_sh(rd, off, RV_REG_T1), ctx);
1187 			break;
1188 		}
1189 
1190 		emit_imm(RV_REG_T2, off, ctx);
1191 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1192 		emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1193 		break;
1194 	case BPF_ST | BPF_MEM | BPF_W:
1195 		emit_imm(RV_REG_T1, imm, ctx);
1196 		if (is_12b_int(off)) {
1197 			emit_sw(rd, off, RV_REG_T1, ctx);
1198 			break;
1199 		}
1200 
1201 		emit_imm(RV_REG_T2, off, ctx);
1202 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1203 		emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1204 		break;
1205 	case BPF_ST | BPF_MEM | BPF_DW:
1206 		emit_imm(RV_REG_T1, imm, ctx);
1207 		if (is_12b_int(off)) {
1208 			emit_sd(rd, off, RV_REG_T1, ctx);
1209 			break;
1210 		}
1211 
1212 		emit_imm(RV_REG_T2, off, ctx);
1213 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1214 		emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1215 		break;
1216 
1217 	/* STX: *(size *)(dst + off) = src */
1218 	case BPF_STX | BPF_MEM | BPF_B:
1219 		if (is_12b_int(off)) {
1220 			emit(rv_sb(rd, off, rs), ctx);
1221 			break;
1222 		}
1223 
1224 		emit_imm(RV_REG_T1, off, ctx);
1225 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1226 		emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1227 		break;
1228 	case BPF_STX | BPF_MEM | BPF_H:
1229 		if (is_12b_int(off)) {
1230 			emit(rv_sh(rd, off, rs), ctx);
1231 			break;
1232 		}
1233 
1234 		emit_imm(RV_REG_T1, off, ctx);
1235 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1236 		emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1237 		break;
1238 	case BPF_STX | BPF_MEM | BPF_W:
1239 		if (is_12b_int(off)) {
1240 			emit_sw(rd, off, rs, ctx);
1241 			break;
1242 		}
1243 
1244 		emit_imm(RV_REG_T1, off, ctx);
1245 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1246 		emit_sw(RV_REG_T1, 0, rs, ctx);
1247 		break;
1248 	case BPF_STX | BPF_MEM | BPF_DW:
1249 		if (is_12b_int(off)) {
1250 			emit_sd(rd, off, rs, ctx);
1251 			break;
1252 		}
1253 
1254 		emit_imm(RV_REG_T1, off, ctx);
1255 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1256 		emit_sd(RV_REG_T1, 0, rs, ctx);
1257 		break;
1258 	case BPF_STX | BPF_ATOMIC | BPF_W:
1259 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1260 		emit_atomic(rd, rs, off, imm,
1261 			    BPF_SIZE(code) == BPF_DW, ctx);
1262 		break;
1263 	default:
1264 		pr_err("bpf-jit: unknown opcode %02x\n", code);
1265 		return -EINVAL;
1266 	}
1267 
1268 	return 0;
1269 }
1270 
1271 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1272 {
1273 	int stack_adjust = 0, store_offset, bpf_stack_adjust;
1274 
1275 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1276 	if (bpf_stack_adjust)
1277 		mark_fp(ctx);
1278 
1279 	if (seen_reg(RV_REG_RA, ctx))
1280 		stack_adjust += 8;
1281 	stack_adjust += 8; /* RV_REG_FP */
1282 	if (seen_reg(RV_REG_S1, ctx))
1283 		stack_adjust += 8;
1284 	if (seen_reg(RV_REG_S2, ctx))
1285 		stack_adjust += 8;
1286 	if (seen_reg(RV_REG_S3, ctx))
1287 		stack_adjust += 8;
1288 	if (seen_reg(RV_REG_S4, ctx))
1289 		stack_adjust += 8;
1290 	if (seen_reg(RV_REG_S5, ctx))
1291 		stack_adjust += 8;
1292 	if (seen_reg(RV_REG_S6, ctx))
1293 		stack_adjust += 8;
1294 
1295 	stack_adjust = round_up(stack_adjust, 16);
1296 	stack_adjust += bpf_stack_adjust;
1297 
1298 	store_offset = stack_adjust - 8;
1299 
1300 	/* First instruction is always setting the tail-call-counter
1301 	 * (TCC) register. This instruction is skipped for tail calls.
1302 	 * Force using a 4-byte (non-compressed) instruction.
1303 	 */
1304 	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1305 
1306 	emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1307 
1308 	if (seen_reg(RV_REG_RA, ctx)) {
1309 		emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1310 		store_offset -= 8;
1311 	}
1312 	emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1313 	store_offset -= 8;
1314 	if (seen_reg(RV_REG_S1, ctx)) {
1315 		emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1316 		store_offset -= 8;
1317 	}
1318 	if (seen_reg(RV_REG_S2, ctx)) {
1319 		emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1320 		store_offset -= 8;
1321 	}
1322 	if (seen_reg(RV_REG_S3, ctx)) {
1323 		emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1324 		store_offset -= 8;
1325 	}
1326 	if (seen_reg(RV_REG_S4, ctx)) {
1327 		emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1328 		store_offset -= 8;
1329 	}
1330 	if (seen_reg(RV_REG_S5, ctx)) {
1331 		emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1332 		store_offset -= 8;
1333 	}
1334 	if (seen_reg(RV_REG_S6, ctx)) {
1335 		emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1336 		store_offset -= 8;
1337 	}
1338 
1339 	emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1340 
1341 	if (bpf_stack_adjust)
1342 		emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1343 
1344 	/* Program contains calls and tail calls, so RV_REG_TCC need
1345 	 * to be saved across calls.
1346 	 */
1347 	if (seen_tail_call(ctx) && seen_call(ctx))
1348 		emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1349 
1350 	ctx->stack_size = stack_adjust;
1351 }
1352 
1353 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1354 {
1355 	__build_epilogue(false, ctx);
1356 }
1357