1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * bpf_jit_comp.c: BPF JIT compiler 4 * 5 * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com) 6 * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com 7 */ 8 #include <linux/netdevice.h> 9 #include <linux/filter.h> 10 #include <linux/if_vlan.h> 11 #include <linux/bpf.h> 12 13 #include <asm/set_memory.h> 14 #include <asm/nospec-branch.h> 15 16 static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len) 17 { 18 if (len == 1) 19 *ptr = bytes; 20 else if (len == 2) 21 *(u16 *)ptr = bytes; 22 else { 23 *(u32 *)ptr = bytes; 24 barrier(); 25 } 26 return ptr + len; 27 } 28 29 #define EMIT(bytes, len) \ 30 do { prog = emit_code(prog, bytes, len); cnt += len; } while (0) 31 32 #define EMIT1(b1) EMIT(b1, 1) 33 #define EMIT2(b1, b2) EMIT((b1) + ((b2) << 8), 2) 34 #define EMIT3(b1, b2, b3) EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3) 35 #define EMIT4(b1, b2, b3, b4) EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4) 36 37 #define EMIT1_off32(b1, off) \ 38 do { EMIT1(b1); EMIT(off, 4); } while (0) 39 #define EMIT2_off32(b1, b2, off) \ 40 do { EMIT2(b1, b2); EMIT(off, 4); } while (0) 41 #define EMIT3_off32(b1, b2, b3, off) \ 42 do { EMIT3(b1, b2, b3); EMIT(off, 4); } while (0) 43 #define EMIT4_off32(b1, b2, b3, b4, off) \ 44 do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0) 45 46 static bool is_imm8(int value) 47 { 48 return value <= 127 && value >= -128; 49 } 50 51 static bool is_simm32(s64 value) 52 { 53 return value == (s64)(s32)value; 54 } 55 56 static bool is_uimm32(u64 value) 57 { 58 return value == (u64)(u32)value; 59 } 60 61 /* mov dst, src */ 62 #define EMIT_mov(DST, SRC) \ 63 do { \ 64 if (DST != SRC) \ 65 EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \ 66 } while (0) 67 68 static int bpf_size_to_x86_bytes(int bpf_size) 69 { 70 if (bpf_size == BPF_W) 71 return 4; 72 else if (bpf_size == BPF_H) 73 return 2; 74 else if (bpf_size == BPF_B) 75 return 1; 76 else if (bpf_size == BPF_DW) 77 return 4; /* imm32 */ 78 else 79 return 0; 80 } 81 82 /* 83 * List of x86 cond jumps opcodes (. + s8) 84 * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32) 85 */ 86 #define X86_JB 0x72 87 #define X86_JAE 0x73 88 #define X86_JE 0x74 89 #define X86_JNE 0x75 90 #define X86_JBE 0x76 91 #define X86_JA 0x77 92 #define X86_JL 0x7C 93 #define X86_JGE 0x7D 94 #define X86_JLE 0x7E 95 #define X86_JG 0x7F 96 97 /* Pick a register outside of BPF range for JIT internal work */ 98 #define AUX_REG (MAX_BPF_JIT_REG + 1) 99 100 /* 101 * The following table maps BPF registers to x86-64 registers. 102 * 103 * x86-64 register R12 is unused, since if used as base address 104 * register in load/store instructions, it always needs an 105 * extra byte of encoding and is callee saved. 106 * 107 * Also x86-64 register R9 is unused. x86-64 register R10 is 108 * used for blinding (if enabled). 109 */ 110 static const int reg2hex[] = { 111 [BPF_REG_0] = 0, /* RAX */ 112 [BPF_REG_1] = 7, /* RDI */ 113 [BPF_REG_2] = 6, /* RSI */ 114 [BPF_REG_3] = 2, /* RDX */ 115 [BPF_REG_4] = 1, /* RCX */ 116 [BPF_REG_5] = 0, /* R8 */ 117 [BPF_REG_6] = 3, /* RBX callee saved */ 118 [BPF_REG_7] = 5, /* R13 callee saved */ 119 [BPF_REG_8] = 6, /* R14 callee saved */ 120 [BPF_REG_9] = 7, /* R15 callee saved */ 121 [BPF_REG_FP] = 5, /* RBP readonly */ 122 [BPF_REG_AX] = 2, /* R10 temp register */ 123 [AUX_REG] = 3, /* R11 temp register */ 124 }; 125 126 /* 127 * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15 128 * which need extra byte of encoding. 129 * rax,rcx,...,rbp have simpler encoding 130 */ 131 static bool is_ereg(u32 reg) 132 { 133 return (1 << reg) & (BIT(BPF_REG_5) | 134 BIT(AUX_REG) | 135 BIT(BPF_REG_7) | 136 BIT(BPF_REG_8) | 137 BIT(BPF_REG_9) | 138 BIT(BPF_REG_AX)); 139 } 140 141 static bool is_axreg(u32 reg) 142 { 143 return reg == BPF_REG_0; 144 } 145 146 /* Add modifiers if 'reg' maps to x86-64 registers R8..R15 */ 147 static u8 add_1mod(u8 byte, u32 reg) 148 { 149 if (is_ereg(reg)) 150 byte |= 1; 151 return byte; 152 } 153 154 static u8 add_2mod(u8 byte, u32 r1, u32 r2) 155 { 156 if (is_ereg(r1)) 157 byte |= 1; 158 if (is_ereg(r2)) 159 byte |= 4; 160 return byte; 161 } 162 163 /* Encode 'dst_reg' register into x86-64 opcode 'byte' */ 164 static u8 add_1reg(u8 byte, u32 dst_reg) 165 { 166 return byte + reg2hex[dst_reg]; 167 } 168 169 /* Encode 'dst_reg' and 'src_reg' registers into x86-64 opcode 'byte' */ 170 static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg) 171 { 172 return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3); 173 } 174 175 static void jit_fill_hole(void *area, unsigned int size) 176 { 177 /* Fill whole space with INT3 instructions */ 178 memset(area, 0xcc, size); 179 } 180 181 struct jit_context { 182 int cleanup_addr; /* Epilogue code offset */ 183 }; 184 185 /* Maximum number of bytes emitted while JITing one eBPF insn */ 186 #define BPF_MAX_INSN_SIZE 128 187 #define BPF_INSN_SAFETY 64 188 189 #define PROLOGUE_SIZE 20 190 191 /* 192 * Emit x86-64 prologue code for BPF program and check its size. 193 * bpf_tail_call helper will skip it while jumping into another program 194 */ 195 static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) 196 { 197 u8 *prog = *pprog; 198 int cnt = 0; 199 200 EMIT1(0x55); /* push rbp */ 201 EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */ 202 /* sub rsp, rounded_stack_depth */ 203 EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); 204 EMIT1(0x53); /* push rbx */ 205 EMIT2(0x41, 0x55); /* push r13 */ 206 EMIT2(0x41, 0x56); /* push r14 */ 207 EMIT2(0x41, 0x57); /* push r15 */ 208 if (!ebpf_from_cbpf) { 209 /* zero init tail_call_cnt */ 210 EMIT2(0x6a, 0x00); 211 BUILD_BUG_ON(cnt != PROLOGUE_SIZE); 212 } 213 *pprog = prog; 214 } 215 216 /* 217 * Generate the following code: 218 * 219 * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ... 220 * if (index >= array->map.max_entries) 221 * goto out; 222 * if (++tail_call_cnt > MAX_TAIL_CALL_CNT) 223 * goto out; 224 * prog = array->ptrs[index]; 225 * if (prog == NULL) 226 * goto out; 227 * goto *(prog->bpf_func + prologue_size); 228 * out: 229 */ 230 static void emit_bpf_tail_call(u8 **pprog) 231 { 232 u8 *prog = *pprog; 233 int label1, label2, label3; 234 int cnt = 0; 235 236 /* 237 * rdi - pointer to ctx 238 * rsi - pointer to bpf_array 239 * rdx - index in bpf_array 240 */ 241 242 /* 243 * if (index >= array->map.max_entries) 244 * goto out; 245 */ 246 EMIT2(0x89, 0xD2); /* mov edx, edx */ 247 EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */ 248 offsetof(struct bpf_array, map.max_entries)); 249 #define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* Number of bytes to jump */ 250 EMIT2(X86_JBE, OFFSET1); /* jbe out */ 251 label1 = cnt; 252 253 /* 254 * if (tail_call_cnt > MAX_TAIL_CALL_CNT) 255 * goto out; 256 */ 257 EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ 258 EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ 259 #define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE) 260 EMIT2(X86_JA, OFFSET2); /* ja out */ 261 label2 = cnt; 262 EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ 263 EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ 264 265 /* prog = array->ptrs[index]; */ 266 EMIT4_off32(0x48, 0x8B, 0x84, 0xD6, /* mov rax, [rsi + rdx * 8 + offsetof(...)] */ 267 offsetof(struct bpf_array, ptrs)); 268 269 /* 270 * if (prog == NULL) 271 * goto out; 272 */ 273 EMIT3(0x48, 0x85, 0xC0); /* test rax,rax */ 274 #define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE) 275 EMIT2(X86_JE, OFFSET3); /* je out */ 276 label3 = cnt; 277 278 /* goto *(prog->bpf_func + prologue_size); */ 279 EMIT4(0x48, 0x8B, 0x40, /* mov rax, qword ptr [rax + 32] */ 280 offsetof(struct bpf_prog, bpf_func)); 281 EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE); /* add rax, prologue_size */ 282 283 /* 284 * Wow we're ready to jump into next BPF program 285 * rdi == ctx (1st arg) 286 * rax == prog->bpf_func + prologue_size 287 */ 288 RETPOLINE_RAX_BPF_JIT(); 289 290 /* out: */ 291 BUILD_BUG_ON(cnt - label1 != OFFSET1); 292 BUILD_BUG_ON(cnt - label2 != OFFSET2); 293 BUILD_BUG_ON(cnt - label3 != OFFSET3); 294 *pprog = prog; 295 } 296 297 static void emit_mov_imm32(u8 **pprog, bool sign_propagate, 298 u32 dst_reg, const u32 imm32) 299 { 300 u8 *prog = *pprog; 301 u8 b1, b2, b3; 302 int cnt = 0; 303 304 /* 305 * Optimization: if imm32 is positive, use 'mov %eax, imm32' 306 * (which zero-extends imm32) to save 2 bytes. 307 */ 308 if (sign_propagate && (s32)imm32 < 0) { 309 /* 'mov %rax, imm32' sign extends imm32 */ 310 b1 = add_1mod(0x48, dst_reg); 311 b2 = 0xC7; 312 b3 = 0xC0; 313 EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32); 314 goto done; 315 } 316 317 /* 318 * Optimization: if imm32 is zero, use 'xor %eax, %eax' 319 * to save 3 bytes. 320 */ 321 if (imm32 == 0) { 322 if (is_ereg(dst_reg)) 323 EMIT1(add_2mod(0x40, dst_reg, dst_reg)); 324 b2 = 0x31; /* xor */ 325 b3 = 0xC0; 326 EMIT2(b2, add_2reg(b3, dst_reg, dst_reg)); 327 goto done; 328 } 329 330 /* mov %eax, imm32 */ 331 if (is_ereg(dst_reg)) 332 EMIT1(add_1mod(0x40, dst_reg)); 333 EMIT1_off32(add_1reg(0xB8, dst_reg), imm32); 334 done: 335 *pprog = prog; 336 } 337 338 static void emit_mov_imm64(u8 **pprog, u32 dst_reg, 339 const u32 imm32_hi, const u32 imm32_lo) 340 { 341 u8 *prog = *pprog; 342 int cnt = 0; 343 344 if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) { 345 /* 346 * For emitting plain u32, where sign bit must not be 347 * propagated LLVM tends to load imm64 over mov32 348 * directly, so save couple of bytes by just doing 349 * 'mov %eax, imm32' instead. 350 */ 351 emit_mov_imm32(&prog, false, dst_reg, imm32_lo); 352 } else { 353 /* movabsq %rax, imm64 */ 354 EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg)); 355 EMIT(imm32_lo, 4); 356 EMIT(imm32_hi, 4); 357 } 358 359 *pprog = prog; 360 } 361 362 static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg) 363 { 364 u8 *prog = *pprog; 365 int cnt = 0; 366 367 if (is64) { 368 /* mov dst, src */ 369 EMIT_mov(dst_reg, src_reg); 370 } else { 371 /* mov32 dst, src */ 372 if (is_ereg(dst_reg) || is_ereg(src_reg)) 373 EMIT1(add_2mod(0x40, dst_reg, src_reg)); 374 EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg)); 375 } 376 377 *pprog = prog; 378 } 379 380 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, 381 int oldproglen, struct jit_context *ctx) 382 { 383 struct bpf_insn *insn = bpf_prog->insnsi; 384 int insn_cnt = bpf_prog->len; 385 bool seen_exit = false; 386 u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY]; 387 int i, cnt = 0; 388 int proglen = 0; 389 u8 *prog = temp; 390 391 emit_prologue(&prog, bpf_prog->aux->stack_depth, 392 bpf_prog_was_classic(bpf_prog)); 393 addrs[0] = prog - temp; 394 395 for (i = 1; i <= insn_cnt; i++, insn++) { 396 const s32 imm32 = insn->imm; 397 u32 dst_reg = insn->dst_reg; 398 u32 src_reg = insn->src_reg; 399 u8 b2 = 0, b3 = 0; 400 s64 jmp_offset; 401 u8 jmp_cond; 402 int ilen; 403 u8 *func; 404 405 switch (insn->code) { 406 /* ALU */ 407 case BPF_ALU | BPF_ADD | BPF_X: 408 case BPF_ALU | BPF_SUB | BPF_X: 409 case BPF_ALU | BPF_AND | BPF_X: 410 case BPF_ALU | BPF_OR | BPF_X: 411 case BPF_ALU | BPF_XOR | BPF_X: 412 case BPF_ALU64 | BPF_ADD | BPF_X: 413 case BPF_ALU64 | BPF_SUB | BPF_X: 414 case BPF_ALU64 | BPF_AND | BPF_X: 415 case BPF_ALU64 | BPF_OR | BPF_X: 416 case BPF_ALU64 | BPF_XOR | BPF_X: 417 switch (BPF_OP(insn->code)) { 418 case BPF_ADD: b2 = 0x01; break; 419 case BPF_SUB: b2 = 0x29; break; 420 case BPF_AND: b2 = 0x21; break; 421 case BPF_OR: b2 = 0x09; break; 422 case BPF_XOR: b2 = 0x31; break; 423 } 424 if (BPF_CLASS(insn->code) == BPF_ALU64) 425 EMIT1(add_2mod(0x48, dst_reg, src_reg)); 426 else if (is_ereg(dst_reg) || is_ereg(src_reg)) 427 EMIT1(add_2mod(0x40, dst_reg, src_reg)); 428 EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg)); 429 break; 430 431 case BPF_ALU64 | BPF_MOV | BPF_X: 432 case BPF_ALU | BPF_MOV | BPF_X: 433 emit_mov_reg(&prog, 434 BPF_CLASS(insn->code) == BPF_ALU64, 435 dst_reg, src_reg); 436 break; 437 438 /* neg dst */ 439 case BPF_ALU | BPF_NEG: 440 case BPF_ALU64 | BPF_NEG: 441 if (BPF_CLASS(insn->code) == BPF_ALU64) 442 EMIT1(add_1mod(0x48, dst_reg)); 443 else if (is_ereg(dst_reg)) 444 EMIT1(add_1mod(0x40, dst_reg)); 445 EMIT2(0xF7, add_1reg(0xD8, dst_reg)); 446 break; 447 448 case BPF_ALU | BPF_ADD | BPF_K: 449 case BPF_ALU | BPF_SUB | BPF_K: 450 case BPF_ALU | BPF_AND | BPF_K: 451 case BPF_ALU | BPF_OR | BPF_K: 452 case BPF_ALU | BPF_XOR | BPF_K: 453 case BPF_ALU64 | BPF_ADD | BPF_K: 454 case BPF_ALU64 | BPF_SUB | BPF_K: 455 case BPF_ALU64 | BPF_AND | BPF_K: 456 case BPF_ALU64 | BPF_OR | BPF_K: 457 case BPF_ALU64 | BPF_XOR | BPF_K: 458 if (BPF_CLASS(insn->code) == BPF_ALU64) 459 EMIT1(add_1mod(0x48, dst_reg)); 460 else if (is_ereg(dst_reg)) 461 EMIT1(add_1mod(0x40, dst_reg)); 462 463 /* 464 * b3 holds 'normal' opcode, b2 short form only valid 465 * in case dst is eax/rax. 466 */ 467 switch (BPF_OP(insn->code)) { 468 case BPF_ADD: 469 b3 = 0xC0; 470 b2 = 0x05; 471 break; 472 case BPF_SUB: 473 b3 = 0xE8; 474 b2 = 0x2D; 475 break; 476 case BPF_AND: 477 b3 = 0xE0; 478 b2 = 0x25; 479 break; 480 case BPF_OR: 481 b3 = 0xC8; 482 b2 = 0x0D; 483 break; 484 case BPF_XOR: 485 b3 = 0xF0; 486 b2 = 0x35; 487 break; 488 } 489 490 if (is_imm8(imm32)) 491 EMIT3(0x83, add_1reg(b3, dst_reg), imm32); 492 else if (is_axreg(dst_reg)) 493 EMIT1_off32(b2, imm32); 494 else 495 EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32); 496 break; 497 498 case BPF_ALU64 | BPF_MOV | BPF_K: 499 case BPF_ALU | BPF_MOV | BPF_K: 500 emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64, 501 dst_reg, imm32); 502 break; 503 504 case BPF_LD | BPF_IMM | BPF_DW: 505 emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm); 506 insn++; 507 i++; 508 break; 509 510 /* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */ 511 case BPF_ALU | BPF_MOD | BPF_X: 512 case BPF_ALU | BPF_DIV | BPF_X: 513 case BPF_ALU | BPF_MOD | BPF_K: 514 case BPF_ALU | BPF_DIV | BPF_K: 515 case BPF_ALU64 | BPF_MOD | BPF_X: 516 case BPF_ALU64 | BPF_DIV | BPF_X: 517 case BPF_ALU64 | BPF_MOD | BPF_K: 518 case BPF_ALU64 | BPF_DIV | BPF_K: 519 EMIT1(0x50); /* push rax */ 520 EMIT1(0x52); /* push rdx */ 521 522 if (BPF_SRC(insn->code) == BPF_X) 523 /* mov r11, src_reg */ 524 EMIT_mov(AUX_REG, src_reg); 525 else 526 /* mov r11, imm32 */ 527 EMIT3_off32(0x49, 0xC7, 0xC3, imm32); 528 529 /* mov rax, dst_reg */ 530 EMIT_mov(BPF_REG_0, dst_reg); 531 532 /* 533 * xor edx, edx 534 * equivalent to 'xor rdx, rdx', but one byte less 535 */ 536 EMIT2(0x31, 0xd2); 537 538 if (BPF_CLASS(insn->code) == BPF_ALU64) 539 /* div r11 */ 540 EMIT3(0x49, 0xF7, 0xF3); 541 else 542 /* div r11d */ 543 EMIT3(0x41, 0xF7, 0xF3); 544 545 if (BPF_OP(insn->code) == BPF_MOD) 546 /* mov r11, rdx */ 547 EMIT3(0x49, 0x89, 0xD3); 548 else 549 /* mov r11, rax */ 550 EMIT3(0x49, 0x89, 0xC3); 551 552 EMIT1(0x5A); /* pop rdx */ 553 EMIT1(0x58); /* pop rax */ 554 555 /* mov dst_reg, r11 */ 556 EMIT_mov(dst_reg, AUX_REG); 557 break; 558 559 case BPF_ALU | BPF_MUL | BPF_K: 560 case BPF_ALU | BPF_MUL | BPF_X: 561 case BPF_ALU64 | BPF_MUL | BPF_K: 562 case BPF_ALU64 | BPF_MUL | BPF_X: 563 { 564 bool is64 = BPF_CLASS(insn->code) == BPF_ALU64; 565 566 if (dst_reg != BPF_REG_0) 567 EMIT1(0x50); /* push rax */ 568 if (dst_reg != BPF_REG_3) 569 EMIT1(0x52); /* push rdx */ 570 571 /* mov r11, dst_reg */ 572 EMIT_mov(AUX_REG, dst_reg); 573 574 if (BPF_SRC(insn->code) == BPF_X) 575 emit_mov_reg(&prog, is64, BPF_REG_0, src_reg); 576 else 577 emit_mov_imm32(&prog, is64, BPF_REG_0, imm32); 578 579 if (is64) 580 EMIT1(add_1mod(0x48, AUX_REG)); 581 else if (is_ereg(AUX_REG)) 582 EMIT1(add_1mod(0x40, AUX_REG)); 583 /* mul(q) r11 */ 584 EMIT2(0xF7, add_1reg(0xE0, AUX_REG)); 585 586 if (dst_reg != BPF_REG_3) 587 EMIT1(0x5A); /* pop rdx */ 588 if (dst_reg != BPF_REG_0) { 589 /* mov dst_reg, rax */ 590 EMIT_mov(dst_reg, BPF_REG_0); 591 EMIT1(0x58); /* pop rax */ 592 } 593 break; 594 } 595 /* Shifts */ 596 case BPF_ALU | BPF_LSH | BPF_K: 597 case BPF_ALU | BPF_RSH | BPF_K: 598 case BPF_ALU | BPF_ARSH | BPF_K: 599 case BPF_ALU64 | BPF_LSH | BPF_K: 600 case BPF_ALU64 | BPF_RSH | BPF_K: 601 case BPF_ALU64 | BPF_ARSH | BPF_K: 602 if (BPF_CLASS(insn->code) == BPF_ALU64) 603 EMIT1(add_1mod(0x48, dst_reg)); 604 else if (is_ereg(dst_reg)) 605 EMIT1(add_1mod(0x40, dst_reg)); 606 607 switch (BPF_OP(insn->code)) { 608 case BPF_LSH: b3 = 0xE0; break; 609 case BPF_RSH: b3 = 0xE8; break; 610 case BPF_ARSH: b3 = 0xF8; break; 611 } 612 613 if (imm32 == 1) 614 EMIT2(0xD1, add_1reg(b3, dst_reg)); 615 else 616 EMIT3(0xC1, add_1reg(b3, dst_reg), imm32); 617 break; 618 619 case BPF_ALU | BPF_LSH | BPF_X: 620 case BPF_ALU | BPF_RSH | BPF_X: 621 case BPF_ALU | BPF_ARSH | BPF_X: 622 case BPF_ALU64 | BPF_LSH | BPF_X: 623 case BPF_ALU64 | BPF_RSH | BPF_X: 624 case BPF_ALU64 | BPF_ARSH | BPF_X: 625 626 /* Check for bad case when dst_reg == rcx */ 627 if (dst_reg == BPF_REG_4) { 628 /* mov r11, dst_reg */ 629 EMIT_mov(AUX_REG, dst_reg); 630 dst_reg = AUX_REG; 631 } 632 633 if (src_reg != BPF_REG_4) { /* common case */ 634 EMIT1(0x51); /* push rcx */ 635 636 /* mov rcx, src_reg */ 637 EMIT_mov(BPF_REG_4, src_reg); 638 } 639 640 /* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */ 641 if (BPF_CLASS(insn->code) == BPF_ALU64) 642 EMIT1(add_1mod(0x48, dst_reg)); 643 else if (is_ereg(dst_reg)) 644 EMIT1(add_1mod(0x40, dst_reg)); 645 646 switch (BPF_OP(insn->code)) { 647 case BPF_LSH: b3 = 0xE0; break; 648 case BPF_RSH: b3 = 0xE8; break; 649 case BPF_ARSH: b3 = 0xF8; break; 650 } 651 EMIT2(0xD3, add_1reg(b3, dst_reg)); 652 653 if (src_reg != BPF_REG_4) 654 EMIT1(0x59); /* pop rcx */ 655 656 if (insn->dst_reg == BPF_REG_4) 657 /* mov dst_reg, r11 */ 658 EMIT_mov(insn->dst_reg, AUX_REG); 659 break; 660 661 case BPF_ALU | BPF_END | BPF_FROM_BE: 662 switch (imm32) { 663 case 16: 664 /* Emit 'ror %ax, 8' to swap lower 2 bytes */ 665 EMIT1(0x66); 666 if (is_ereg(dst_reg)) 667 EMIT1(0x41); 668 EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8); 669 670 /* Emit 'movzwl eax, ax' */ 671 if (is_ereg(dst_reg)) 672 EMIT3(0x45, 0x0F, 0xB7); 673 else 674 EMIT2(0x0F, 0xB7); 675 EMIT1(add_2reg(0xC0, dst_reg, dst_reg)); 676 break; 677 case 32: 678 /* Emit 'bswap eax' to swap lower 4 bytes */ 679 if (is_ereg(dst_reg)) 680 EMIT2(0x41, 0x0F); 681 else 682 EMIT1(0x0F); 683 EMIT1(add_1reg(0xC8, dst_reg)); 684 break; 685 case 64: 686 /* Emit 'bswap rax' to swap 8 bytes */ 687 EMIT3(add_1mod(0x48, dst_reg), 0x0F, 688 add_1reg(0xC8, dst_reg)); 689 break; 690 } 691 break; 692 693 case BPF_ALU | BPF_END | BPF_FROM_LE: 694 switch (imm32) { 695 case 16: 696 /* 697 * Emit 'movzwl eax, ax' to zero extend 16-bit 698 * into 64 bit 699 */ 700 if (is_ereg(dst_reg)) 701 EMIT3(0x45, 0x0F, 0xB7); 702 else 703 EMIT2(0x0F, 0xB7); 704 EMIT1(add_2reg(0xC0, dst_reg, dst_reg)); 705 break; 706 case 32: 707 /* Emit 'mov eax, eax' to clear upper 32-bits */ 708 if (is_ereg(dst_reg)) 709 EMIT1(0x45); 710 EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg)); 711 break; 712 case 64: 713 /* nop */ 714 break; 715 } 716 break; 717 718 /* ST: *(u8*)(dst_reg + off) = imm */ 719 case BPF_ST | BPF_MEM | BPF_B: 720 if (is_ereg(dst_reg)) 721 EMIT2(0x41, 0xC6); 722 else 723 EMIT1(0xC6); 724 goto st; 725 case BPF_ST | BPF_MEM | BPF_H: 726 if (is_ereg(dst_reg)) 727 EMIT3(0x66, 0x41, 0xC7); 728 else 729 EMIT2(0x66, 0xC7); 730 goto st; 731 case BPF_ST | BPF_MEM | BPF_W: 732 if (is_ereg(dst_reg)) 733 EMIT2(0x41, 0xC7); 734 else 735 EMIT1(0xC7); 736 goto st; 737 case BPF_ST | BPF_MEM | BPF_DW: 738 EMIT2(add_1mod(0x48, dst_reg), 0xC7); 739 740 st: if (is_imm8(insn->off)) 741 EMIT2(add_1reg(0x40, dst_reg), insn->off); 742 else 743 EMIT1_off32(add_1reg(0x80, dst_reg), insn->off); 744 745 EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code))); 746 break; 747 748 /* STX: *(u8*)(dst_reg + off) = src_reg */ 749 case BPF_STX | BPF_MEM | BPF_B: 750 /* Emit 'mov byte ptr [rax + off], al' */ 751 if (is_ereg(dst_reg) || is_ereg(src_reg) || 752 /* We have to add extra byte for x86 SIL, DIL regs */ 753 src_reg == BPF_REG_1 || src_reg == BPF_REG_2) 754 EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88); 755 else 756 EMIT1(0x88); 757 goto stx; 758 case BPF_STX | BPF_MEM | BPF_H: 759 if (is_ereg(dst_reg) || is_ereg(src_reg)) 760 EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89); 761 else 762 EMIT2(0x66, 0x89); 763 goto stx; 764 case BPF_STX | BPF_MEM | BPF_W: 765 if (is_ereg(dst_reg) || is_ereg(src_reg)) 766 EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89); 767 else 768 EMIT1(0x89); 769 goto stx; 770 case BPF_STX | BPF_MEM | BPF_DW: 771 EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89); 772 stx: if (is_imm8(insn->off)) 773 EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off); 774 else 775 EMIT1_off32(add_2reg(0x80, dst_reg, src_reg), 776 insn->off); 777 break; 778 779 /* LDX: dst_reg = *(u8*)(src_reg + off) */ 780 case BPF_LDX | BPF_MEM | BPF_B: 781 /* Emit 'movzx rax, byte ptr [rax + off]' */ 782 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6); 783 goto ldx; 784 case BPF_LDX | BPF_MEM | BPF_H: 785 /* Emit 'movzx rax, word ptr [rax + off]' */ 786 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7); 787 goto ldx; 788 case BPF_LDX | BPF_MEM | BPF_W: 789 /* Emit 'mov eax, dword ptr [rax+0x14]' */ 790 if (is_ereg(dst_reg) || is_ereg(src_reg)) 791 EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B); 792 else 793 EMIT1(0x8B); 794 goto ldx; 795 case BPF_LDX | BPF_MEM | BPF_DW: 796 /* Emit 'mov rax, qword ptr [rax+0x14]' */ 797 EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B); 798 ldx: /* 799 * If insn->off == 0 we can save one extra byte, but 800 * special case of x86 R13 which always needs an offset 801 * is not worth the hassle 802 */ 803 if (is_imm8(insn->off)) 804 EMIT2(add_2reg(0x40, src_reg, dst_reg), insn->off); 805 else 806 EMIT1_off32(add_2reg(0x80, src_reg, dst_reg), 807 insn->off); 808 break; 809 810 /* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */ 811 case BPF_STX | BPF_XADD | BPF_W: 812 /* Emit 'lock add dword ptr [rax + off], eax' */ 813 if (is_ereg(dst_reg) || is_ereg(src_reg)) 814 EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01); 815 else 816 EMIT2(0xF0, 0x01); 817 goto xadd; 818 case BPF_STX | BPF_XADD | BPF_DW: 819 EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01); 820 xadd: if (is_imm8(insn->off)) 821 EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off); 822 else 823 EMIT1_off32(add_2reg(0x80, dst_reg, src_reg), 824 insn->off); 825 break; 826 827 /* call */ 828 case BPF_JMP | BPF_CALL: 829 func = (u8 *) __bpf_call_base + imm32; 830 jmp_offset = func - (image + addrs[i]); 831 if (!imm32 || !is_simm32(jmp_offset)) { 832 pr_err("unsupported BPF func %d addr %p image %p\n", 833 imm32, func, image); 834 return -EINVAL; 835 } 836 EMIT1_off32(0xE8, jmp_offset); 837 break; 838 839 case BPF_JMP | BPF_TAIL_CALL: 840 emit_bpf_tail_call(&prog); 841 break; 842 843 /* cond jump */ 844 case BPF_JMP | BPF_JEQ | BPF_X: 845 case BPF_JMP | BPF_JNE | BPF_X: 846 case BPF_JMP | BPF_JGT | BPF_X: 847 case BPF_JMP | BPF_JLT | BPF_X: 848 case BPF_JMP | BPF_JGE | BPF_X: 849 case BPF_JMP | BPF_JLE | BPF_X: 850 case BPF_JMP | BPF_JSGT | BPF_X: 851 case BPF_JMP | BPF_JSLT | BPF_X: 852 case BPF_JMP | BPF_JSGE | BPF_X: 853 case BPF_JMP | BPF_JSLE | BPF_X: 854 case BPF_JMP32 | BPF_JEQ | BPF_X: 855 case BPF_JMP32 | BPF_JNE | BPF_X: 856 case BPF_JMP32 | BPF_JGT | BPF_X: 857 case BPF_JMP32 | BPF_JLT | BPF_X: 858 case BPF_JMP32 | BPF_JGE | BPF_X: 859 case BPF_JMP32 | BPF_JLE | BPF_X: 860 case BPF_JMP32 | BPF_JSGT | BPF_X: 861 case BPF_JMP32 | BPF_JSLT | BPF_X: 862 case BPF_JMP32 | BPF_JSGE | BPF_X: 863 case BPF_JMP32 | BPF_JSLE | BPF_X: 864 /* cmp dst_reg, src_reg */ 865 if (BPF_CLASS(insn->code) == BPF_JMP) 866 EMIT1(add_2mod(0x48, dst_reg, src_reg)); 867 else if (is_ereg(dst_reg) || is_ereg(src_reg)) 868 EMIT1(add_2mod(0x40, dst_reg, src_reg)); 869 EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg)); 870 goto emit_cond_jmp; 871 872 case BPF_JMP | BPF_JSET | BPF_X: 873 case BPF_JMP32 | BPF_JSET | BPF_X: 874 /* test dst_reg, src_reg */ 875 if (BPF_CLASS(insn->code) == BPF_JMP) 876 EMIT1(add_2mod(0x48, dst_reg, src_reg)); 877 else if (is_ereg(dst_reg) || is_ereg(src_reg)) 878 EMIT1(add_2mod(0x40, dst_reg, src_reg)); 879 EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg)); 880 goto emit_cond_jmp; 881 882 case BPF_JMP | BPF_JSET | BPF_K: 883 case BPF_JMP32 | BPF_JSET | BPF_K: 884 /* test dst_reg, imm32 */ 885 if (BPF_CLASS(insn->code) == BPF_JMP) 886 EMIT1(add_1mod(0x48, dst_reg)); 887 else if (is_ereg(dst_reg)) 888 EMIT1(add_1mod(0x40, dst_reg)); 889 EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32); 890 goto emit_cond_jmp; 891 892 case BPF_JMP | BPF_JEQ | BPF_K: 893 case BPF_JMP | BPF_JNE | BPF_K: 894 case BPF_JMP | BPF_JGT | BPF_K: 895 case BPF_JMP | BPF_JLT | BPF_K: 896 case BPF_JMP | BPF_JGE | BPF_K: 897 case BPF_JMP | BPF_JLE | BPF_K: 898 case BPF_JMP | BPF_JSGT | BPF_K: 899 case BPF_JMP | BPF_JSLT | BPF_K: 900 case BPF_JMP | BPF_JSGE | BPF_K: 901 case BPF_JMP | BPF_JSLE | BPF_K: 902 case BPF_JMP32 | BPF_JEQ | BPF_K: 903 case BPF_JMP32 | BPF_JNE | BPF_K: 904 case BPF_JMP32 | BPF_JGT | BPF_K: 905 case BPF_JMP32 | BPF_JLT | BPF_K: 906 case BPF_JMP32 | BPF_JGE | BPF_K: 907 case BPF_JMP32 | BPF_JLE | BPF_K: 908 case BPF_JMP32 | BPF_JSGT | BPF_K: 909 case BPF_JMP32 | BPF_JSLT | BPF_K: 910 case BPF_JMP32 | BPF_JSGE | BPF_K: 911 case BPF_JMP32 | BPF_JSLE | BPF_K: 912 /* cmp dst_reg, imm8/32 */ 913 if (BPF_CLASS(insn->code) == BPF_JMP) 914 EMIT1(add_1mod(0x48, dst_reg)); 915 else if (is_ereg(dst_reg)) 916 EMIT1(add_1mod(0x40, dst_reg)); 917 918 if (is_imm8(imm32)) 919 EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32); 920 else 921 EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32); 922 923 emit_cond_jmp: /* Convert BPF opcode to x86 */ 924 switch (BPF_OP(insn->code)) { 925 case BPF_JEQ: 926 jmp_cond = X86_JE; 927 break; 928 case BPF_JSET: 929 case BPF_JNE: 930 jmp_cond = X86_JNE; 931 break; 932 case BPF_JGT: 933 /* GT is unsigned '>', JA in x86 */ 934 jmp_cond = X86_JA; 935 break; 936 case BPF_JLT: 937 /* LT is unsigned '<', JB in x86 */ 938 jmp_cond = X86_JB; 939 break; 940 case BPF_JGE: 941 /* GE is unsigned '>=', JAE in x86 */ 942 jmp_cond = X86_JAE; 943 break; 944 case BPF_JLE: 945 /* LE is unsigned '<=', JBE in x86 */ 946 jmp_cond = X86_JBE; 947 break; 948 case BPF_JSGT: 949 /* Signed '>', GT in x86 */ 950 jmp_cond = X86_JG; 951 break; 952 case BPF_JSLT: 953 /* Signed '<', LT in x86 */ 954 jmp_cond = X86_JL; 955 break; 956 case BPF_JSGE: 957 /* Signed '>=', GE in x86 */ 958 jmp_cond = X86_JGE; 959 break; 960 case BPF_JSLE: 961 /* Signed '<=', LE in x86 */ 962 jmp_cond = X86_JLE; 963 break; 964 default: /* to silence GCC warning */ 965 return -EFAULT; 966 } 967 jmp_offset = addrs[i + insn->off] - addrs[i]; 968 if (is_imm8(jmp_offset)) { 969 EMIT2(jmp_cond, jmp_offset); 970 } else if (is_simm32(jmp_offset)) { 971 EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset); 972 } else { 973 pr_err("cond_jmp gen bug %llx\n", jmp_offset); 974 return -EFAULT; 975 } 976 977 break; 978 979 case BPF_JMP | BPF_JA: 980 if (insn->off == -1) 981 /* -1 jmp instructions will always jump 982 * backwards two bytes. Explicitly handling 983 * this case avoids wasting too many passes 984 * when there are long sequences of replaced 985 * dead code. 986 */ 987 jmp_offset = -2; 988 else 989 jmp_offset = addrs[i + insn->off] - addrs[i]; 990 991 if (!jmp_offset) 992 /* Optimize out nop jumps */ 993 break; 994 emit_jmp: 995 if (is_imm8(jmp_offset)) { 996 EMIT2(0xEB, jmp_offset); 997 } else if (is_simm32(jmp_offset)) { 998 EMIT1_off32(0xE9, jmp_offset); 999 } else { 1000 pr_err("jmp gen bug %llx\n", jmp_offset); 1001 return -EFAULT; 1002 } 1003 break; 1004 1005 case BPF_JMP | BPF_EXIT: 1006 if (seen_exit) { 1007 jmp_offset = ctx->cleanup_addr - addrs[i]; 1008 goto emit_jmp; 1009 } 1010 seen_exit = true; 1011 /* Update cleanup_addr */ 1012 ctx->cleanup_addr = proglen; 1013 if (!bpf_prog_was_classic(bpf_prog)) 1014 EMIT1(0x5B); /* get rid of tail_call_cnt */ 1015 EMIT2(0x41, 0x5F); /* pop r15 */ 1016 EMIT2(0x41, 0x5E); /* pop r14 */ 1017 EMIT2(0x41, 0x5D); /* pop r13 */ 1018 EMIT1(0x5B); /* pop rbx */ 1019 EMIT1(0xC9); /* leave */ 1020 EMIT1(0xC3); /* ret */ 1021 break; 1022 1023 default: 1024 /* 1025 * By design x86-64 JIT should support all BPF instructions. 1026 * This error will be seen if new instruction was added 1027 * to the interpreter, but not to the JIT, or if there is 1028 * junk in bpf_prog. 1029 */ 1030 pr_err("bpf_jit: unknown opcode %02x\n", insn->code); 1031 return -EINVAL; 1032 } 1033 1034 ilen = prog - temp; 1035 if (ilen > BPF_MAX_INSN_SIZE) { 1036 pr_err("bpf_jit: fatal insn size error\n"); 1037 return -EFAULT; 1038 } 1039 1040 if (image) { 1041 if (unlikely(proglen + ilen > oldproglen)) { 1042 pr_err("bpf_jit: fatal error\n"); 1043 return -EFAULT; 1044 } 1045 memcpy(image + proglen, temp, ilen); 1046 } 1047 proglen += ilen; 1048 addrs[i] = proglen; 1049 prog = temp; 1050 } 1051 return proglen; 1052 } 1053 1054 struct x64_jit_data { 1055 struct bpf_binary_header *header; 1056 int *addrs; 1057 u8 *image; 1058 int proglen; 1059 struct jit_context ctx; 1060 }; 1061 1062 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) 1063 { 1064 struct bpf_binary_header *header = NULL; 1065 struct bpf_prog *tmp, *orig_prog = prog; 1066 struct x64_jit_data *jit_data; 1067 int proglen, oldproglen = 0; 1068 struct jit_context ctx = {}; 1069 bool tmp_blinded = false; 1070 bool extra_pass = false; 1071 u8 *image = NULL; 1072 int *addrs; 1073 int pass; 1074 int i; 1075 1076 if (!prog->jit_requested) 1077 return orig_prog; 1078 1079 tmp = bpf_jit_blind_constants(prog); 1080 /* 1081 * If blinding was requested and we failed during blinding, 1082 * we must fall back to the interpreter. 1083 */ 1084 if (IS_ERR(tmp)) 1085 return orig_prog; 1086 if (tmp != prog) { 1087 tmp_blinded = true; 1088 prog = tmp; 1089 } 1090 1091 jit_data = prog->aux->jit_data; 1092 if (!jit_data) { 1093 jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL); 1094 if (!jit_data) { 1095 prog = orig_prog; 1096 goto out; 1097 } 1098 prog->aux->jit_data = jit_data; 1099 } 1100 addrs = jit_data->addrs; 1101 if (addrs) { 1102 ctx = jit_data->ctx; 1103 oldproglen = jit_data->proglen; 1104 image = jit_data->image; 1105 header = jit_data->header; 1106 extra_pass = true; 1107 goto skip_init_addrs; 1108 } 1109 addrs = kmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL); 1110 if (!addrs) { 1111 prog = orig_prog; 1112 goto out_addrs; 1113 } 1114 1115 /* 1116 * Before first pass, make a rough estimation of addrs[] 1117 * each BPF instruction is translated to less than 64 bytes 1118 */ 1119 for (proglen = 0, i = 0; i <= prog->len; i++) { 1120 proglen += 64; 1121 addrs[i] = proglen; 1122 } 1123 ctx.cleanup_addr = proglen; 1124 skip_init_addrs: 1125 1126 /* 1127 * JITed image shrinks with every pass and the loop iterates 1128 * until the image stops shrinking. Very large BPF programs 1129 * may converge on the last pass. In such case do one more 1130 * pass to emit the final image. 1131 */ 1132 for (pass = 0; pass < 20 || image; pass++) { 1133 proglen = do_jit(prog, addrs, image, oldproglen, &ctx); 1134 if (proglen <= 0) { 1135 out_image: 1136 image = NULL; 1137 if (header) 1138 bpf_jit_binary_free(header); 1139 prog = orig_prog; 1140 goto out_addrs; 1141 } 1142 if (image) { 1143 if (proglen != oldproglen) { 1144 pr_err("bpf_jit: proglen=%d != oldproglen=%d\n", 1145 proglen, oldproglen); 1146 goto out_image; 1147 } 1148 break; 1149 } 1150 if (proglen == oldproglen) { 1151 header = bpf_jit_binary_alloc(proglen, &image, 1152 1, jit_fill_hole); 1153 if (!header) { 1154 prog = orig_prog; 1155 goto out_addrs; 1156 } 1157 } 1158 oldproglen = proglen; 1159 cond_resched(); 1160 } 1161 1162 if (bpf_jit_enable > 1) 1163 bpf_jit_dump(prog->len, proglen, pass + 1, image); 1164 1165 if (image) { 1166 if (!prog->is_func || extra_pass) { 1167 bpf_jit_binary_lock_ro(header); 1168 } else { 1169 jit_data->addrs = addrs; 1170 jit_data->ctx = ctx; 1171 jit_data->proglen = proglen; 1172 jit_data->image = image; 1173 jit_data->header = header; 1174 } 1175 prog->bpf_func = (void *)image; 1176 prog->jited = 1; 1177 prog->jited_len = proglen; 1178 } else { 1179 prog = orig_prog; 1180 } 1181 1182 if (!image || !prog->is_func || extra_pass) { 1183 if (image) 1184 bpf_prog_fill_jited_linfo(prog, addrs + 1); 1185 out_addrs: 1186 kfree(addrs); 1187 kfree(jit_data); 1188 prog->aux->jit_data = NULL; 1189 } 1190 out: 1191 if (tmp_blinded) 1192 bpf_jit_prog_release_other(prog, prog == orig_prog ? 1193 tmp : orig_prog); 1194 return prog; 1195 } 1196