xref: /openbmc/linux/arch/riscv/net/bpf_jit_comp32.c (revision 15e3ae36)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * BPF JIT compiler for RV32G
4  *
5  * Copyright (c) 2020 Luke Nelson <luke.r.nels@gmail.com>
6  * Copyright (c) 2020 Xi Wang <xi.wang@gmail.com>
7  *
8  * The code is based on the BPF JIT compiler for RV64G by Björn Töpel and
9  * the BPF JIT compiler for 32-bit ARM by Shubham Bansal and Mircea Gherzan.
10  */
11 
12 #include <linux/bpf.h>
13 #include <linux/filter.h>
14 #include "bpf_jit.h"
15 
16 enum {
17 	/* Stack layout - these are offsets from (top of stack - 4). */
18 	BPF_R6_HI,
19 	BPF_R6_LO,
20 	BPF_R7_HI,
21 	BPF_R7_LO,
22 	BPF_R8_HI,
23 	BPF_R8_LO,
24 	BPF_R9_HI,
25 	BPF_R9_LO,
26 	BPF_AX_HI,
27 	BPF_AX_LO,
28 	/* Stack space for BPF_REG_6 through BPF_REG_9 and BPF_REG_AX. */
29 	BPF_JIT_SCRATCH_REGS,
30 };
31 
32 #define STACK_OFFSET(k) (-4 - ((k) * 4))
33 
34 #define TMP_REG_1	(MAX_BPF_JIT_REG + 0)
35 #define TMP_REG_2	(MAX_BPF_JIT_REG + 1)
36 
37 #define RV_REG_TCC		RV_REG_T6
38 #define RV_REG_TCC_SAVED	RV_REG_S7
39 
40 static const s8 bpf2rv32[][2] = {
41 	/* Return value from in-kernel function, and exit value from eBPF. */
42 	[BPF_REG_0] = {RV_REG_S2, RV_REG_S1},
43 	/* Arguments from eBPF program to in-kernel function. */
44 	[BPF_REG_1] = {RV_REG_A1, RV_REG_A0},
45 	[BPF_REG_2] = {RV_REG_A3, RV_REG_A2},
46 	[BPF_REG_3] = {RV_REG_A5, RV_REG_A4},
47 	[BPF_REG_4] = {RV_REG_A7, RV_REG_A6},
48 	[BPF_REG_5] = {RV_REG_S4, RV_REG_S3},
49 	/*
50 	 * Callee-saved registers that in-kernel function will preserve.
51 	 * Stored on the stack.
52 	 */
53 	[BPF_REG_6] = {STACK_OFFSET(BPF_R6_HI), STACK_OFFSET(BPF_R6_LO)},
54 	[BPF_REG_7] = {STACK_OFFSET(BPF_R7_HI), STACK_OFFSET(BPF_R7_LO)},
55 	[BPF_REG_8] = {STACK_OFFSET(BPF_R8_HI), STACK_OFFSET(BPF_R8_LO)},
56 	[BPF_REG_9] = {STACK_OFFSET(BPF_R9_HI), STACK_OFFSET(BPF_R9_LO)},
57 	/* Read-only frame pointer to access BPF stack. */
58 	[BPF_REG_FP] = {RV_REG_S6, RV_REG_S5},
59 	/* Temporary register for blinding constants. Stored on the stack. */
60 	[BPF_REG_AX] = {STACK_OFFSET(BPF_AX_HI), STACK_OFFSET(BPF_AX_LO)},
61 	/*
62 	 * Temporary registers used by the JIT to operate on registers stored
63 	 * on the stack. Save t0 and t1 to be used as temporaries in generated
64 	 * code.
65 	 */
66 	[TMP_REG_1] = {RV_REG_T3, RV_REG_T2},
67 	[TMP_REG_2] = {RV_REG_T5, RV_REG_T4},
68 };
69 
70 static s8 hi(const s8 *r)
71 {
72 	return r[0];
73 }
74 
75 static s8 lo(const s8 *r)
76 {
77 	return r[1];
78 }
79 
80 static void emit_imm(const s8 rd, s32 imm, struct rv_jit_context *ctx)
81 {
82 	u32 upper = (imm + (1 << 11)) >> 12;
83 	u32 lower = imm & 0xfff;
84 
85 	if (upper) {
86 		emit(rv_lui(rd, upper), ctx);
87 		emit(rv_addi(rd, rd, lower), ctx);
88 	} else {
89 		emit(rv_addi(rd, RV_REG_ZERO, lower), ctx);
90 	}
91 }
92 
93 static void emit_imm32(const s8 *rd, s32 imm, struct rv_jit_context *ctx)
94 {
95 	/* Emit immediate into lower bits. */
96 	emit_imm(lo(rd), imm, ctx);
97 
98 	/* Sign-extend into upper bits. */
99 	if (imm >= 0)
100 		emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
101 	else
102 		emit(rv_addi(hi(rd), RV_REG_ZERO, -1), ctx);
103 }
104 
105 static void emit_imm64(const s8 *rd, s32 imm_hi, s32 imm_lo,
106 		       struct rv_jit_context *ctx)
107 {
108 	emit_imm(lo(rd), imm_lo, ctx);
109 	emit_imm(hi(rd), imm_hi, ctx);
110 }
111 
112 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
113 {
114 	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 4;
115 	const s8 *r0 = bpf2rv32[BPF_REG_0];
116 
117 	store_offset -= 4 * BPF_JIT_SCRATCH_REGS;
118 
119 	/* Set return value if not tail call. */
120 	if (!is_tail_call) {
121 		emit(rv_addi(RV_REG_A0, lo(r0), 0), ctx);
122 		emit(rv_addi(RV_REG_A1, hi(r0), 0), ctx);
123 	}
124 
125 	/* Restore callee-saved registers. */
126 	emit(rv_lw(RV_REG_RA, store_offset - 0, RV_REG_SP), ctx);
127 	emit(rv_lw(RV_REG_FP, store_offset - 4, RV_REG_SP), ctx);
128 	emit(rv_lw(RV_REG_S1, store_offset - 8, RV_REG_SP), ctx);
129 	emit(rv_lw(RV_REG_S2, store_offset - 12, RV_REG_SP), ctx);
130 	emit(rv_lw(RV_REG_S3, store_offset - 16, RV_REG_SP), ctx);
131 	emit(rv_lw(RV_REG_S4, store_offset - 20, RV_REG_SP), ctx);
132 	emit(rv_lw(RV_REG_S5, store_offset - 24, RV_REG_SP), ctx);
133 	emit(rv_lw(RV_REG_S6, store_offset - 28, RV_REG_SP), ctx);
134 	emit(rv_lw(RV_REG_S7, store_offset - 32, RV_REG_SP), ctx);
135 
136 	emit(rv_addi(RV_REG_SP, RV_REG_SP, stack_adjust), ctx);
137 
138 	if (is_tail_call) {
139 		/*
140 		 * goto *(t0 + 4);
141 		 * Skips first instruction of prologue which initializes tail
142 		 * call counter. Assumes t0 contains address of target program,
143 		 * see emit_bpf_tail_call.
144 		 */
145 		emit(rv_jalr(RV_REG_ZERO, RV_REG_T0, 4), ctx);
146 	} else {
147 		emit(rv_jalr(RV_REG_ZERO, RV_REG_RA, 0), ctx);
148 	}
149 }
150 
151 static bool is_stacked(s8 reg)
152 {
153 	return reg < 0;
154 }
155 
156 static const s8 *bpf_get_reg64(const s8 *reg, const s8 *tmp,
157 			       struct rv_jit_context *ctx)
158 {
159 	if (is_stacked(hi(reg))) {
160 		emit(rv_lw(hi(tmp), hi(reg), RV_REG_FP), ctx);
161 		emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
162 		reg = tmp;
163 	}
164 	return reg;
165 }
166 
167 static void bpf_put_reg64(const s8 *reg, const s8 *src,
168 			  struct rv_jit_context *ctx)
169 {
170 	if (is_stacked(hi(reg))) {
171 		emit(rv_sw(RV_REG_FP, hi(reg), hi(src)), ctx);
172 		emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
173 	}
174 }
175 
176 static const s8 *bpf_get_reg32(const s8 *reg, const s8 *tmp,
177 			       struct rv_jit_context *ctx)
178 {
179 	if (is_stacked(lo(reg))) {
180 		emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
181 		reg = tmp;
182 	}
183 	return reg;
184 }
185 
186 static void bpf_put_reg32(const s8 *reg, const s8 *src,
187 			  struct rv_jit_context *ctx)
188 {
189 	if (is_stacked(lo(reg))) {
190 		emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
191 		if (!ctx->prog->aux->verifier_zext)
192 			emit(rv_sw(RV_REG_FP, hi(reg), RV_REG_ZERO), ctx);
193 	} else if (!ctx->prog->aux->verifier_zext) {
194 		emit(rv_addi(hi(reg), RV_REG_ZERO, 0), ctx);
195 	}
196 }
197 
198 static void emit_jump_and_link(u8 rd, s32 rvoff, bool force_jalr,
199 			       struct rv_jit_context *ctx)
200 {
201 	s32 upper, lower;
202 
203 	if (rvoff && is_21b_int(rvoff) && !force_jalr) {
204 		emit(rv_jal(rd, rvoff >> 1), ctx);
205 		return;
206 	}
207 
208 	upper = (rvoff + (1 << 11)) >> 12;
209 	lower = rvoff & 0xfff;
210 	emit(rv_auipc(RV_REG_T1, upper), ctx);
211 	emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
212 }
213 
214 static void emit_alu_i64(const s8 *dst, s32 imm,
215 			 struct rv_jit_context *ctx, const u8 op)
216 {
217 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
218 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
219 
220 	switch (op) {
221 	case BPF_MOV:
222 		emit_imm32(rd, imm, ctx);
223 		break;
224 	case BPF_AND:
225 		if (is_12b_int(imm)) {
226 			emit(rv_andi(lo(rd), lo(rd), imm), ctx);
227 		} else {
228 			emit_imm(RV_REG_T0, imm, ctx);
229 			emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
230 		}
231 		if (imm >= 0)
232 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
233 		break;
234 	case BPF_OR:
235 		if (is_12b_int(imm)) {
236 			emit(rv_ori(lo(rd), lo(rd), imm), ctx);
237 		} else {
238 			emit_imm(RV_REG_T0, imm, ctx);
239 			emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
240 		}
241 		if (imm < 0)
242 			emit(rv_ori(hi(rd), RV_REG_ZERO, -1), ctx);
243 		break;
244 	case BPF_XOR:
245 		if (is_12b_int(imm)) {
246 			emit(rv_xori(lo(rd), lo(rd), imm), ctx);
247 		} else {
248 			emit_imm(RV_REG_T0, imm, ctx);
249 			emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
250 		}
251 		if (imm < 0)
252 			emit(rv_xori(hi(rd), hi(rd), -1), ctx);
253 		break;
254 	case BPF_LSH:
255 		if (imm >= 32) {
256 			emit(rv_slli(hi(rd), lo(rd), imm - 32), ctx);
257 			emit(rv_addi(lo(rd), RV_REG_ZERO, 0), ctx);
258 		} else if (imm == 0) {
259 			/* Do nothing. */
260 		} else {
261 			emit(rv_srli(RV_REG_T0, lo(rd), 32 - imm), ctx);
262 			emit(rv_slli(hi(rd), hi(rd), imm), ctx);
263 			emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
264 			emit(rv_slli(lo(rd), lo(rd), imm), ctx);
265 		}
266 		break;
267 	case BPF_RSH:
268 		if (imm >= 32) {
269 			emit(rv_srli(lo(rd), hi(rd), imm - 32), ctx);
270 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
271 		} else if (imm == 0) {
272 			/* Do nothing. */
273 		} else {
274 			emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
275 			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
276 			emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
277 			emit(rv_srli(hi(rd), hi(rd), imm), ctx);
278 		}
279 		break;
280 	case BPF_ARSH:
281 		if (imm >= 32) {
282 			emit(rv_srai(lo(rd), hi(rd), imm - 32), ctx);
283 			emit(rv_srai(hi(rd), hi(rd), 31), ctx);
284 		} else if (imm == 0) {
285 			/* Do nothing. */
286 		} else {
287 			emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
288 			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
289 			emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
290 			emit(rv_srai(hi(rd), hi(rd), imm), ctx);
291 		}
292 		break;
293 	}
294 
295 	bpf_put_reg64(dst, rd, ctx);
296 }
297 
298 static void emit_alu_i32(const s8 *dst, s32 imm,
299 			 struct rv_jit_context *ctx, const u8 op)
300 {
301 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
302 	const s8 *rd = bpf_get_reg32(dst, tmp1, ctx);
303 
304 	switch (op) {
305 	case BPF_MOV:
306 		emit_imm(lo(rd), imm, ctx);
307 		break;
308 	case BPF_ADD:
309 		if (is_12b_int(imm)) {
310 			emit(rv_addi(lo(rd), lo(rd), imm), ctx);
311 		} else {
312 			emit_imm(RV_REG_T0, imm, ctx);
313 			emit(rv_add(lo(rd), lo(rd), RV_REG_T0), ctx);
314 		}
315 		break;
316 	case BPF_SUB:
317 		if (is_12b_int(-imm)) {
318 			emit(rv_addi(lo(rd), lo(rd), -imm), ctx);
319 		} else {
320 			emit_imm(RV_REG_T0, imm, ctx);
321 			emit(rv_sub(lo(rd), lo(rd), RV_REG_T0), ctx);
322 		}
323 		break;
324 	case BPF_AND:
325 		if (is_12b_int(imm)) {
326 			emit(rv_andi(lo(rd), lo(rd), imm), ctx);
327 		} else {
328 			emit_imm(RV_REG_T0, imm, ctx);
329 			emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
330 		}
331 		break;
332 	case BPF_OR:
333 		if (is_12b_int(imm)) {
334 			emit(rv_ori(lo(rd), lo(rd), imm), ctx);
335 		} else {
336 			emit_imm(RV_REG_T0, imm, ctx);
337 			emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
338 		}
339 		break;
340 	case BPF_XOR:
341 		if (is_12b_int(imm)) {
342 			emit(rv_xori(lo(rd), lo(rd), imm), ctx);
343 		} else {
344 			emit_imm(RV_REG_T0, imm, ctx);
345 			emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
346 		}
347 		break;
348 	case BPF_LSH:
349 		if (is_12b_int(imm)) {
350 			emit(rv_slli(lo(rd), lo(rd), imm), ctx);
351 		} else {
352 			emit_imm(RV_REG_T0, imm, ctx);
353 			emit(rv_sll(lo(rd), lo(rd), RV_REG_T0), ctx);
354 		}
355 		break;
356 	case BPF_RSH:
357 		if (is_12b_int(imm)) {
358 			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
359 		} else {
360 			emit_imm(RV_REG_T0, imm, ctx);
361 			emit(rv_srl(lo(rd), lo(rd), RV_REG_T0), ctx);
362 		}
363 		break;
364 	case BPF_ARSH:
365 		if (is_12b_int(imm)) {
366 			emit(rv_srai(lo(rd), lo(rd), imm), ctx);
367 		} else {
368 			emit_imm(RV_REG_T0, imm, ctx);
369 			emit(rv_sra(lo(rd), lo(rd), RV_REG_T0), ctx);
370 		}
371 		break;
372 	}
373 
374 	bpf_put_reg32(dst, rd, ctx);
375 }
376 
377 static void emit_alu_r64(const s8 *dst, const s8 *src,
378 			 struct rv_jit_context *ctx, const u8 op)
379 {
380 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
381 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
382 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
383 	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
384 
385 	switch (op) {
386 	case BPF_MOV:
387 		emit(rv_addi(lo(rd), lo(rs), 0), ctx);
388 		emit(rv_addi(hi(rd), hi(rs), 0), ctx);
389 		break;
390 	case BPF_ADD:
391 		if (rd == rs) {
392 			emit(rv_srli(RV_REG_T0, lo(rd), 31), ctx);
393 			emit(rv_slli(hi(rd), hi(rd), 1), ctx);
394 			emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
395 			emit(rv_slli(lo(rd), lo(rd), 1), ctx);
396 		} else {
397 			emit(rv_add(lo(rd), lo(rd), lo(rs)), ctx);
398 			emit(rv_sltu(RV_REG_T0, lo(rd), lo(rs)), ctx);
399 			emit(rv_add(hi(rd), hi(rd), hi(rs)), ctx);
400 			emit(rv_add(hi(rd), hi(rd), RV_REG_T0), ctx);
401 		}
402 		break;
403 	case BPF_SUB:
404 		emit(rv_sub(RV_REG_T1, hi(rd), hi(rs)), ctx);
405 		emit(rv_sltu(RV_REG_T0, lo(rd), lo(rs)), ctx);
406 		emit(rv_sub(hi(rd), RV_REG_T1, RV_REG_T0), ctx);
407 		emit(rv_sub(lo(rd), lo(rd), lo(rs)), ctx);
408 		break;
409 	case BPF_AND:
410 		emit(rv_and(lo(rd), lo(rd), lo(rs)), ctx);
411 		emit(rv_and(hi(rd), hi(rd), hi(rs)), ctx);
412 		break;
413 	case BPF_OR:
414 		emit(rv_or(lo(rd), lo(rd), lo(rs)), ctx);
415 		emit(rv_or(hi(rd), hi(rd), hi(rs)), ctx);
416 		break;
417 	case BPF_XOR:
418 		emit(rv_xor(lo(rd), lo(rd), lo(rs)), ctx);
419 		emit(rv_xor(hi(rd), hi(rd), hi(rs)), ctx);
420 		break;
421 	case BPF_MUL:
422 		emit(rv_mul(RV_REG_T0, hi(rs), lo(rd)), ctx);
423 		emit(rv_mul(hi(rd), hi(rd), lo(rs)), ctx);
424 		emit(rv_mulhu(RV_REG_T1, lo(rd), lo(rs)), ctx);
425 		emit(rv_add(hi(rd), hi(rd), RV_REG_T0), ctx);
426 		emit(rv_mul(lo(rd), lo(rd), lo(rs)), ctx);
427 		emit(rv_add(hi(rd), hi(rd), RV_REG_T1), ctx);
428 		break;
429 	case BPF_LSH:
430 		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
431 		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
432 		emit(rv_sll(hi(rd), lo(rd), RV_REG_T0), ctx);
433 		emit(rv_addi(lo(rd), RV_REG_ZERO, 0), ctx);
434 		emit(rv_jal(RV_REG_ZERO, 16), ctx);
435 		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
436 		emit(rv_srli(RV_REG_T0, lo(rd), 1), ctx);
437 		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
438 		emit(rv_srl(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
439 		emit(rv_sll(hi(rd), hi(rd), lo(rs)), ctx);
440 		emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
441 		emit(rv_sll(lo(rd), lo(rd), lo(rs)), ctx);
442 		break;
443 	case BPF_RSH:
444 		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
445 		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
446 		emit(rv_srl(lo(rd), hi(rd), RV_REG_T0), ctx);
447 		emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
448 		emit(rv_jal(RV_REG_ZERO, 16), ctx);
449 		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
450 		emit(rv_slli(RV_REG_T0, hi(rd), 1), ctx);
451 		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
452 		emit(rv_sll(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
453 		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
454 		emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
455 		emit(rv_srl(hi(rd), hi(rd), lo(rs)), ctx);
456 		break;
457 	case BPF_ARSH:
458 		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
459 		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
460 		emit(rv_sra(lo(rd), hi(rd), RV_REG_T0), ctx);
461 		emit(rv_srai(hi(rd), hi(rd), 31), ctx);
462 		emit(rv_jal(RV_REG_ZERO, 16), ctx);
463 		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
464 		emit(rv_slli(RV_REG_T0, hi(rd), 1), ctx);
465 		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
466 		emit(rv_sll(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
467 		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
468 		emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
469 		emit(rv_sra(hi(rd), hi(rd), lo(rs)), ctx);
470 		break;
471 	case BPF_NEG:
472 		emit(rv_sub(lo(rd), RV_REG_ZERO, lo(rd)), ctx);
473 		emit(rv_sltu(RV_REG_T0, RV_REG_ZERO, lo(rd)), ctx);
474 		emit(rv_sub(hi(rd), RV_REG_ZERO, hi(rd)), ctx);
475 		emit(rv_sub(hi(rd), hi(rd), RV_REG_T0), ctx);
476 		break;
477 	}
478 
479 	bpf_put_reg64(dst, rd, ctx);
480 }
481 
482 static void emit_alu_r32(const s8 *dst, const s8 *src,
483 			 struct rv_jit_context *ctx, const u8 op)
484 {
485 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
486 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
487 	const s8 *rd = bpf_get_reg32(dst, tmp1, ctx);
488 	const s8 *rs = bpf_get_reg32(src, tmp2, ctx);
489 
490 	switch (op) {
491 	case BPF_MOV:
492 		emit(rv_addi(lo(rd), lo(rs), 0), ctx);
493 		break;
494 	case BPF_ADD:
495 		emit(rv_add(lo(rd), lo(rd), lo(rs)), ctx);
496 		break;
497 	case BPF_SUB:
498 		emit(rv_sub(lo(rd), lo(rd), lo(rs)), ctx);
499 		break;
500 	case BPF_AND:
501 		emit(rv_and(lo(rd), lo(rd), lo(rs)), ctx);
502 		break;
503 	case BPF_OR:
504 		emit(rv_or(lo(rd), lo(rd), lo(rs)), ctx);
505 		break;
506 	case BPF_XOR:
507 		emit(rv_xor(lo(rd), lo(rd), lo(rs)), ctx);
508 		break;
509 	case BPF_MUL:
510 		emit(rv_mul(lo(rd), lo(rd), lo(rs)), ctx);
511 		break;
512 	case BPF_DIV:
513 		emit(rv_divu(lo(rd), lo(rd), lo(rs)), ctx);
514 		break;
515 	case BPF_MOD:
516 		emit(rv_remu(lo(rd), lo(rd), lo(rs)), ctx);
517 		break;
518 	case BPF_LSH:
519 		emit(rv_sll(lo(rd), lo(rd), lo(rs)), ctx);
520 		break;
521 	case BPF_RSH:
522 		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
523 		break;
524 	case BPF_ARSH:
525 		emit(rv_sra(lo(rd), lo(rd), lo(rs)), ctx);
526 		break;
527 	case BPF_NEG:
528 		emit(rv_sub(lo(rd), RV_REG_ZERO, lo(rd)), ctx);
529 		break;
530 	}
531 
532 	bpf_put_reg32(dst, rd, ctx);
533 }
534 
535 static int emit_branch_r64(const s8 *src1, const s8 *src2, s32 rvoff,
536 			   struct rv_jit_context *ctx, const u8 op)
537 {
538 	int e, s = ctx->ninsns;
539 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
540 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
541 
542 	const s8 *rs1 = bpf_get_reg64(src1, tmp1, ctx);
543 	const s8 *rs2 = bpf_get_reg64(src2, tmp2, ctx);
544 
545 	/*
546 	 * NO_JUMP skips over the rest of the instructions and the
547 	 * emit_jump_and_link, meaning the BPF branch is not taken.
548 	 * JUMP skips directly to the emit_jump_and_link, meaning
549 	 * the BPF branch is taken.
550 	 *
551 	 * The fallthrough case results in the BPF branch being taken.
552 	 */
553 #define NO_JUMP(idx) (6 + (2 * (idx)))
554 #define JUMP(idx) (2 + (2 * (idx)))
555 
556 	switch (op) {
557 	case BPF_JEQ:
558 		emit(rv_bne(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
559 		emit(rv_bne(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
560 		break;
561 	case BPF_JGT:
562 		emit(rv_bgtu(hi(rs1), hi(rs2), JUMP(2)), ctx);
563 		emit(rv_bltu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
564 		emit(rv_bleu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
565 		break;
566 	case BPF_JLT:
567 		emit(rv_bltu(hi(rs1), hi(rs2), JUMP(2)), ctx);
568 		emit(rv_bgtu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
569 		emit(rv_bgeu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
570 		break;
571 	case BPF_JGE:
572 		emit(rv_bgtu(hi(rs1), hi(rs2), JUMP(2)), ctx);
573 		emit(rv_bltu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
574 		emit(rv_bltu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
575 		break;
576 	case BPF_JLE:
577 		emit(rv_bltu(hi(rs1), hi(rs2), JUMP(2)), ctx);
578 		emit(rv_bgtu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
579 		emit(rv_bgtu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
580 		break;
581 	case BPF_JNE:
582 		emit(rv_bne(hi(rs1), hi(rs2), JUMP(1)), ctx);
583 		emit(rv_beq(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
584 		break;
585 	case BPF_JSGT:
586 		emit(rv_bgt(hi(rs1), hi(rs2), JUMP(2)), ctx);
587 		emit(rv_blt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
588 		emit(rv_bleu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
589 		break;
590 	case BPF_JSLT:
591 		emit(rv_blt(hi(rs1), hi(rs2), JUMP(2)), ctx);
592 		emit(rv_bgt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
593 		emit(rv_bgeu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
594 		break;
595 	case BPF_JSGE:
596 		emit(rv_bgt(hi(rs1), hi(rs2), JUMP(2)), ctx);
597 		emit(rv_blt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
598 		emit(rv_bltu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
599 		break;
600 	case BPF_JSLE:
601 		emit(rv_blt(hi(rs1), hi(rs2), JUMP(2)), ctx);
602 		emit(rv_bgt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
603 		emit(rv_bgtu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
604 		break;
605 	case BPF_JSET:
606 		emit(rv_and(RV_REG_T0, hi(rs1), hi(rs2)), ctx);
607 		emit(rv_bne(RV_REG_T0, RV_REG_ZERO, JUMP(2)), ctx);
608 		emit(rv_and(RV_REG_T0, lo(rs1), lo(rs2)), ctx);
609 		emit(rv_beq(RV_REG_T0, RV_REG_ZERO, NO_JUMP(0)), ctx);
610 		break;
611 	}
612 
613 #undef NO_JUMP
614 #undef JUMP
615 
616 	e = ctx->ninsns;
617 	/* Adjust for extra insns. */
618 	rvoff -= (e - s) << 2;
619 	emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
620 	return 0;
621 }
622 
623 static int emit_bcc(u8 op, u8 rd, u8 rs, int rvoff, struct rv_jit_context *ctx)
624 {
625 	int e, s = ctx->ninsns;
626 	bool far = false;
627 	int off;
628 
629 	if (op == BPF_JSET) {
630 		/*
631 		 * BPF_JSET is a special case: it has no inverse so we always
632 		 * treat it as a far branch.
633 		 */
634 		far = true;
635 	} else if (!is_13b_int(rvoff)) {
636 		op = invert_bpf_cond(op);
637 		far = true;
638 	}
639 
640 	/*
641 	 * For a far branch, the condition is negated and we jump over the
642 	 * branch itself, and the two instructions from emit_jump_and_link.
643 	 * For a near branch, just use rvoff.
644 	 */
645 	off = far ? 6 : (rvoff >> 1);
646 
647 	switch (op) {
648 	case BPF_JEQ:
649 		emit(rv_beq(rd, rs, off), ctx);
650 		break;
651 	case BPF_JGT:
652 		emit(rv_bgtu(rd, rs, off), ctx);
653 		break;
654 	case BPF_JLT:
655 		emit(rv_bltu(rd, rs, off), ctx);
656 		break;
657 	case BPF_JGE:
658 		emit(rv_bgeu(rd, rs, off), ctx);
659 		break;
660 	case BPF_JLE:
661 		emit(rv_bleu(rd, rs, off), ctx);
662 		break;
663 	case BPF_JNE:
664 		emit(rv_bne(rd, rs, off), ctx);
665 		break;
666 	case BPF_JSGT:
667 		emit(rv_bgt(rd, rs, off), ctx);
668 		break;
669 	case BPF_JSLT:
670 		emit(rv_blt(rd, rs, off), ctx);
671 		break;
672 	case BPF_JSGE:
673 		emit(rv_bge(rd, rs, off), ctx);
674 		break;
675 	case BPF_JSLE:
676 		emit(rv_ble(rd, rs, off), ctx);
677 		break;
678 	case BPF_JSET:
679 		emit(rv_and(RV_REG_T0, rd, rs), ctx);
680 		emit(rv_beq(RV_REG_T0, RV_REG_ZERO, off), ctx);
681 		break;
682 	}
683 
684 	if (far) {
685 		e = ctx->ninsns;
686 		/* Adjust for extra insns. */
687 		rvoff -= (e - s) << 2;
688 		emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
689 	}
690 	return 0;
691 }
692 
693 static int emit_branch_r32(const s8 *src1, const s8 *src2, s32 rvoff,
694 			   struct rv_jit_context *ctx, const u8 op)
695 {
696 	int e, s = ctx->ninsns;
697 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
698 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
699 
700 	const s8 *rs1 = bpf_get_reg32(src1, tmp1, ctx);
701 	const s8 *rs2 = bpf_get_reg32(src2, tmp2, ctx);
702 
703 	e = ctx->ninsns;
704 	/* Adjust for extra insns. */
705 	rvoff -= (e - s) << 2;
706 
707 	if (emit_bcc(op, lo(rs1), lo(rs2), rvoff, ctx))
708 		return -1;
709 
710 	return 0;
711 }
712 
713 static void emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
714 {
715 	const s8 *r0 = bpf2rv32[BPF_REG_0];
716 	const s8 *r5 = bpf2rv32[BPF_REG_5];
717 	u32 upper = ((u32)addr + (1 << 11)) >> 12;
718 	u32 lower = addr & 0xfff;
719 
720 	/* R1-R4 already in correct registers---need to push R5 to stack. */
721 	emit(rv_addi(RV_REG_SP, RV_REG_SP, -16), ctx);
722 	emit(rv_sw(RV_REG_SP, 0, lo(r5)), ctx);
723 	emit(rv_sw(RV_REG_SP, 4, hi(r5)), ctx);
724 
725 	/* Backup TCC. */
726 	emit(rv_addi(RV_REG_TCC_SAVED, RV_REG_TCC, 0), ctx);
727 
728 	/*
729 	 * Use lui/jalr pair to jump to absolute address. Don't use emit_imm as
730 	 * the number of emitted instructions should not depend on the value of
731 	 * addr.
732 	 */
733 	emit(rv_lui(RV_REG_T1, upper), ctx);
734 	emit(rv_jalr(RV_REG_RA, RV_REG_T1, lower), ctx);
735 
736 	/* Restore TCC. */
737 	emit(rv_addi(RV_REG_TCC, RV_REG_TCC_SAVED, 0), ctx);
738 
739 	/* Set return value and restore stack. */
740 	emit(rv_addi(lo(r0), RV_REG_A0, 0), ctx);
741 	emit(rv_addi(hi(r0), RV_REG_A1, 0), ctx);
742 	emit(rv_addi(RV_REG_SP, RV_REG_SP, 16), ctx);
743 }
744 
745 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
746 {
747 	/*
748 	 * R1 -> &ctx
749 	 * R2 -> &array
750 	 * R3 -> index
751 	 */
752 	int tc_ninsn, off, start_insn = ctx->ninsns;
753 	const s8 *arr_reg = bpf2rv32[BPF_REG_2];
754 	const s8 *idx_reg = bpf2rv32[BPF_REG_3];
755 
756 	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
757 		ctx->offset[0];
758 
759 	/* max_entries = array->map.max_entries; */
760 	off = offsetof(struct bpf_array, map.max_entries);
761 	if (is_12b_check(off, insn))
762 		return -1;
763 	emit(rv_lw(RV_REG_T1, off, lo(arr_reg)), ctx);
764 
765 	/*
766 	 * if (index >= max_entries)
767 	 *   goto out;
768 	 */
769 	off = (tc_ninsn - (ctx->ninsns - start_insn)) << 2;
770 	emit_bcc(BPF_JGE, lo(idx_reg), RV_REG_T1, off, ctx);
771 
772 	/*
773 	 * if ((temp_tcc = tcc - 1) < 0)
774 	 *   goto out;
775 	 */
776 	emit(rv_addi(RV_REG_T1, RV_REG_TCC, -1), ctx);
777 	off = (tc_ninsn - (ctx->ninsns - start_insn)) << 2;
778 	emit_bcc(BPF_JSLT, RV_REG_T1, RV_REG_ZERO, off, ctx);
779 
780 	/*
781 	 * prog = array->ptrs[index];
782 	 * if (!prog)
783 	 *   goto out;
784 	 */
785 	emit(rv_slli(RV_REG_T0, lo(idx_reg), 2), ctx);
786 	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(arr_reg)), ctx);
787 	off = offsetof(struct bpf_array, ptrs);
788 	if (is_12b_check(off, insn))
789 		return -1;
790 	emit(rv_lw(RV_REG_T0, off, RV_REG_T0), ctx);
791 	off = (tc_ninsn - (ctx->ninsns - start_insn)) << 2;
792 	emit_bcc(BPF_JEQ, RV_REG_T0, RV_REG_ZERO, off, ctx);
793 
794 	/*
795 	 * tcc = temp_tcc;
796 	 * goto *(prog->bpf_func + 4);
797 	 */
798 	off = offsetof(struct bpf_prog, bpf_func);
799 	if (is_12b_check(off, insn))
800 		return -1;
801 	emit(rv_lw(RV_REG_T0, off, RV_REG_T0), ctx);
802 	emit(rv_addi(RV_REG_TCC, RV_REG_T1, 0), ctx);
803 	/* Epilogue jumps to *(t0 + 4). */
804 	__build_epilogue(true, ctx);
805 	return 0;
806 }
807 
808 static int emit_load_r64(const s8 *dst, const s8 *src, s16 off,
809 			 struct rv_jit_context *ctx, const u8 size)
810 {
811 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
812 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
813 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
814 	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
815 
816 	emit_imm(RV_REG_T0, off, ctx);
817 	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(rs)), ctx);
818 
819 	switch (size) {
820 	case BPF_B:
821 		emit(rv_lbu(lo(rd), 0, RV_REG_T0), ctx);
822 		if (!ctx->prog->aux->verifier_zext)
823 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
824 		break;
825 	case BPF_H:
826 		emit(rv_lhu(lo(rd), 0, RV_REG_T0), ctx);
827 		if (!ctx->prog->aux->verifier_zext)
828 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
829 		break;
830 	case BPF_W:
831 		emit(rv_lw(lo(rd), 0, RV_REG_T0), ctx);
832 		if (!ctx->prog->aux->verifier_zext)
833 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
834 		break;
835 	case BPF_DW:
836 		emit(rv_lw(lo(rd), 0, RV_REG_T0), ctx);
837 		emit(rv_lw(hi(rd), 4, RV_REG_T0), ctx);
838 		break;
839 	}
840 
841 	bpf_put_reg64(dst, rd, ctx);
842 	return 0;
843 }
844 
845 static int emit_store_r64(const s8 *dst, const s8 *src, s16 off,
846 			  struct rv_jit_context *ctx, const u8 size,
847 			  const u8 mode)
848 {
849 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
850 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
851 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
852 	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
853 
854 	if (mode == BPF_XADD && size != BPF_W)
855 		return -1;
856 
857 	emit_imm(RV_REG_T0, off, ctx);
858 	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(rd)), ctx);
859 
860 	switch (size) {
861 	case BPF_B:
862 		emit(rv_sb(RV_REG_T0, 0, lo(rs)), ctx);
863 		break;
864 	case BPF_H:
865 		emit(rv_sh(RV_REG_T0, 0, lo(rs)), ctx);
866 		break;
867 	case BPF_W:
868 		switch (mode) {
869 		case BPF_MEM:
870 			emit(rv_sw(RV_REG_T0, 0, lo(rs)), ctx);
871 			break;
872 		case BPF_XADD:
873 			emit(rv_amoadd_w(RV_REG_ZERO, lo(rs), RV_REG_T0, 0, 0),
874 			     ctx);
875 			break;
876 		}
877 		break;
878 	case BPF_DW:
879 		emit(rv_sw(RV_REG_T0, 0, lo(rs)), ctx);
880 		emit(rv_sw(RV_REG_T0, 4, hi(rs)), ctx);
881 		break;
882 	}
883 
884 	return 0;
885 }
886 
887 static void emit_rev16(const s8 rd, struct rv_jit_context *ctx)
888 {
889 	emit(rv_slli(rd, rd, 16), ctx);
890 	emit(rv_slli(RV_REG_T1, rd, 8), ctx);
891 	emit(rv_srli(rd, rd, 8), ctx);
892 	emit(rv_add(RV_REG_T1, rd, RV_REG_T1), ctx);
893 	emit(rv_srli(rd, RV_REG_T1, 16), ctx);
894 }
895 
896 static void emit_rev32(const s8 rd, struct rv_jit_context *ctx)
897 {
898 	emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 0), ctx);
899 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
900 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
901 	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
902 	emit(rv_srli(rd, rd, 8), ctx);
903 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
904 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
905 	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
906 	emit(rv_srli(rd, rd, 8), ctx);
907 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
908 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
909 	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
910 	emit(rv_srli(rd, rd, 8), ctx);
911 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
912 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
913 	emit(rv_addi(rd, RV_REG_T1, 0), ctx);
914 }
915 
916 static void emit_zext64(const s8 *dst, struct rv_jit_context *ctx)
917 {
918 	const s8 *rd;
919 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
920 
921 	rd = bpf_get_reg64(dst, tmp1, ctx);
922 	emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
923 	bpf_put_reg64(dst, rd, ctx);
924 }
925 
926 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
927 		      bool extra_pass)
928 {
929 	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
930 		BPF_CLASS(insn->code) == BPF_JMP;
931 	int s, e, rvoff, i = insn - ctx->prog->insnsi;
932 	u8 code = insn->code;
933 	s16 off = insn->off;
934 	s32 imm = insn->imm;
935 
936 	const s8 *dst = bpf2rv32[insn->dst_reg];
937 	const s8 *src = bpf2rv32[insn->src_reg];
938 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
939 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
940 
941 	switch (code) {
942 	case BPF_ALU64 | BPF_MOV | BPF_X:
943 
944 	case BPF_ALU64 | BPF_ADD | BPF_X:
945 	case BPF_ALU64 | BPF_ADD | BPF_K:
946 
947 	case BPF_ALU64 | BPF_SUB | BPF_X:
948 	case BPF_ALU64 | BPF_SUB | BPF_K:
949 
950 	case BPF_ALU64 | BPF_AND | BPF_X:
951 	case BPF_ALU64 | BPF_OR | BPF_X:
952 	case BPF_ALU64 | BPF_XOR | BPF_X:
953 
954 	case BPF_ALU64 | BPF_MUL | BPF_X:
955 	case BPF_ALU64 | BPF_MUL | BPF_K:
956 
957 	case BPF_ALU64 | BPF_LSH | BPF_X:
958 	case BPF_ALU64 | BPF_RSH | BPF_X:
959 	case BPF_ALU64 | BPF_ARSH | BPF_X:
960 		if (BPF_SRC(code) == BPF_K) {
961 			emit_imm32(tmp2, imm, ctx);
962 			src = tmp2;
963 		}
964 		emit_alu_r64(dst, src, ctx, BPF_OP(code));
965 		break;
966 
967 	case BPF_ALU64 | BPF_NEG:
968 		emit_alu_r64(dst, tmp2, ctx, BPF_OP(code));
969 		break;
970 
971 	case BPF_ALU64 | BPF_DIV | BPF_X:
972 	case BPF_ALU64 | BPF_DIV | BPF_K:
973 	case BPF_ALU64 | BPF_MOD | BPF_X:
974 	case BPF_ALU64 | BPF_MOD | BPF_K:
975 		goto notsupported;
976 
977 	case BPF_ALU64 | BPF_MOV | BPF_K:
978 	case BPF_ALU64 | BPF_AND | BPF_K:
979 	case BPF_ALU64 | BPF_OR | BPF_K:
980 	case BPF_ALU64 | BPF_XOR | BPF_K:
981 	case BPF_ALU64 | BPF_LSH | BPF_K:
982 	case BPF_ALU64 | BPF_RSH | BPF_K:
983 	case BPF_ALU64 | BPF_ARSH | BPF_K:
984 		emit_alu_i64(dst, imm, ctx, BPF_OP(code));
985 		break;
986 
987 	case BPF_ALU | BPF_MOV | BPF_X:
988 		if (imm == 1) {
989 			/* Special mov32 for zext. */
990 			emit_zext64(dst, ctx);
991 			break;
992 		}
993 		/* Fallthrough. */
994 
995 	case BPF_ALU | BPF_ADD | BPF_X:
996 	case BPF_ALU | BPF_SUB | BPF_X:
997 	case BPF_ALU | BPF_AND | BPF_X:
998 	case BPF_ALU | BPF_OR | BPF_X:
999 	case BPF_ALU | BPF_XOR | BPF_X:
1000 
1001 	case BPF_ALU | BPF_MUL | BPF_X:
1002 	case BPF_ALU | BPF_MUL | BPF_K:
1003 
1004 	case BPF_ALU | BPF_DIV | BPF_X:
1005 	case BPF_ALU | BPF_DIV | BPF_K:
1006 
1007 	case BPF_ALU | BPF_MOD | BPF_X:
1008 	case BPF_ALU | BPF_MOD | BPF_K:
1009 
1010 	case BPF_ALU | BPF_LSH | BPF_X:
1011 	case BPF_ALU | BPF_RSH | BPF_X:
1012 	case BPF_ALU | BPF_ARSH | BPF_X:
1013 		if (BPF_SRC(code) == BPF_K) {
1014 			emit_imm32(tmp2, imm, ctx);
1015 			src = tmp2;
1016 		}
1017 		emit_alu_r32(dst, src, ctx, BPF_OP(code));
1018 		break;
1019 
1020 	case BPF_ALU | BPF_MOV | BPF_K:
1021 	case BPF_ALU | BPF_ADD | BPF_K:
1022 	case BPF_ALU | BPF_SUB | BPF_K:
1023 	case BPF_ALU | BPF_AND | BPF_K:
1024 	case BPF_ALU | BPF_OR | BPF_K:
1025 	case BPF_ALU | BPF_XOR | BPF_K:
1026 	case BPF_ALU | BPF_LSH | BPF_K:
1027 	case BPF_ALU | BPF_RSH | BPF_K:
1028 	case BPF_ALU | BPF_ARSH | BPF_K:
1029 		/*
1030 		 * mul,div,mod are handled in the BPF_X case since there are
1031 		 * no RISC-V I-type equivalents.
1032 		 */
1033 		emit_alu_i32(dst, imm, ctx, BPF_OP(code));
1034 		break;
1035 
1036 	case BPF_ALU | BPF_NEG:
1037 		/*
1038 		 * src is ignored---choose tmp2 as a dummy register since it
1039 		 * is not on the stack.
1040 		 */
1041 		emit_alu_r32(dst, tmp2, ctx, BPF_OP(code));
1042 		break;
1043 
1044 	case BPF_ALU | BPF_END | BPF_FROM_LE:
1045 	{
1046 		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1047 
1048 		switch (imm) {
1049 		case 16:
1050 			emit(rv_slli(lo(rd), lo(rd), 16), ctx);
1051 			emit(rv_srli(lo(rd), lo(rd), 16), ctx);
1052 			/* Fallthrough. */
1053 		case 32:
1054 			if (!ctx->prog->aux->verifier_zext)
1055 				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1056 			break;
1057 		case 64:
1058 			/* Do nothing. */
1059 			break;
1060 		default:
1061 			pr_err("bpf-jit: BPF_END imm %d invalid\n", imm);
1062 			return -1;
1063 		}
1064 
1065 		bpf_put_reg64(dst, rd, ctx);
1066 		break;
1067 	}
1068 
1069 	case BPF_ALU | BPF_END | BPF_FROM_BE:
1070 	{
1071 		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1072 
1073 		switch (imm) {
1074 		case 16:
1075 			emit_rev16(lo(rd), ctx);
1076 			if (!ctx->prog->aux->verifier_zext)
1077 				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1078 			break;
1079 		case 32:
1080 			emit_rev32(lo(rd), ctx);
1081 			if (!ctx->prog->aux->verifier_zext)
1082 				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1083 			break;
1084 		case 64:
1085 			/* Swap upper and lower halves. */
1086 			emit(rv_addi(RV_REG_T0, lo(rd), 0), ctx);
1087 			emit(rv_addi(lo(rd), hi(rd), 0), ctx);
1088 			emit(rv_addi(hi(rd), RV_REG_T0, 0), ctx);
1089 
1090 			/* Swap each half. */
1091 			emit_rev32(lo(rd), ctx);
1092 			emit_rev32(hi(rd), ctx);
1093 			break;
1094 		default:
1095 			pr_err("bpf-jit: BPF_END imm %d invalid\n", imm);
1096 			return -1;
1097 		}
1098 
1099 		bpf_put_reg64(dst, rd, ctx);
1100 		break;
1101 	}
1102 
1103 	case BPF_JMP | BPF_JA:
1104 		rvoff = rv_offset(i, off, ctx);
1105 		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
1106 		break;
1107 
1108 	case BPF_JMP | BPF_CALL:
1109 	{
1110 		bool fixed;
1111 		int ret;
1112 		u64 addr;
1113 
1114 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &addr,
1115 					    &fixed);
1116 		if (ret < 0)
1117 			return ret;
1118 		emit_call(fixed, addr, ctx);
1119 		break;
1120 	}
1121 
1122 	case BPF_JMP | BPF_TAIL_CALL:
1123 		if (emit_bpf_tail_call(i, ctx))
1124 			return -1;
1125 		break;
1126 
1127 	case BPF_JMP | BPF_JEQ | BPF_X:
1128 	case BPF_JMP | BPF_JEQ | BPF_K:
1129 	case BPF_JMP32 | BPF_JEQ | BPF_X:
1130 	case BPF_JMP32 | BPF_JEQ | BPF_K:
1131 
1132 	case BPF_JMP | BPF_JNE | BPF_X:
1133 	case BPF_JMP | BPF_JNE | BPF_K:
1134 	case BPF_JMP32 | BPF_JNE | BPF_X:
1135 	case BPF_JMP32 | BPF_JNE | BPF_K:
1136 
1137 	case BPF_JMP | BPF_JLE | BPF_X:
1138 	case BPF_JMP | BPF_JLE | BPF_K:
1139 	case BPF_JMP32 | BPF_JLE | BPF_X:
1140 	case BPF_JMP32 | BPF_JLE | BPF_K:
1141 
1142 	case BPF_JMP | BPF_JLT | BPF_X:
1143 	case BPF_JMP | BPF_JLT | BPF_K:
1144 	case BPF_JMP32 | BPF_JLT | BPF_X:
1145 	case BPF_JMP32 | BPF_JLT | BPF_K:
1146 
1147 	case BPF_JMP | BPF_JGE | BPF_X:
1148 	case BPF_JMP | BPF_JGE | BPF_K:
1149 	case BPF_JMP32 | BPF_JGE | BPF_X:
1150 	case BPF_JMP32 | BPF_JGE | BPF_K:
1151 
1152 	case BPF_JMP | BPF_JGT | BPF_X:
1153 	case BPF_JMP | BPF_JGT | BPF_K:
1154 	case BPF_JMP32 | BPF_JGT | BPF_X:
1155 	case BPF_JMP32 | BPF_JGT | BPF_K:
1156 
1157 	case BPF_JMP | BPF_JSLE | BPF_X:
1158 	case BPF_JMP | BPF_JSLE | BPF_K:
1159 	case BPF_JMP32 | BPF_JSLE | BPF_X:
1160 	case BPF_JMP32 | BPF_JSLE | BPF_K:
1161 
1162 	case BPF_JMP | BPF_JSLT | BPF_X:
1163 	case BPF_JMP | BPF_JSLT | BPF_K:
1164 	case BPF_JMP32 | BPF_JSLT | BPF_X:
1165 	case BPF_JMP32 | BPF_JSLT | BPF_K:
1166 
1167 	case BPF_JMP | BPF_JSGE | BPF_X:
1168 	case BPF_JMP | BPF_JSGE | BPF_K:
1169 	case BPF_JMP32 | BPF_JSGE | BPF_X:
1170 	case BPF_JMP32 | BPF_JSGE | BPF_K:
1171 
1172 	case BPF_JMP | BPF_JSGT | BPF_X:
1173 	case BPF_JMP | BPF_JSGT | BPF_K:
1174 	case BPF_JMP32 | BPF_JSGT | BPF_X:
1175 	case BPF_JMP32 | BPF_JSGT | BPF_K:
1176 
1177 	case BPF_JMP | BPF_JSET | BPF_X:
1178 	case BPF_JMP | BPF_JSET | BPF_K:
1179 	case BPF_JMP32 | BPF_JSET | BPF_X:
1180 	case BPF_JMP32 | BPF_JSET | BPF_K:
1181 		rvoff = rv_offset(i, off, ctx);
1182 		if (BPF_SRC(code) == BPF_K) {
1183 			s = ctx->ninsns;
1184 			emit_imm32(tmp2, imm, ctx);
1185 			src = tmp2;
1186 			e = ctx->ninsns;
1187 			rvoff -= (e - s) << 2;
1188 		}
1189 
1190 		if (is64)
1191 			emit_branch_r64(dst, src, rvoff, ctx, BPF_OP(code));
1192 		else
1193 			emit_branch_r32(dst, src, rvoff, ctx, BPF_OP(code));
1194 		break;
1195 
1196 	case BPF_JMP | BPF_EXIT:
1197 		if (i == ctx->prog->len - 1)
1198 			break;
1199 
1200 		rvoff = epilogue_offset(ctx);
1201 		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
1202 		break;
1203 
1204 	case BPF_LD | BPF_IMM | BPF_DW:
1205 	{
1206 		struct bpf_insn insn1 = insn[1];
1207 		s32 imm_lo = imm;
1208 		s32 imm_hi = insn1.imm;
1209 		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1210 
1211 		emit_imm64(rd, imm_hi, imm_lo, ctx);
1212 		bpf_put_reg64(dst, rd, ctx);
1213 		return 1;
1214 	}
1215 
1216 	case BPF_LDX | BPF_MEM | BPF_B:
1217 	case BPF_LDX | BPF_MEM | BPF_H:
1218 	case BPF_LDX | BPF_MEM | BPF_W:
1219 	case BPF_LDX | BPF_MEM | BPF_DW:
1220 		if (emit_load_r64(dst, src, off, ctx, BPF_SIZE(code)))
1221 			return -1;
1222 		break;
1223 
1224 	case BPF_ST | BPF_MEM | BPF_B:
1225 	case BPF_ST | BPF_MEM | BPF_H:
1226 	case BPF_ST | BPF_MEM | BPF_W:
1227 	case BPF_ST | BPF_MEM | BPF_DW:
1228 
1229 	case BPF_STX | BPF_MEM | BPF_B:
1230 	case BPF_STX | BPF_MEM | BPF_H:
1231 	case BPF_STX | BPF_MEM | BPF_W:
1232 	case BPF_STX | BPF_MEM | BPF_DW:
1233 	case BPF_STX | BPF_XADD | BPF_W:
1234 		if (BPF_CLASS(code) == BPF_ST) {
1235 			emit_imm32(tmp2, imm, ctx);
1236 			src = tmp2;
1237 		}
1238 
1239 		if (emit_store_r64(dst, src, off, ctx, BPF_SIZE(code),
1240 				   BPF_MODE(code)))
1241 			return -1;
1242 		break;
1243 
1244 	/* No hardware support for 8-byte atomics in RV32. */
1245 	case BPF_STX | BPF_XADD | BPF_DW:
1246 		/* Fallthrough. */
1247 
1248 notsupported:
1249 		pr_info_once("bpf-jit: not supported: opcode %02x ***\n", code);
1250 		return -EFAULT;
1251 
1252 	default:
1253 		pr_err("bpf-jit: unknown opcode %02x\n", code);
1254 		return -EINVAL;
1255 	}
1256 
1257 	return 0;
1258 }
1259 
1260 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1261 {
1262 	/* Make space to save 9 registers: ra, fp, s1--s7. */
1263 	int stack_adjust = 9 * sizeof(u32), store_offset, bpf_stack_adjust;
1264 	const s8 *fp = bpf2rv32[BPF_REG_FP];
1265 	const s8 *r1 = bpf2rv32[BPF_REG_1];
1266 
1267 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1268 	stack_adjust += bpf_stack_adjust;
1269 
1270 	store_offset = stack_adjust - 4;
1271 
1272 	stack_adjust += 4 * BPF_JIT_SCRATCH_REGS;
1273 
1274 	/*
1275 	 * The first instruction sets the tail-call-counter (TCC) register.
1276 	 * This instruction is skipped by tail calls.
1277 	 */
1278 	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1279 
1280 	emit(rv_addi(RV_REG_SP, RV_REG_SP, -stack_adjust), ctx);
1281 
1282 	/* Save callee-save registers. */
1283 	emit(rv_sw(RV_REG_SP, store_offset - 0, RV_REG_RA), ctx);
1284 	emit(rv_sw(RV_REG_SP, store_offset - 4, RV_REG_FP), ctx);
1285 	emit(rv_sw(RV_REG_SP, store_offset - 8, RV_REG_S1), ctx);
1286 	emit(rv_sw(RV_REG_SP, store_offset - 12, RV_REG_S2), ctx);
1287 	emit(rv_sw(RV_REG_SP, store_offset - 16, RV_REG_S3), ctx);
1288 	emit(rv_sw(RV_REG_SP, store_offset - 20, RV_REG_S4), ctx);
1289 	emit(rv_sw(RV_REG_SP, store_offset - 24, RV_REG_S5), ctx);
1290 	emit(rv_sw(RV_REG_SP, store_offset - 28, RV_REG_S6), ctx);
1291 	emit(rv_sw(RV_REG_SP, store_offset - 32, RV_REG_S7), ctx);
1292 
1293 	/* Set fp: used as the base address for stacked BPF registers. */
1294 	emit(rv_addi(RV_REG_FP, RV_REG_SP, stack_adjust), ctx);
1295 
1296 	/* Set up BPF stack pointer. */
1297 	emit(rv_addi(lo(fp), RV_REG_SP, bpf_stack_adjust), ctx);
1298 	emit(rv_addi(hi(fp), RV_REG_ZERO, 0), ctx);
1299 
1300 	/* Set up context pointer. */
1301 	emit(rv_addi(lo(r1), RV_REG_A0, 0), ctx);
1302 	emit(rv_addi(hi(r1), RV_REG_ZERO, 0), ctx);
1303 
1304 	ctx->stack_size = stack_adjust;
1305 }
1306 
1307 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1308 {
1309 	__build_epilogue(false, ctx);
1310 }
1311