xref: /openbmc/linux/arch/riscv/net/bpf_jit_comp32.c (revision ebf7f6f0)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * BPF JIT compiler for RV32G
4  *
5  * Copyright (c) 2020 Luke Nelson <luke.r.nels@gmail.com>
6  * Copyright (c) 2020 Xi Wang <xi.wang@gmail.com>
7  *
8  * The code is based on the BPF JIT compiler for RV64G by Björn Töpel and
9  * the BPF JIT compiler for 32-bit ARM by Shubham Bansal and Mircea Gherzan.
10  */
11 
12 #include <linux/bpf.h>
13 #include <linux/filter.h>
14 #include "bpf_jit.h"
15 
16 /*
17  * Stack layout during BPF program execution:
18  *
19  *                     high
20  *     RV32 fp =>  +----------+
21  *                 | saved ra |
22  *                 | saved fp | RV32 callee-saved registers
23  *                 |   ...    |
24  *                 +----------+ <= (fp - 4 * NR_SAVED_REGISTERS)
25  *                 |  hi(R6)  |
26  *                 |  lo(R6)  |
27  *                 |  hi(R7)  | JIT scratch space for BPF registers
28  *                 |  lo(R7)  |
29  *                 |   ...    |
30  *  BPF_REG_FP =>  +----------+ <= (fp - 4 * NR_SAVED_REGISTERS
31  *                 |          |        - 4 * BPF_JIT_SCRATCH_REGS)
32  *                 |          |
33  *                 |   ...    | BPF program stack
34  *                 |          |
35  *     RV32 sp =>  +----------+
36  *                 |          |
37  *                 |   ...    | Function call stack
38  *                 |          |
39  *                 +----------+
40  *                     low
41  */
42 
43 enum {
44 	/* Stack layout - these are offsets from top of JIT scratch space. */
45 	BPF_R6_HI,
46 	BPF_R6_LO,
47 	BPF_R7_HI,
48 	BPF_R7_LO,
49 	BPF_R8_HI,
50 	BPF_R8_LO,
51 	BPF_R9_HI,
52 	BPF_R9_LO,
53 	BPF_AX_HI,
54 	BPF_AX_LO,
55 	/* Stack space for BPF_REG_6 through BPF_REG_9 and BPF_REG_AX. */
56 	BPF_JIT_SCRATCH_REGS,
57 };
58 
59 /* Number of callee-saved registers stored to stack: ra, fp, s1--s7. */
60 #define NR_SAVED_REGISTERS	9
61 
62 /* Offset from fp for BPF registers stored on stack. */
63 #define STACK_OFFSET(k)	(-4 - (4 * NR_SAVED_REGISTERS) - (4 * (k)))
64 
65 #define TMP_REG_1	(MAX_BPF_JIT_REG + 0)
66 #define TMP_REG_2	(MAX_BPF_JIT_REG + 1)
67 
68 #define RV_REG_TCC		RV_REG_T6
69 #define RV_REG_TCC_SAVED	RV_REG_S7
70 
71 static const s8 bpf2rv32[][2] = {
72 	/* Return value from in-kernel function, and exit value from eBPF. */
73 	[BPF_REG_0] = {RV_REG_S2, RV_REG_S1},
74 	/* Arguments from eBPF program to in-kernel function. */
75 	[BPF_REG_1] = {RV_REG_A1, RV_REG_A0},
76 	[BPF_REG_2] = {RV_REG_A3, RV_REG_A2},
77 	[BPF_REG_3] = {RV_REG_A5, RV_REG_A4},
78 	[BPF_REG_4] = {RV_REG_A7, RV_REG_A6},
79 	[BPF_REG_5] = {RV_REG_S4, RV_REG_S3},
80 	/*
81 	 * Callee-saved registers that in-kernel function will preserve.
82 	 * Stored on the stack.
83 	 */
84 	[BPF_REG_6] = {STACK_OFFSET(BPF_R6_HI), STACK_OFFSET(BPF_R6_LO)},
85 	[BPF_REG_7] = {STACK_OFFSET(BPF_R7_HI), STACK_OFFSET(BPF_R7_LO)},
86 	[BPF_REG_8] = {STACK_OFFSET(BPF_R8_HI), STACK_OFFSET(BPF_R8_LO)},
87 	[BPF_REG_9] = {STACK_OFFSET(BPF_R9_HI), STACK_OFFSET(BPF_R9_LO)},
88 	/* Read-only frame pointer to access BPF stack. */
89 	[BPF_REG_FP] = {RV_REG_S6, RV_REG_S5},
90 	/* Temporary register for blinding constants. Stored on the stack. */
91 	[BPF_REG_AX] = {STACK_OFFSET(BPF_AX_HI), STACK_OFFSET(BPF_AX_LO)},
92 	/*
93 	 * Temporary registers used by the JIT to operate on registers stored
94 	 * on the stack. Save t0 and t1 to be used as temporaries in generated
95 	 * code.
96 	 */
97 	[TMP_REG_1] = {RV_REG_T3, RV_REG_T2},
98 	[TMP_REG_2] = {RV_REG_T5, RV_REG_T4},
99 };
100 
hi(const s8 * r)101 static s8 hi(const s8 *r)
102 {
103 	return r[0];
104 }
105 
lo(const s8 * r)106 static s8 lo(const s8 *r)
107 {
108 	return r[1];
109 }
110 
emit_imm(const s8 rd,s32 imm,struct rv_jit_context * ctx)111 static void emit_imm(const s8 rd, s32 imm, struct rv_jit_context *ctx)
112 {
113 	u32 upper = (imm + (1 << 11)) >> 12;
114 	u32 lower = imm & 0xfff;
115 
116 	if (upper) {
117 		emit(rv_lui(rd, upper), ctx);
118 		emit(rv_addi(rd, rd, lower), ctx);
119 	} else {
120 		emit(rv_addi(rd, RV_REG_ZERO, lower), ctx);
121 	}
122 }
123 
emit_imm32(const s8 * rd,s32 imm,struct rv_jit_context * ctx)124 static void emit_imm32(const s8 *rd, s32 imm, struct rv_jit_context *ctx)
125 {
126 	/* Emit immediate into lower bits. */
127 	emit_imm(lo(rd), imm, ctx);
128 
129 	/* Sign-extend into upper bits. */
130 	if (imm >= 0)
131 		emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
132 	else
133 		emit(rv_addi(hi(rd), RV_REG_ZERO, -1), ctx);
134 }
135 
emit_imm64(const s8 * rd,s32 imm_hi,s32 imm_lo,struct rv_jit_context * ctx)136 static void emit_imm64(const s8 *rd, s32 imm_hi, s32 imm_lo,
137 		       struct rv_jit_context *ctx)
138 {
139 	emit_imm(lo(rd), imm_lo, ctx);
140 	emit_imm(hi(rd), imm_hi, ctx);
141 }
142 
__build_epilogue(bool is_tail_call,struct rv_jit_context * ctx)143 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
144 {
145 	int stack_adjust = ctx->stack_size;
146 	const s8 *r0 = bpf2rv32[BPF_REG_0];
147 
148 	/* Set return value if not tail call. */
149 	if (!is_tail_call) {
150 		emit(rv_addi(RV_REG_A0, lo(r0), 0), ctx);
151 		emit(rv_addi(RV_REG_A1, hi(r0), 0), ctx);
152 	}
153 
154 	/* Restore callee-saved registers. */
155 	emit(rv_lw(RV_REG_RA, stack_adjust - 4, RV_REG_SP), ctx);
156 	emit(rv_lw(RV_REG_FP, stack_adjust - 8, RV_REG_SP), ctx);
157 	emit(rv_lw(RV_REG_S1, stack_adjust - 12, RV_REG_SP), ctx);
158 	emit(rv_lw(RV_REG_S2, stack_adjust - 16, RV_REG_SP), ctx);
159 	emit(rv_lw(RV_REG_S3, stack_adjust - 20, RV_REG_SP), ctx);
160 	emit(rv_lw(RV_REG_S4, stack_adjust - 24, RV_REG_SP), ctx);
161 	emit(rv_lw(RV_REG_S5, stack_adjust - 28, RV_REG_SP), ctx);
162 	emit(rv_lw(RV_REG_S6, stack_adjust - 32, RV_REG_SP), ctx);
163 	emit(rv_lw(RV_REG_S7, stack_adjust - 36, RV_REG_SP), ctx);
164 
165 	emit(rv_addi(RV_REG_SP, RV_REG_SP, stack_adjust), ctx);
166 
167 	if (is_tail_call) {
168 		/*
169 		 * goto *(t0 + 4);
170 		 * Skips first instruction of prologue which initializes tail
171 		 * call counter. Assumes t0 contains address of target program,
172 		 * see emit_bpf_tail_call.
173 		 */
174 		emit(rv_jalr(RV_REG_ZERO, RV_REG_T0, 4), ctx);
175 	} else {
176 		emit(rv_jalr(RV_REG_ZERO, RV_REG_RA, 0), ctx);
177 	}
178 }
179 
is_stacked(s8 reg)180 static bool is_stacked(s8 reg)
181 {
182 	return reg < 0;
183 }
184 
bpf_get_reg64(const s8 * reg,const s8 * tmp,struct rv_jit_context * ctx)185 static const s8 *bpf_get_reg64(const s8 *reg, const s8 *tmp,
186 			       struct rv_jit_context *ctx)
187 {
188 	if (is_stacked(hi(reg))) {
189 		emit(rv_lw(hi(tmp), hi(reg), RV_REG_FP), ctx);
190 		emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
191 		reg = tmp;
192 	}
193 	return reg;
194 }
195 
bpf_put_reg64(const s8 * reg,const s8 * src,struct rv_jit_context * ctx)196 static void bpf_put_reg64(const s8 *reg, const s8 *src,
197 			  struct rv_jit_context *ctx)
198 {
199 	if (is_stacked(hi(reg))) {
200 		emit(rv_sw(RV_REG_FP, hi(reg), hi(src)), ctx);
201 		emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
202 	}
203 }
204 
bpf_get_reg32(const s8 * reg,const s8 * tmp,struct rv_jit_context * ctx)205 static const s8 *bpf_get_reg32(const s8 *reg, const s8 *tmp,
206 			       struct rv_jit_context *ctx)
207 {
208 	if (is_stacked(lo(reg))) {
209 		emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
210 		reg = tmp;
211 	}
212 	return reg;
213 }
214 
bpf_put_reg32(const s8 * reg,const s8 * src,struct rv_jit_context * ctx)215 static void bpf_put_reg32(const s8 *reg, const s8 *src,
216 			  struct rv_jit_context *ctx)
217 {
218 	if (is_stacked(lo(reg))) {
219 		emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
220 		if (!ctx->prog->aux->verifier_zext)
221 			emit(rv_sw(RV_REG_FP, hi(reg), RV_REG_ZERO), ctx);
222 	} else if (!ctx->prog->aux->verifier_zext) {
223 		emit(rv_addi(hi(reg), RV_REG_ZERO, 0), ctx);
224 	}
225 }
226 
emit_jump_and_link(u8 rd,s32 rvoff,bool force_jalr,struct rv_jit_context * ctx)227 static void emit_jump_and_link(u8 rd, s32 rvoff, bool force_jalr,
228 			       struct rv_jit_context *ctx)
229 {
230 	s32 upper, lower;
231 
232 	if (rvoff && is_21b_int(rvoff) && !force_jalr) {
233 		emit(rv_jal(rd, rvoff >> 1), ctx);
234 		return;
235 	}
236 
237 	upper = (rvoff + (1 << 11)) >> 12;
238 	lower = rvoff & 0xfff;
239 	emit(rv_auipc(RV_REG_T1, upper), ctx);
240 	emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
241 }
242 
emit_alu_i64(const s8 * dst,s32 imm,struct rv_jit_context * ctx,const u8 op)243 static void emit_alu_i64(const s8 *dst, s32 imm,
244 			 struct rv_jit_context *ctx, const u8 op)
245 {
246 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
247 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
248 
249 	switch (op) {
250 	case BPF_MOV:
251 		emit_imm32(rd, imm, ctx);
252 		break;
253 	case BPF_AND:
254 		if (is_12b_int(imm)) {
255 			emit(rv_andi(lo(rd), lo(rd), imm), ctx);
256 		} else {
257 			emit_imm(RV_REG_T0, imm, ctx);
258 			emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
259 		}
260 		if (imm >= 0)
261 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
262 		break;
263 	case BPF_OR:
264 		if (is_12b_int(imm)) {
265 			emit(rv_ori(lo(rd), lo(rd), imm), ctx);
266 		} else {
267 			emit_imm(RV_REG_T0, imm, ctx);
268 			emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
269 		}
270 		if (imm < 0)
271 			emit(rv_ori(hi(rd), RV_REG_ZERO, -1), ctx);
272 		break;
273 	case BPF_XOR:
274 		if (is_12b_int(imm)) {
275 			emit(rv_xori(lo(rd), lo(rd), imm), ctx);
276 		} else {
277 			emit_imm(RV_REG_T0, imm, ctx);
278 			emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
279 		}
280 		if (imm < 0)
281 			emit(rv_xori(hi(rd), hi(rd), -1), ctx);
282 		break;
283 	case BPF_LSH:
284 		if (imm >= 32) {
285 			emit(rv_slli(hi(rd), lo(rd), imm - 32), ctx);
286 			emit(rv_addi(lo(rd), RV_REG_ZERO, 0), ctx);
287 		} else if (imm == 0) {
288 			/* Do nothing. */
289 		} else {
290 			emit(rv_srli(RV_REG_T0, lo(rd), 32 - imm), ctx);
291 			emit(rv_slli(hi(rd), hi(rd), imm), ctx);
292 			emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
293 			emit(rv_slli(lo(rd), lo(rd), imm), ctx);
294 		}
295 		break;
296 	case BPF_RSH:
297 		if (imm >= 32) {
298 			emit(rv_srli(lo(rd), hi(rd), imm - 32), ctx);
299 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
300 		} else if (imm == 0) {
301 			/* Do nothing. */
302 		} else {
303 			emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
304 			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
305 			emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
306 			emit(rv_srli(hi(rd), hi(rd), imm), ctx);
307 		}
308 		break;
309 	case BPF_ARSH:
310 		if (imm >= 32) {
311 			emit(rv_srai(lo(rd), hi(rd), imm - 32), ctx);
312 			emit(rv_srai(hi(rd), hi(rd), 31), ctx);
313 		} else if (imm == 0) {
314 			/* Do nothing. */
315 		} else {
316 			emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
317 			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
318 			emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
319 			emit(rv_srai(hi(rd), hi(rd), imm), ctx);
320 		}
321 		break;
322 	}
323 
324 	bpf_put_reg64(dst, rd, ctx);
325 }
326 
emit_alu_i32(const s8 * dst,s32 imm,struct rv_jit_context * ctx,const u8 op)327 static void emit_alu_i32(const s8 *dst, s32 imm,
328 			 struct rv_jit_context *ctx, const u8 op)
329 {
330 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
331 	const s8 *rd = bpf_get_reg32(dst, tmp1, ctx);
332 
333 	switch (op) {
334 	case BPF_MOV:
335 		emit_imm(lo(rd), imm, ctx);
336 		break;
337 	case BPF_ADD:
338 		if (is_12b_int(imm)) {
339 			emit(rv_addi(lo(rd), lo(rd), imm), ctx);
340 		} else {
341 			emit_imm(RV_REG_T0, imm, ctx);
342 			emit(rv_add(lo(rd), lo(rd), RV_REG_T0), ctx);
343 		}
344 		break;
345 	case BPF_SUB:
346 		if (is_12b_int(-imm)) {
347 			emit(rv_addi(lo(rd), lo(rd), -imm), ctx);
348 		} else {
349 			emit_imm(RV_REG_T0, imm, ctx);
350 			emit(rv_sub(lo(rd), lo(rd), RV_REG_T0), ctx);
351 		}
352 		break;
353 	case BPF_AND:
354 		if (is_12b_int(imm)) {
355 			emit(rv_andi(lo(rd), lo(rd), imm), ctx);
356 		} else {
357 			emit_imm(RV_REG_T0, imm, ctx);
358 			emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
359 		}
360 		break;
361 	case BPF_OR:
362 		if (is_12b_int(imm)) {
363 			emit(rv_ori(lo(rd), lo(rd), imm), ctx);
364 		} else {
365 			emit_imm(RV_REG_T0, imm, ctx);
366 			emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
367 		}
368 		break;
369 	case BPF_XOR:
370 		if (is_12b_int(imm)) {
371 			emit(rv_xori(lo(rd), lo(rd), imm), ctx);
372 		} else {
373 			emit_imm(RV_REG_T0, imm, ctx);
374 			emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
375 		}
376 		break;
377 	case BPF_LSH:
378 		if (is_12b_int(imm)) {
379 			emit(rv_slli(lo(rd), lo(rd), imm), ctx);
380 		} else {
381 			emit_imm(RV_REG_T0, imm, ctx);
382 			emit(rv_sll(lo(rd), lo(rd), RV_REG_T0), ctx);
383 		}
384 		break;
385 	case BPF_RSH:
386 		if (is_12b_int(imm)) {
387 			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
388 		} else {
389 			emit_imm(RV_REG_T0, imm, ctx);
390 			emit(rv_srl(lo(rd), lo(rd), RV_REG_T0), ctx);
391 		}
392 		break;
393 	case BPF_ARSH:
394 		if (is_12b_int(imm)) {
395 			emit(rv_srai(lo(rd), lo(rd), imm), ctx);
396 		} else {
397 			emit_imm(RV_REG_T0, imm, ctx);
398 			emit(rv_sra(lo(rd), lo(rd), RV_REG_T0), ctx);
399 		}
400 		break;
401 	}
402 
403 	bpf_put_reg32(dst, rd, ctx);
404 }
405 
emit_alu_r64(const s8 * dst,const s8 * src,struct rv_jit_context * ctx,const u8 op)406 static void emit_alu_r64(const s8 *dst, const s8 *src,
407 			 struct rv_jit_context *ctx, const u8 op)
408 {
409 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
410 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
411 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
412 	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
413 
414 	switch (op) {
415 	case BPF_MOV:
416 		emit(rv_addi(lo(rd), lo(rs), 0), ctx);
417 		emit(rv_addi(hi(rd), hi(rs), 0), ctx);
418 		break;
419 	case BPF_ADD:
420 		if (rd == rs) {
421 			emit(rv_srli(RV_REG_T0, lo(rd), 31), ctx);
422 			emit(rv_slli(hi(rd), hi(rd), 1), ctx);
423 			emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
424 			emit(rv_slli(lo(rd), lo(rd), 1), ctx);
425 		} else {
426 			emit(rv_add(lo(rd), lo(rd), lo(rs)), ctx);
427 			emit(rv_sltu(RV_REG_T0, lo(rd), lo(rs)), ctx);
428 			emit(rv_add(hi(rd), hi(rd), hi(rs)), ctx);
429 			emit(rv_add(hi(rd), hi(rd), RV_REG_T0), ctx);
430 		}
431 		break;
432 	case BPF_SUB:
433 		emit(rv_sub(RV_REG_T1, hi(rd), hi(rs)), ctx);
434 		emit(rv_sltu(RV_REG_T0, lo(rd), lo(rs)), ctx);
435 		emit(rv_sub(hi(rd), RV_REG_T1, RV_REG_T0), ctx);
436 		emit(rv_sub(lo(rd), lo(rd), lo(rs)), ctx);
437 		break;
438 	case BPF_AND:
439 		emit(rv_and(lo(rd), lo(rd), lo(rs)), ctx);
440 		emit(rv_and(hi(rd), hi(rd), hi(rs)), ctx);
441 		break;
442 	case BPF_OR:
443 		emit(rv_or(lo(rd), lo(rd), lo(rs)), ctx);
444 		emit(rv_or(hi(rd), hi(rd), hi(rs)), ctx);
445 		break;
446 	case BPF_XOR:
447 		emit(rv_xor(lo(rd), lo(rd), lo(rs)), ctx);
448 		emit(rv_xor(hi(rd), hi(rd), hi(rs)), ctx);
449 		break;
450 	case BPF_MUL:
451 		emit(rv_mul(RV_REG_T0, hi(rs), lo(rd)), ctx);
452 		emit(rv_mul(hi(rd), hi(rd), lo(rs)), ctx);
453 		emit(rv_mulhu(RV_REG_T1, lo(rd), lo(rs)), ctx);
454 		emit(rv_add(hi(rd), hi(rd), RV_REG_T0), ctx);
455 		emit(rv_mul(lo(rd), lo(rd), lo(rs)), ctx);
456 		emit(rv_add(hi(rd), hi(rd), RV_REG_T1), ctx);
457 		break;
458 	case BPF_LSH:
459 		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
460 		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
461 		emit(rv_sll(hi(rd), lo(rd), RV_REG_T0), ctx);
462 		emit(rv_addi(lo(rd), RV_REG_ZERO, 0), ctx);
463 		emit(rv_jal(RV_REG_ZERO, 16), ctx);
464 		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
465 		emit(rv_srli(RV_REG_T0, lo(rd), 1), ctx);
466 		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
467 		emit(rv_srl(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
468 		emit(rv_sll(hi(rd), hi(rd), lo(rs)), ctx);
469 		emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
470 		emit(rv_sll(lo(rd), lo(rd), lo(rs)), ctx);
471 		break;
472 	case BPF_RSH:
473 		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
474 		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
475 		emit(rv_srl(lo(rd), hi(rd), RV_REG_T0), ctx);
476 		emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
477 		emit(rv_jal(RV_REG_ZERO, 16), ctx);
478 		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
479 		emit(rv_slli(RV_REG_T0, hi(rd), 1), ctx);
480 		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
481 		emit(rv_sll(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
482 		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
483 		emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
484 		emit(rv_srl(hi(rd), hi(rd), lo(rs)), ctx);
485 		break;
486 	case BPF_ARSH:
487 		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
488 		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
489 		emit(rv_sra(lo(rd), hi(rd), RV_REG_T0), ctx);
490 		emit(rv_srai(hi(rd), hi(rd), 31), ctx);
491 		emit(rv_jal(RV_REG_ZERO, 16), ctx);
492 		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
493 		emit(rv_slli(RV_REG_T0, hi(rd), 1), ctx);
494 		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
495 		emit(rv_sll(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
496 		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
497 		emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
498 		emit(rv_sra(hi(rd), hi(rd), lo(rs)), ctx);
499 		break;
500 	case BPF_NEG:
501 		emit(rv_sub(lo(rd), RV_REG_ZERO, lo(rd)), ctx);
502 		emit(rv_sltu(RV_REG_T0, RV_REG_ZERO, lo(rd)), ctx);
503 		emit(rv_sub(hi(rd), RV_REG_ZERO, hi(rd)), ctx);
504 		emit(rv_sub(hi(rd), hi(rd), RV_REG_T0), ctx);
505 		break;
506 	}
507 
508 	bpf_put_reg64(dst, rd, ctx);
509 }
510 
emit_alu_r32(const s8 * dst,const s8 * src,struct rv_jit_context * ctx,const u8 op)511 static void emit_alu_r32(const s8 *dst, const s8 *src,
512 			 struct rv_jit_context *ctx, const u8 op)
513 {
514 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
515 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
516 	const s8 *rd = bpf_get_reg32(dst, tmp1, ctx);
517 	const s8 *rs = bpf_get_reg32(src, tmp2, ctx);
518 
519 	switch (op) {
520 	case BPF_MOV:
521 		emit(rv_addi(lo(rd), lo(rs), 0), ctx);
522 		break;
523 	case BPF_ADD:
524 		emit(rv_add(lo(rd), lo(rd), lo(rs)), ctx);
525 		break;
526 	case BPF_SUB:
527 		emit(rv_sub(lo(rd), lo(rd), lo(rs)), ctx);
528 		break;
529 	case BPF_AND:
530 		emit(rv_and(lo(rd), lo(rd), lo(rs)), ctx);
531 		break;
532 	case BPF_OR:
533 		emit(rv_or(lo(rd), lo(rd), lo(rs)), ctx);
534 		break;
535 	case BPF_XOR:
536 		emit(rv_xor(lo(rd), lo(rd), lo(rs)), ctx);
537 		break;
538 	case BPF_MUL:
539 		emit(rv_mul(lo(rd), lo(rd), lo(rs)), ctx);
540 		break;
541 	case BPF_DIV:
542 		emit(rv_divu(lo(rd), lo(rd), lo(rs)), ctx);
543 		break;
544 	case BPF_MOD:
545 		emit(rv_remu(lo(rd), lo(rd), lo(rs)), ctx);
546 		break;
547 	case BPF_LSH:
548 		emit(rv_sll(lo(rd), lo(rd), lo(rs)), ctx);
549 		break;
550 	case BPF_RSH:
551 		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
552 		break;
553 	case BPF_ARSH:
554 		emit(rv_sra(lo(rd), lo(rd), lo(rs)), ctx);
555 		break;
556 	case BPF_NEG:
557 		emit(rv_sub(lo(rd), RV_REG_ZERO, lo(rd)), ctx);
558 		break;
559 	}
560 
561 	bpf_put_reg32(dst, rd, ctx);
562 }
563 
emit_branch_r64(const s8 * src1,const s8 * src2,s32 rvoff,struct rv_jit_context * ctx,const u8 op)564 static int emit_branch_r64(const s8 *src1, const s8 *src2, s32 rvoff,
565 			   struct rv_jit_context *ctx, const u8 op)
566 {
567 	int e, s = ctx->ninsns;
568 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
569 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
570 
571 	const s8 *rs1 = bpf_get_reg64(src1, tmp1, ctx);
572 	const s8 *rs2 = bpf_get_reg64(src2, tmp2, ctx);
573 
574 	/*
575 	 * NO_JUMP skips over the rest of the instructions and the
576 	 * emit_jump_and_link, meaning the BPF branch is not taken.
577 	 * JUMP skips directly to the emit_jump_and_link, meaning
578 	 * the BPF branch is taken.
579 	 *
580 	 * The fallthrough case results in the BPF branch being taken.
581 	 */
582 #define NO_JUMP(idx) (6 + (2 * (idx)))
583 #define JUMP(idx) (2 + (2 * (idx)))
584 
585 	switch (op) {
586 	case BPF_JEQ:
587 		emit(rv_bne(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
588 		emit(rv_bne(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
589 		break;
590 	case BPF_JGT:
591 		emit(rv_bgtu(hi(rs1), hi(rs2), JUMP(2)), ctx);
592 		emit(rv_bltu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
593 		emit(rv_bleu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
594 		break;
595 	case BPF_JLT:
596 		emit(rv_bltu(hi(rs1), hi(rs2), JUMP(2)), ctx);
597 		emit(rv_bgtu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
598 		emit(rv_bgeu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
599 		break;
600 	case BPF_JGE:
601 		emit(rv_bgtu(hi(rs1), hi(rs2), JUMP(2)), ctx);
602 		emit(rv_bltu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
603 		emit(rv_bltu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
604 		break;
605 	case BPF_JLE:
606 		emit(rv_bltu(hi(rs1), hi(rs2), JUMP(2)), ctx);
607 		emit(rv_bgtu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
608 		emit(rv_bgtu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
609 		break;
610 	case BPF_JNE:
611 		emit(rv_bne(hi(rs1), hi(rs2), JUMP(1)), ctx);
612 		emit(rv_beq(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
613 		break;
614 	case BPF_JSGT:
615 		emit(rv_bgt(hi(rs1), hi(rs2), JUMP(2)), ctx);
616 		emit(rv_blt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
617 		emit(rv_bleu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
618 		break;
619 	case BPF_JSLT:
620 		emit(rv_blt(hi(rs1), hi(rs2), JUMP(2)), ctx);
621 		emit(rv_bgt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
622 		emit(rv_bgeu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
623 		break;
624 	case BPF_JSGE:
625 		emit(rv_bgt(hi(rs1), hi(rs2), JUMP(2)), ctx);
626 		emit(rv_blt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
627 		emit(rv_bltu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
628 		break;
629 	case BPF_JSLE:
630 		emit(rv_blt(hi(rs1), hi(rs2), JUMP(2)), ctx);
631 		emit(rv_bgt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
632 		emit(rv_bgtu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
633 		break;
634 	case BPF_JSET:
635 		emit(rv_and(RV_REG_T0, hi(rs1), hi(rs2)), ctx);
636 		emit(rv_bne(RV_REG_T0, RV_REG_ZERO, JUMP(2)), ctx);
637 		emit(rv_and(RV_REG_T0, lo(rs1), lo(rs2)), ctx);
638 		emit(rv_beq(RV_REG_T0, RV_REG_ZERO, NO_JUMP(0)), ctx);
639 		break;
640 	}
641 
642 #undef NO_JUMP
643 #undef JUMP
644 
645 	e = ctx->ninsns;
646 	/* Adjust for extra insns. */
647 	rvoff -= ninsns_rvoff(e - s);
648 	emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
649 	return 0;
650 }
651 
emit_bcc(u8 op,u8 rd,u8 rs,int rvoff,struct rv_jit_context * ctx)652 static int emit_bcc(u8 op, u8 rd, u8 rs, int rvoff, struct rv_jit_context *ctx)
653 {
654 	int e, s = ctx->ninsns;
655 	bool far = false;
656 	int off;
657 
658 	if (op == BPF_JSET) {
659 		/*
660 		 * BPF_JSET is a special case: it has no inverse so we always
661 		 * treat it as a far branch.
662 		 */
663 		far = true;
664 	} else if (!is_13b_int(rvoff)) {
665 		op = invert_bpf_cond(op);
666 		far = true;
667 	}
668 
669 	/*
670 	 * For a far branch, the condition is negated and we jump over the
671 	 * branch itself, and the two instructions from emit_jump_and_link.
672 	 * For a near branch, just use rvoff.
673 	 */
674 	off = far ? 6 : (rvoff >> 1);
675 
676 	switch (op) {
677 	case BPF_JEQ:
678 		emit(rv_beq(rd, rs, off), ctx);
679 		break;
680 	case BPF_JGT:
681 		emit(rv_bgtu(rd, rs, off), ctx);
682 		break;
683 	case BPF_JLT:
684 		emit(rv_bltu(rd, rs, off), ctx);
685 		break;
686 	case BPF_JGE:
687 		emit(rv_bgeu(rd, rs, off), ctx);
688 		break;
689 	case BPF_JLE:
690 		emit(rv_bleu(rd, rs, off), ctx);
691 		break;
692 	case BPF_JNE:
693 		emit(rv_bne(rd, rs, off), ctx);
694 		break;
695 	case BPF_JSGT:
696 		emit(rv_bgt(rd, rs, off), ctx);
697 		break;
698 	case BPF_JSLT:
699 		emit(rv_blt(rd, rs, off), ctx);
700 		break;
701 	case BPF_JSGE:
702 		emit(rv_bge(rd, rs, off), ctx);
703 		break;
704 	case BPF_JSLE:
705 		emit(rv_ble(rd, rs, off), ctx);
706 		break;
707 	case BPF_JSET:
708 		emit(rv_and(RV_REG_T0, rd, rs), ctx);
709 		emit(rv_beq(RV_REG_T0, RV_REG_ZERO, off), ctx);
710 		break;
711 	}
712 
713 	if (far) {
714 		e = ctx->ninsns;
715 		/* Adjust for extra insns. */
716 		rvoff -= ninsns_rvoff(e - s);
717 		emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
718 	}
719 	return 0;
720 }
721 
emit_branch_r32(const s8 * src1,const s8 * src2,s32 rvoff,struct rv_jit_context * ctx,const u8 op)722 static int emit_branch_r32(const s8 *src1, const s8 *src2, s32 rvoff,
723 			   struct rv_jit_context *ctx, const u8 op)
724 {
725 	int e, s = ctx->ninsns;
726 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
727 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
728 
729 	const s8 *rs1 = bpf_get_reg32(src1, tmp1, ctx);
730 	const s8 *rs2 = bpf_get_reg32(src2, tmp2, ctx);
731 
732 	e = ctx->ninsns;
733 	/* Adjust for extra insns. */
734 	rvoff -= ninsns_rvoff(e - s);
735 
736 	if (emit_bcc(op, lo(rs1), lo(rs2), rvoff, ctx))
737 		return -1;
738 
739 	return 0;
740 }
741 
emit_call(bool fixed,u64 addr,struct rv_jit_context * ctx)742 static void emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
743 {
744 	const s8 *r0 = bpf2rv32[BPF_REG_0];
745 	const s8 *r5 = bpf2rv32[BPF_REG_5];
746 	u32 upper = ((u32)addr + (1 << 11)) >> 12;
747 	u32 lower = addr & 0xfff;
748 
749 	/* R1-R4 already in correct registers---need to push R5 to stack. */
750 	emit(rv_addi(RV_REG_SP, RV_REG_SP, -16), ctx);
751 	emit(rv_sw(RV_REG_SP, 0, lo(r5)), ctx);
752 	emit(rv_sw(RV_REG_SP, 4, hi(r5)), ctx);
753 
754 	/* Backup TCC. */
755 	emit(rv_addi(RV_REG_TCC_SAVED, RV_REG_TCC, 0), ctx);
756 
757 	/*
758 	 * Use lui/jalr pair to jump to absolute address. Don't use emit_imm as
759 	 * the number of emitted instructions should not depend on the value of
760 	 * addr.
761 	 */
762 	emit(rv_lui(RV_REG_T1, upper), ctx);
763 	emit(rv_jalr(RV_REG_RA, RV_REG_T1, lower), ctx);
764 
765 	/* Restore TCC. */
766 	emit(rv_addi(RV_REG_TCC, RV_REG_TCC_SAVED, 0), ctx);
767 
768 	/* Set return value and restore stack. */
769 	emit(rv_addi(lo(r0), RV_REG_A0, 0), ctx);
770 	emit(rv_addi(hi(r0), RV_REG_A1, 0), ctx);
771 	emit(rv_addi(RV_REG_SP, RV_REG_SP, 16), ctx);
772 }
773 
emit_bpf_tail_call(int insn,struct rv_jit_context * ctx)774 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
775 {
776 	/*
777 	 * R1 -> &ctx
778 	 * R2 -> &array
779 	 * R3 -> index
780 	 */
781 	int tc_ninsn, off, start_insn = ctx->ninsns;
782 	const s8 *arr_reg = bpf2rv32[BPF_REG_2];
783 	const s8 *idx_reg = bpf2rv32[BPF_REG_3];
784 
785 	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
786 		ctx->offset[0];
787 
788 	/* max_entries = array->map.max_entries; */
789 	off = offsetof(struct bpf_array, map.max_entries);
790 	if (is_12b_check(off, insn))
791 		return -1;
792 	emit(rv_lw(RV_REG_T1, off, lo(arr_reg)), ctx);
793 
794 	/*
795 	 * if (index >= max_entries)
796 	 *   goto out;
797 	 */
798 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
799 	emit_bcc(BPF_JGE, lo(idx_reg), RV_REG_T1, off, ctx);
800 
801 	/*
802 	 * if (--tcc < 0)
803 	 *   goto out;
804 	 */
805 	emit(rv_addi(RV_REG_TCC, RV_REG_TCC, -1), ctx);
806 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
807 	emit_bcc(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
808 
809 	/*
810 	 * prog = array->ptrs[index];
811 	 * if (!prog)
812 	 *   goto out;
813 	 */
814 	emit(rv_slli(RV_REG_T0, lo(idx_reg), 2), ctx);
815 	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(arr_reg)), ctx);
816 	off = offsetof(struct bpf_array, ptrs);
817 	if (is_12b_check(off, insn))
818 		return -1;
819 	emit(rv_lw(RV_REG_T0, off, RV_REG_T0), ctx);
820 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
821 	emit_bcc(BPF_JEQ, RV_REG_T0, RV_REG_ZERO, off, ctx);
822 
823 	/*
824 	 * tcc = temp_tcc;
825 	 * goto *(prog->bpf_func + 4);
826 	 */
827 	off = offsetof(struct bpf_prog, bpf_func);
828 	if (is_12b_check(off, insn))
829 		return -1;
830 	emit(rv_lw(RV_REG_T0, off, RV_REG_T0), ctx);
831 	/* Epilogue jumps to *(t0 + 4). */
832 	__build_epilogue(true, ctx);
833 	return 0;
834 }
835 
emit_load_r64(const s8 * dst,const s8 * src,s16 off,struct rv_jit_context * ctx,const u8 size)836 static int emit_load_r64(const s8 *dst, const s8 *src, s16 off,
837 			 struct rv_jit_context *ctx, const u8 size)
838 {
839 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
840 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
841 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
842 	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
843 
844 	emit_imm(RV_REG_T0, off, ctx);
845 	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(rs)), ctx);
846 
847 	switch (size) {
848 	case BPF_B:
849 		emit(rv_lbu(lo(rd), 0, RV_REG_T0), ctx);
850 		if (!ctx->prog->aux->verifier_zext)
851 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
852 		break;
853 	case BPF_H:
854 		emit(rv_lhu(lo(rd), 0, RV_REG_T0), ctx);
855 		if (!ctx->prog->aux->verifier_zext)
856 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
857 		break;
858 	case BPF_W:
859 		emit(rv_lw(lo(rd), 0, RV_REG_T0), ctx);
860 		if (!ctx->prog->aux->verifier_zext)
861 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
862 		break;
863 	case BPF_DW:
864 		emit(rv_lw(lo(rd), 0, RV_REG_T0), ctx);
865 		emit(rv_lw(hi(rd), 4, RV_REG_T0), ctx);
866 		break;
867 	}
868 
869 	bpf_put_reg64(dst, rd, ctx);
870 	return 0;
871 }
872 
emit_store_r64(const s8 * dst,const s8 * src,s16 off,struct rv_jit_context * ctx,const u8 size,const u8 mode)873 static int emit_store_r64(const s8 *dst, const s8 *src, s16 off,
874 			  struct rv_jit_context *ctx, const u8 size,
875 			  const u8 mode)
876 {
877 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
878 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
879 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
880 	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
881 
882 	if (mode == BPF_ATOMIC && size != BPF_W)
883 		return -1;
884 
885 	emit_imm(RV_REG_T0, off, ctx);
886 	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(rd)), ctx);
887 
888 	switch (size) {
889 	case BPF_B:
890 		emit(rv_sb(RV_REG_T0, 0, lo(rs)), ctx);
891 		break;
892 	case BPF_H:
893 		emit(rv_sh(RV_REG_T0, 0, lo(rs)), ctx);
894 		break;
895 	case BPF_W:
896 		switch (mode) {
897 		case BPF_MEM:
898 			emit(rv_sw(RV_REG_T0, 0, lo(rs)), ctx);
899 			break;
900 		case BPF_ATOMIC: /* Only BPF_ADD supported */
901 			emit(rv_amoadd_w(RV_REG_ZERO, lo(rs), RV_REG_T0, 0, 0),
902 			     ctx);
903 			break;
904 		}
905 		break;
906 	case BPF_DW:
907 		emit(rv_sw(RV_REG_T0, 0, lo(rs)), ctx);
908 		emit(rv_sw(RV_REG_T0, 4, hi(rs)), ctx);
909 		break;
910 	}
911 
912 	return 0;
913 }
914 
emit_rev16(const s8 rd,struct rv_jit_context * ctx)915 static void emit_rev16(const s8 rd, struct rv_jit_context *ctx)
916 {
917 	emit(rv_slli(rd, rd, 16), ctx);
918 	emit(rv_slli(RV_REG_T1, rd, 8), ctx);
919 	emit(rv_srli(rd, rd, 8), ctx);
920 	emit(rv_add(RV_REG_T1, rd, RV_REG_T1), ctx);
921 	emit(rv_srli(rd, RV_REG_T1, 16), ctx);
922 }
923 
emit_rev32(const s8 rd,struct rv_jit_context * ctx)924 static void emit_rev32(const s8 rd, struct rv_jit_context *ctx)
925 {
926 	emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 0), ctx);
927 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
928 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
929 	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
930 	emit(rv_srli(rd, rd, 8), ctx);
931 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
932 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
933 	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
934 	emit(rv_srli(rd, rd, 8), ctx);
935 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
936 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
937 	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
938 	emit(rv_srli(rd, rd, 8), ctx);
939 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
940 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
941 	emit(rv_addi(rd, RV_REG_T1, 0), ctx);
942 }
943 
emit_zext64(const s8 * dst,struct rv_jit_context * ctx)944 static void emit_zext64(const s8 *dst, struct rv_jit_context *ctx)
945 {
946 	const s8 *rd;
947 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
948 
949 	rd = bpf_get_reg64(dst, tmp1, ctx);
950 	emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
951 	bpf_put_reg64(dst, rd, ctx);
952 }
953 
bpf_jit_emit_insn(const struct bpf_insn * insn,struct rv_jit_context * ctx,bool extra_pass)954 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
955 		      bool extra_pass)
956 {
957 	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
958 		BPF_CLASS(insn->code) == BPF_JMP;
959 	int s, e, rvoff, i = insn - ctx->prog->insnsi;
960 	u8 code = insn->code;
961 	s16 off = insn->off;
962 	s32 imm = insn->imm;
963 
964 	const s8 *dst = bpf2rv32[insn->dst_reg];
965 	const s8 *src = bpf2rv32[insn->src_reg];
966 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
967 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
968 
969 	switch (code) {
970 	case BPF_ALU64 | BPF_MOV | BPF_X:
971 
972 	case BPF_ALU64 | BPF_ADD | BPF_X:
973 	case BPF_ALU64 | BPF_ADD | BPF_K:
974 
975 	case BPF_ALU64 | BPF_SUB | BPF_X:
976 	case BPF_ALU64 | BPF_SUB | BPF_K:
977 
978 	case BPF_ALU64 | BPF_AND | BPF_X:
979 	case BPF_ALU64 | BPF_OR | BPF_X:
980 	case BPF_ALU64 | BPF_XOR | BPF_X:
981 
982 	case BPF_ALU64 | BPF_MUL | BPF_X:
983 	case BPF_ALU64 | BPF_MUL | BPF_K:
984 
985 	case BPF_ALU64 | BPF_LSH | BPF_X:
986 	case BPF_ALU64 | BPF_RSH | BPF_X:
987 	case BPF_ALU64 | BPF_ARSH | BPF_X:
988 		if (BPF_SRC(code) == BPF_K) {
989 			emit_imm32(tmp2, imm, ctx);
990 			src = tmp2;
991 		}
992 		emit_alu_r64(dst, src, ctx, BPF_OP(code));
993 		break;
994 
995 	case BPF_ALU64 | BPF_NEG:
996 		emit_alu_r64(dst, tmp2, ctx, BPF_OP(code));
997 		break;
998 
999 	case BPF_ALU64 | BPF_DIV | BPF_X:
1000 	case BPF_ALU64 | BPF_DIV | BPF_K:
1001 	case BPF_ALU64 | BPF_MOD | BPF_X:
1002 	case BPF_ALU64 | BPF_MOD | BPF_K:
1003 		goto notsupported;
1004 
1005 	case BPF_ALU64 | BPF_MOV | BPF_K:
1006 	case BPF_ALU64 | BPF_AND | BPF_K:
1007 	case BPF_ALU64 | BPF_OR | BPF_K:
1008 	case BPF_ALU64 | BPF_XOR | BPF_K:
1009 	case BPF_ALU64 | BPF_LSH | BPF_K:
1010 	case BPF_ALU64 | BPF_RSH | BPF_K:
1011 	case BPF_ALU64 | BPF_ARSH | BPF_K:
1012 		emit_alu_i64(dst, imm, ctx, BPF_OP(code));
1013 		break;
1014 
1015 	case BPF_ALU | BPF_MOV | BPF_X:
1016 		if (imm == 1) {
1017 			/* Special mov32 for zext. */
1018 			emit_zext64(dst, ctx);
1019 			break;
1020 		}
1021 		fallthrough;
1022 
1023 	case BPF_ALU | BPF_ADD | BPF_X:
1024 	case BPF_ALU | BPF_SUB | BPF_X:
1025 	case BPF_ALU | BPF_AND | BPF_X:
1026 	case BPF_ALU | BPF_OR | BPF_X:
1027 	case BPF_ALU | BPF_XOR | BPF_X:
1028 
1029 	case BPF_ALU | BPF_MUL | BPF_X:
1030 	case BPF_ALU | BPF_MUL | BPF_K:
1031 
1032 	case BPF_ALU | BPF_DIV | BPF_X:
1033 	case BPF_ALU | BPF_DIV | BPF_K:
1034 
1035 	case BPF_ALU | BPF_MOD | BPF_X:
1036 	case BPF_ALU | BPF_MOD | BPF_K:
1037 
1038 	case BPF_ALU | BPF_LSH | BPF_X:
1039 	case BPF_ALU | BPF_RSH | BPF_X:
1040 	case BPF_ALU | BPF_ARSH | BPF_X:
1041 		if (BPF_SRC(code) == BPF_K) {
1042 			emit_imm32(tmp2, imm, ctx);
1043 			src = tmp2;
1044 		}
1045 		emit_alu_r32(dst, src, ctx, BPF_OP(code));
1046 		break;
1047 
1048 	case BPF_ALU | BPF_MOV | BPF_K:
1049 	case BPF_ALU | BPF_ADD | BPF_K:
1050 	case BPF_ALU | BPF_SUB | BPF_K:
1051 	case BPF_ALU | BPF_AND | BPF_K:
1052 	case BPF_ALU | BPF_OR | BPF_K:
1053 	case BPF_ALU | BPF_XOR | BPF_K:
1054 	case BPF_ALU | BPF_LSH | BPF_K:
1055 	case BPF_ALU | BPF_RSH | BPF_K:
1056 	case BPF_ALU | BPF_ARSH | BPF_K:
1057 		/*
1058 		 * mul,div,mod are handled in the BPF_X case since there are
1059 		 * no RISC-V I-type equivalents.
1060 		 */
1061 		emit_alu_i32(dst, imm, ctx, BPF_OP(code));
1062 		break;
1063 
1064 	case BPF_ALU | BPF_NEG:
1065 		/*
1066 		 * src is ignored---choose tmp2 as a dummy register since it
1067 		 * is not on the stack.
1068 		 */
1069 		emit_alu_r32(dst, tmp2, ctx, BPF_OP(code));
1070 		break;
1071 
1072 	case BPF_ALU | BPF_END | BPF_FROM_LE:
1073 	{
1074 		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1075 
1076 		switch (imm) {
1077 		case 16:
1078 			emit(rv_slli(lo(rd), lo(rd), 16), ctx);
1079 			emit(rv_srli(lo(rd), lo(rd), 16), ctx);
1080 			fallthrough;
1081 		case 32:
1082 			if (!ctx->prog->aux->verifier_zext)
1083 				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1084 			break;
1085 		case 64:
1086 			/* Do nothing. */
1087 			break;
1088 		default:
1089 			pr_err("bpf-jit: BPF_END imm %d invalid\n", imm);
1090 			return -1;
1091 		}
1092 
1093 		bpf_put_reg64(dst, rd, ctx);
1094 		break;
1095 	}
1096 
1097 	case BPF_ALU | BPF_END | BPF_FROM_BE:
1098 	{
1099 		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1100 
1101 		switch (imm) {
1102 		case 16:
1103 			emit_rev16(lo(rd), ctx);
1104 			if (!ctx->prog->aux->verifier_zext)
1105 				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1106 			break;
1107 		case 32:
1108 			emit_rev32(lo(rd), ctx);
1109 			if (!ctx->prog->aux->verifier_zext)
1110 				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1111 			break;
1112 		case 64:
1113 			/* Swap upper and lower halves. */
1114 			emit(rv_addi(RV_REG_T0, lo(rd), 0), ctx);
1115 			emit(rv_addi(lo(rd), hi(rd), 0), ctx);
1116 			emit(rv_addi(hi(rd), RV_REG_T0, 0), ctx);
1117 
1118 			/* Swap each half. */
1119 			emit_rev32(lo(rd), ctx);
1120 			emit_rev32(hi(rd), ctx);
1121 			break;
1122 		default:
1123 			pr_err("bpf-jit: BPF_END imm %d invalid\n", imm);
1124 			return -1;
1125 		}
1126 
1127 		bpf_put_reg64(dst, rd, ctx);
1128 		break;
1129 	}
1130 
1131 	case BPF_JMP | BPF_JA:
1132 		rvoff = rv_offset(i, off, ctx);
1133 		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
1134 		break;
1135 
1136 	case BPF_JMP | BPF_CALL:
1137 	{
1138 		bool fixed;
1139 		int ret;
1140 		u64 addr;
1141 
1142 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &addr,
1143 					    &fixed);
1144 		if (ret < 0)
1145 			return ret;
1146 		emit_call(fixed, addr, ctx);
1147 		break;
1148 	}
1149 
1150 	case BPF_JMP | BPF_TAIL_CALL:
1151 		if (emit_bpf_tail_call(i, ctx))
1152 			return -1;
1153 		break;
1154 
1155 	case BPF_JMP | BPF_JEQ | BPF_X:
1156 	case BPF_JMP | BPF_JEQ | BPF_K:
1157 	case BPF_JMP32 | BPF_JEQ | BPF_X:
1158 	case BPF_JMP32 | BPF_JEQ | BPF_K:
1159 
1160 	case BPF_JMP | BPF_JNE | BPF_X:
1161 	case BPF_JMP | BPF_JNE | BPF_K:
1162 	case BPF_JMP32 | BPF_JNE | BPF_X:
1163 	case BPF_JMP32 | BPF_JNE | BPF_K:
1164 
1165 	case BPF_JMP | BPF_JLE | BPF_X:
1166 	case BPF_JMP | BPF_JLE | BPF_K:
1167 	case BPF_JMP32 | BPF_JLE | BPF_X:
1168 	case BPF_JMP32 | BPF_JLE | BPF_K:
1169 
1170 	case BPF_JMP | BPF_JLT | BPF_X:
1171 	case BPF_JMP | BPF_JLT | BPF_K:
1172 	case BPF_JMP32 | BPF_JLT | BPF_X:
1173 	case BPF_JMP32 | BPF_JLT | BPF_K:
1174 
1175 	case BPF_JMP | BPF_JGE | BPF_X:
1176 	case BPF_JMP | BPF_JGE | BPF_K:
1177 	case BPF_JMP32 | BPF_JGE | BPF_X:
1178 	case BPF_JMP32 | BPF_JGE | BPF_K:
1179 
1180 	case BPF_JMP | BPF_JGT | BPF_X:
1181 	case BPF_JMP | BPF_JGT | BPF_K:
1182 	case BPF_JMP32 | BPF_JGT | BPF_X:
1183 	case BPF_JMP32 | BPF_JGT | BPF_K:
1184 
1185 	case BPF_JMP | BPF_JSLE | BPF_X:
1186 	case BPF_JMP | BPF_JSLE | BPF_K:
1187 	case BPF_JMP32 | BPF_JSLE | BPF_X:
1188 	case BPF_JMP32 | BPF_JSLE | BPF_K:
1189 
1190 	case BPF_JMP | BPF_JSLT | BPF_X:
1191 	case BPF_JMP | BPF_JSLT | BPF_K:
1192 	case BPF_JMP32 | BPF_JSLT | BPF_X:
1193 	case BPF_JMP32 | BPF_JSLT | BPF_K:
1194 
1195 	case BPF_JMP | BPF_JSGE | BPF_X:
1196 	case BPF_JMP | BPF_JSGE | BPF_K:
1197 	case BPF_JMP32 | BPF_JSGE | BPF_X:
1198 	case BPF_JMP32 | BPF_JSGE | BPF_K:
1199 
1200 	case BPF_JMP | BPF_JSGT | BPF_X:
1201 	case BPF_JMP | BPF_JSGT | BPF_K:
1202 	case BPF_JMP32 | BPF_JSGT | BPF_X:
1203 	case BPF_JMP32 | BPF_JSGT | BPF_K:
1204 
1205 	case BPF_JMP | BPF_JSET | BPF_X:
1206 	case BPF_JMP | BPF_JSET | BPF_K:
1207 	case BPF_JMP32 | BPF_JSET | BPF_X:
1208 	case BPF_JMP32 | BPF_JSET | BPF_K:
1209 		rvoff = rv_offset(i, off, ctx);
1210 		if (BPF_SRC(code) == BPF_K) {
1211 			s = ctx->ninsns;
1212 			emit_imm32(tmp2, imm, ctx);
1213 			src = tmp2;
1214 			e = ctx->ninsns;
1215 			rvoff -= ninsns_rvoff(e - s);
1216 		}
1217 
1218 		if (is64)
1219 			emit_branch_r64(dst, src, rvoff, ctx, BPF_OP(code));
1220 		else
1221 			emit_branch_r32(dst, src, rvoff, ctx, BPF_OP(code));
1222 		break;
1223 
1224 	case BPF_JMP | BPF_EXIT:
1225 		if (i == ctx->prog->len - 1)
1226 			break;
1227 
1228 		rvoff = epilogue_offset(ctx);
1229 		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
1230 		break;
1231 
1232 	case BPF_LD | BPF_IMM | BPF_DW:
1233 	{
1234 		struct bpf_insn insn1 = insn[1];
1235 		s32 imm_lo = imm;
1236 		s32 imm_hi = insn1.imm;
1237 		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1238 
1239 		emit_imm64(rd, imm_hi, imm_lo, ctx);
1240 		bpf_put_reg64(dst, rd, ctx);
1241 		return 1;
1242 	}
1243 
1244 	case BPF_LDX | BPF_MEM | BPF_B:
1245 	case BPF_LDX | BPF_MEM | BPF_H:
1246 	case BPF_LDX | BPF_MEM | BPF_W:
1247 	case BPF_LDX | BPF_MEM | BPF_DW:
1248 		if (emit_load_r64(dst, src, off, ctx, BPF_SIZE(code)))
1249 			return -1;
1250 		break;
1251 
1252 	/* speculation barrier */
1253 	case BPF_ST | BPF_NOSPEC:
1254 		break;
1255 
1256 	case BPF_ST | BPF_MEM | BPF_B:
1257 	case BPF_ST | BPF_MEM | BPF_H:
1258 	case BPF_ST | BPF_MEM | BPF_W:
1259 	case BPF_ST | BPF_MEM | BPF_DW:
1260 
1261 	case BPF_STX | BPF_MEM | BPF_B:
1262 	case BPF_STX | BPF_MEM | BPF_H:
1263 	case BPF_STX | BPF_MEM | BPF_W:
1264 	case BPF_STX | BPF_MEM | BPF_DW:
1265 		if (BPF_CLASS(code) == BPF_ST) {
1266 			emit_imm32(tmp2, imm, ctx);
1267 			src = tmp2;
1268 		}
1269 
1270 		if (emit_store_r64(dst, src, off, ctx, BPF_SIZE(code),
1271 				   BPF_MODE(code)))
1272 			return -1;
1273 		break;
1274 
1275 	case BPF_STX | BPF_ATOMIC | BPF_W:
1276 		if (insn->imm != BPF_ADD) {
1277 			pr_info_once(
1278 				"bpf-jit: not supported: atomic operation %02x ***\n",
1279 				insn->imm);
1280 			return -EFAULT;
1281 		}
1282 
1283 		if (emit_store_r64(dst, src, off, ctx, BPF_SIZE(code),
1284 				   BPF_MODE(code)))
1285 			return -1;
1286 		break;
1287 
1288 	/* No hardware support for 8-byte atomics in RV32. */
1289 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1290 		/* Fallthrough. */
1291 
1292 notsupported:
1293 		pr_info_once("bpf-jit: not supported: opcode %02x ***\n", code);
1294 		return -EFAULT;
1295 
1296 	default:
1297 		pr_err("bpf-jit: unknown opcode %02x\n", code);
1298 		return -EINVAL;
1299 	}
1300 
1301 	return 0;
1302 }
1303 
bpf_jit_build_prologue(struct rv_jit_context * ctx)1304 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1305 {
1306 	const s8 *fp = bpf2rv32[BPF_REG_FP];
1307 	const s8 *r1 = bpf2rv32[BPF_REG_1];
1308 	int stack_adjust = 0;
1309 	int bpf_stack_adjust =
1310 		round_up(ctx->prog->aux->stack_depth, STACK_ALIGN);
1311 
1312 	/* Make space for callee-saved registers. */
1313 	stack_adjust += NR_SAVED_REGISTERS * sizeof(u32);
1314 	/* Make space for BPF registers on stack. */
1315 	stack_adjust += BPF_JIT_SCRATCH_REGS * sizeof(u32);
1316 	/* Make space for BPF stack. */
1317 	stack_adjust += bpf_stack_adjust;
1318 	/* Round up for stack alignment. */
1319 	stack_adjust = round_up(stack_adjust, STACK_ALIGN);
1320 
1321 	/*
1322 	 * The first instruction sets the tail-call-counter (TCC) register.
1323 	 * This instruction is skipped by tail calls.
1324 	 */
1325 	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1326 
1327 	emit(rv_addi(RV_REG_SP, RV_REG_SP, -stack_adjust), ctx);
1328 
1329 	/* Save callee-save registers. */
1330 	emit(rv_sw(RV_REG_SP, stack_adjust - 4, RV_REG_RA), ctx);
1331 	emit(rv_sw(RV_REG_SP, stack_adjust - 8, RV_REG_FP), ctx);
1332 	emit(rv_sw(RV_REG_SP, stack_adjust - 12, RV_REG_S1), ctx);
1333 	emit(rv_sw(RV_REG_SP, stack_adjust - 16, RV_REG_S2), ctx);
1334 	emit(rv_sw(RV_REG_SP, stack_adjust - 20, RV_REG_S3), ctx);
1335 	emit(rv_sw(RV_REG_SP, stack_adjust - 24, RV_REG_S4), ctx);
1336 	emit(rv_sw(RV_REG_SP, stack_adjust - 28, RV_REG_S5), ctx);
1337 	emit(rv_sw(RV_REG_SP, stack_adjust - 32, RV_REG_S6), ctx);
1338 	emit(rv_sw(RV_REG_SP, stack_adjust - 36, RV_REG_S7), ctx);
1339 
1340 	/* Set fp: used as the base address for stacked BPF registers. */
1341 	emit(rv_addi(RV_REG_FP, RV_REG_SP, stack_adjust), ctx);
1342 
1343 	/* Set up BPF frame pointer. */
1344 	emit(rv_addi(lo(fp), RV_REG_SP, bpf_stack_adjust), ctx);
1345 	emit(rv_addi(hi(fp), RV_REG_ZERO, 0), ctx);
1346 
1347 	/* Set up BPF context pointer. */
1348 	emit(rv_addi(lo(r1), RV_REG_A0, 0), ctx);
1349 	emit(rv_addi(hi(r1), RV_REG_ZERO, 0), ctx);
1350 
1351 	ctx->stack_size = stack_adjust;
1352 }
1353 
bpf_jit_build_epilogue(struct rv_jit_context * ctx)1354 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1355 {
1356 	__build_epilogue(false, ctx);
1357 }
1358