1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * bpf_jit_comp.c: BPF JIT compiler 4 * 5 * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com) 6 * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com 7 */ 8 #include <linux/netdevice.h> 9 #include <linux/filter.h> 10 #include <linux/if_vlan.h> 11 #include <linux/bpf.h> 12 13 #include <asm/set_memory.h> 14 #include <asm/nospec-branch.h> 15 16 static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len) 17 { 18 if (len == 1) 19 *ptr = bytes; 20 else if (len == 2) 21 *(u16 *)ptr = bytes; 22 else { 23 *(u32 *)ptr = bytes; 24 barrier(); 25 } 26 return ptr + len; 27 } 28 29 #define EMIT(bytes, len) \ 30 do { prog = emit_code(prog, bytes, len); cnt += len; } while (0) 31 32 #define EMIT1(b1) EMIT(b1, 1) 33 #define EMIT2(b1, b2) EMIT((b1) + ((b2) << 8), 2) 34 #define EMIT3(b1, b2, b3) EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3) 35 #define EMIT4(b1, b2, b3, b4) EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4) 36 37 #define EMIT1_off32(b1, off) \ 38 do { EMIT1(b1); EMIT(off, 4); } while (0) 39 #define EMIT2_off32(b1, b2, off) \ 40 do { EMIT2(b1, b2); EMIT(off, 4); } while (0) 41 #define EMIT3_off32(b1, b2, b3, off) \ 42 do { EMIT3(b1, b2, b3); EMIT(off, 4); } while (0) 43 #define EMIT4_off32(b1, b2, b3, b4, off) \ 44 do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0) 45 46 static bool is_imm8(int value) 47 { 48 return value <= 127 && value >= -128; 49 } 50 51 static bool is_simm32(s64 value) 52 { 53 return value == (s64)(s32)value; 54 } 55 56 static bool is_uimm32(u64 value) 57 { 58 return value == (u64)(u32)value; 59 } 60 61 /* mov dst, src */ 62 #define EMIT_mov(DST, SRC) \ 63 do { \ 64 if (DST != SRC) \ 65 EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \ 66 } while (0) 67 68 static int bpf_size_to_x86_bytes(int bpf_size) 69 { 70 if (bpf_size == BPF_W) 71 return 4; 72 else if (bpf_size == BPF_H) 73 return 2; 74 else if (bpf_size == BPF_B) 75 return 1; 76 else if (bpf_size == BPF_DW) 77 return 4; /* imm32 */ 78 else 79 return 0; 80 } 81 82 /* 83 * List of x86 cond jumps opcodes (. + s8) 84 * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32) 85 */ 86 #define X86_JB 0x72 87 #define X86_JAE 0x73 88 #define X86_JE 0x74 89 #define X86_JNE 0x75 90 #define X86_JBE 0x76 91 #define X86_JA 0x77 92 #define X86_JL 0x7C 93 #define X86_JGE 0x7D 94 #define X86_JLE 0x7E 95 #define X86_JG 0x7F 96 97 /* Pick a register outside of BPF range for JIT internal work */ 98 #define AUX_REG (MAX_BPF_JIT_REG + 1) 99 100 /* 101 * The following table maps BPF registers to x86-64 registers. 102 * 103 * x86-64 register R12 is unused, since if used as base address 104 * register in load/store instructions, it always needs an 105 * extra byte of encoding and is callee saved. 106 * 107 * Also x86-64 register R9 is unused. x86-64 register R10 is 108 * used for blinding (if enabled). 109 */ 110 static const int reg2hex[] = { 111 [BPF_REG_0] = 0, /* RAX */ 112 [BPF_REG_1] = 7, /* RDI */ 113 [BPF_REG_2] = 6, /* RSI */ 114 [BPF_REG_3] = 2, /* RDX */ 115 [BPF_REG_4] = 1, /* RCX */ 116 [BPF_REG_5] = 0, /* R8 */ 117 [BPF_REG_6] = 3, /* RBX callee saved */ 118 [BPF_REG_7] = 5, /* R13 callee saved */ 119 [BPF_REG_8] = 6, /* R14 callee saved */ 120 [BPF_REG_9] = 7, /* R15 callee saved */ 121 [BPF_REG_FP] = 5, /* RBP readonly */ 122 [BPF_REG_AX] = 2, /* R10 temp register */ 123 [AUX_REG] = 3, /* R11 temp register */ 124 }; 125 126 /* 127 * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15 128 * which need extra byte of encoding. 129 * rax,rcx,...,rbp have simpler encoding 130 */ 131 static bool is_ereg(u32 reg) 132 { 133 return (1 << reg) & (BIT(BPF_REG_5) | 134 BIT(AUX_REG) | 135 BIT(BPF_REG_7) | 136 BIT(BPF_REG_8) | 137 BIT(BPF_REG_9) | 138 BIT(BPF_REG_AX)); 139 } 140 141 static bool is_axreg(u32 reg) 142 { 143 return reg == BPF_REG_0; 144 } 145 146 /* Add modifiers if 'reg' maps to x86-64 registers R8..R15 */ 147 static u8 add_1mod(u8 byte, u32 reg) 148 { 149 if (is_ereg(reg)) 150 byte |= 1; 151 return byte; 152 } 153 154 static u8 add_2mod(u8 byte, u32 r1, u32 r2) 155 { 156 if (is_ereg(r1)) 157 byte |= 1; 158 if (is_ereg(r2)) 159 byte |= 4; 160 return byte; 161 } 162 163 /* Encode 'dst_reg' register into x86-64 opcode 'byte' */ 164 static u8 add_1reg(u8 byte, u32 dst_reg) 165 { 166 return byte + reg2hex[dst_reg]; 167 } 168 169 /* Encode 'dst_reg' and 'src_reg' registers into x86-64 opcode 'byte' */ 170 static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg) 171 { 172 return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3); 173 } 174 175 static void jit_fill_hole(void *area, unsigned int size) 176 { 177 /* Fill whole space with INT3 instructions */ 178 memset(area, 0xcc, size); 179 } 180 181 struct jit_context { 182 int cleanup_addr; /* Epilogue code offset */ 183 }; 184 185 /* Maximum number of bytes emitted while JITing one eBPF insn */ 186 #define BPF_MAX_INSN_SIZE 128 187 #define BPF_INSN_SAFETY 64 188 189 #define PROLOGUE_SIZE 20 190 191 /* 192 * Emit x86-64 prologue code for BPF program and check its size. 193 * bpf_tail_call helper will skip it while jumping into another program 194 */ 195 static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) 196 { 197 u8 *prog = *pprog; 198 int cnt = 0; 199 200 EMIT1(0x55); /* push rbp */ 201 EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */ 202 /* sub rsp, rounded_stack_depth */ 203 EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); 204 EMIT1(0x53); /* push rbx */ 205 EMIT2(0x41, 0x55); /* push r13 */ 206 EMIT2(0x41, 0x56); /* push r14 */ 207 EMIT2(0x41, 0x57); /* push r15 */ 208 if (!ebpf_from_cbpf) { 209 /* zero init tail_call_cnt */ 210 EMIT2(0x6a, 0x00); 211 BUILD_BUG_ON(cnt != PROLOGUE_SIZE); 212 } 213 *pprog = prog; 214 } 215 216 /* 217 * Generate the following code: 218 * 219 * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ... 220 * if (index >= array->map.max_entries) 221 * goto out; 222 * if (++tail_call_cnt > MAX_TAIL_CALL_CNT) 223 * goto out; 224 * prog = array->ptrs[index]; 225 * if (prog == NULL) 226 * goto out; 227 * goto *(prog->bpf_func + prologue_size); 228 * out: 229 */ 230 static void emit_bpf_tail_call(u8 **pprog) 231 { 232 u8 *prog = *pprog; 233 int label1, label2, label3; 234 int cnt = 0; 235 236 /* 237 * rdi - pointer to ctx 238 * rsi - pointer to bpf_array 239 * rdx - index in bpf_array 240 */ 241 242 /* 243 * if (index >= array->map.max_entries) 244 * goto out; 245 */ 246 EMIT2(0x89, 0xD2); /* mov edx, edx */ 247 EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */ 248 offsetof(struct bpf_array, map.max_entries)); 249 #define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* Number of bytes to jump */ 250 EMIT2(X86_JBE, OFFSET1); /* jbe out */ 251 label1 = cnt; 252 253 /* 254 * if (tail_call_cnt > MAX_TAIL_CALL_CNT) 255 * goto out; 256 */ 257 EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ 258 EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ 259 #define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE) 260 EMIT2(X86_JA, OFFSET2); /* ja out */ 261 label2 = cnt; 262 EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ 263 EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ 264 265 /* prog = array->ptrs[index]; */ 266 EMIT4_off32(0x48, 0x8B, 0x84, 0xD6, /* mov rax, [rsi + rdx * 8 + offsetof(...)] */ 267 offsetof(struct bpf_array, ptrs)); 268 269 /* 270 * if (prog == NULL) 271 * goto out; 272 */ 273 EMIT3(0x48, 0x85, 0xC0); /* test rax,rax */ 274 #define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE) 275 EMIT2(X86_JE, OFFSET3); /* je out */ 276 label3 = cnt; 277 278 /* goto *(prog->bpf_func + prologue_size); */ 279 EMIT4(0x48, 0x8B, 0x40, /* mov rax, qword ptr [rax + 32] */ 280 offsetof(struct bpf_prog, bpf_func)); 281 EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE); /* add rax, prologue_size */ 282 283 /* 284 * Wow we're ready to jump into next BPF program 285 * rdi == ctx (1st arg) 286 * rax == prog->bpf_func + prologue_size 287 */ 288 RETPOLINE_RAX_BPF_JIT(); 289 290 /* out: */ 291 BUILD_BUG_ON(cnt - label1 != OFFSET1); 292 BUILD_BUG_ON(cnt - label2 != OFFSET2); 293 BUILD_BUG_ON(cnt - label3 != OFFSET3); 294 *pprog = prog; 295 } 296 297 static void emit_mov_imm32(u8 **pprog, bool sign_propagate, 298 u32 dst_reg, const u32 imm32) 299 { 300 u8 *prog = *pprog; 301 u8 b1, b2, b3; 302 int cnt = 0; 303 304 /* 305 * Optimization: if imm32 is positive, use 'mov %eax, imm32' 306 * (which zero-extends imm32) to save 2 bytes. 307 */ 308 if (sign_propagate && (s32)imm32 < 0) { 309 /* 'mov %rax, imm32' sign extends imm32 */ 310 b1 = add_1mod(0x48, dst_reg); 311 b2 = 0xC7; 312 b3 = 0xC0; 313 EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32); 314 goto done; 315 } 316 317 /* 318 * Optimization: if imm32 is zero, use 'xor %eax, %eax' 319 * to save 3 bytes. 320 */ 321 if (imm32 == 0) { 322 if (is_ereg(dst_reg)) 323 EMIT1(add_2mod(0x40, dst_reg, dst_reg)); 324 b2 = 0x31; /* xor */ 325 b3 = 0xC0; 326 EMIT2(b2, add_2reg(b3, dst_reg, dst_reg)); 327 goto done; 328 } 329 330 /* mov %eax, imm32 */ 331 if (is_ereg(dst_reg)) 332 EMIT1(add_1mod(0x40, dst_reg)); 333 EMIT1_off32(add_1reg(0xB8, dst_reg), imm32); 334 done: 335 *pprog = prog; 336 } 337 338 static void emit_mov_imm64(u8 **pprog, u32 dst_reg, 339 const u32 imm32_hi, const u32 imm32_lo) 340 { 341 u8 *prog = *pprog; 342 int cnt = 0; 343 344 if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) { 345 /* 346 * For emitting plain u32, where sign bit must not be 347 * propagated LLVM tends to load imm64 over mov32 348 * directly, so save couple of bytes by just doing 349 * 'mov %eax, imm32' instead. 350 */ 351 emit_mov_imm32(&prog, false, dst_reg, imm32_lo); 352 } else { 353 /* movabsq %rax, imm64 */ 354 EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg)); 355 EMIT(imm32_lo, 4); 356 EMIT(imm32_hi, 4); 357 } 358 359 *pprog = prog; 360 } 361 362 static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg) 363 { 364 u8 *prog = *pprog; 365 int cnt = 0; 366 367 if (is64) { 368 /* mov dst, src */ 369 EMIT_mov(dst_reg, src_reg); 370 } else { 371 /* mov32 dst, src */ 372 if (is_ereg(dst_reg) || is_ereg(src_reg)) 373 EMIT1(add_2mod(0x40, dst_reg, src_reg)); 374 EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg)); 375 } 376 377 *pprog = prog; 378 } 379 380 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, 381 int oldproglen, struct jit_context *ctx) 382 { 383 struct bpf_insn *insn = bpf_prog->insnsi; 384 int insn_cnt = bpf_prog->len; 385 bool seen_exit = false; 386 u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY]; 387 int i, cnt = 0; 388 int proglen = 0; 389 u8 *prog = temp; 390 391 emit_prologue(&prog, bpf_prog->aux->stack_depth, 392 bpf_prog_was_classic(bpf_prog)); 393 394 for (i = 0; i < insn_cnt; i++, insn++) { 395 const s32 imm32 = insn->imm; 396 u32 dst_reg = insn->dst_reg; 397 u32 src_reg = insn->src_reg; 398 u8 b2 = 0, b3 = 0; 399 s64 jmp_offset; 400 u8 jmp_cond; 401 int ilen; 402 u8 *func; 403 404 switch (insn->code) { 405 /* ALU */ 406 case BPF_ALU | BPF_ADD | BPF_X: 407 case BPF_ALU | BPF_SUB | BPF_X: 408 case BPF_ALU | BPF_AND | BPF_X: 409 case BPF_ALU | BPF_OR | BPF_X: 410 case BPF_ALU | BPF_XOR | BPF_X: 411 case BPF_ALU64 | BPF_ADD | BPF_X: 412 case BPF_ALU64 | BPF_SUB | BPF_X: 413 case BPF_ALU64 | BPF_AND | BPF_X: 414 case BPF_ALU64 | BPF_OR | BPF_X: 415 case BPF_ALU64 | BPF_XOR | BPF_X: 416 switch (BPF_OP(insn->code)) { 417 case BPF_ADD: b2 = 0x01; break; 418 case BPF_SUB: b2 = 0x29; break; 419 case BPF_AND: b2 = 0x21; break; 420 case BPF_OR: b2 = 0x09; break; 421 case BPF_XOR: b2 = 0x31; break; 422 } 423 if (BPF_CLASS(insn->code) == BPF_ALU64) 424 EMIT1(add_2mod(0x48, dst_reg, src_reg)); 425 else if (is_ereg(dst_reg) || is_ereg(src_reg)) 426 EMIT1(add_2mod(0x40, dst_reg, src_reg)); 427 EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg)); 428 break; 429 430 case BPF_ALU64 | BPF_MOV | BPF_X: 431 case BPF_ALU | BPF_MOV | BPF_X: 432 emit_mov_reg(&prog, 433 BPF_CLASS(insn->code) == BPF_ALU64, 434 dst_reg, src_reg); 435 break; 436 437 /* neg dst */ 438 case BPF_ALU | BPF_NEG: 439 case BPF_ALU64 | BPF_NEG: 440 if (BPF_CLASS(insn->code) == BPF_ALU64) 441 EMIT1(add_1mod(0x48, dst_reg)); 442 else if (is_ereg(dst_reg)) 443 EMIT1(add_1mod(0x40, dst_reg)); 444 EMIT2(0xF7, add_1reg(0xD8, dst_reg)); 445 break; 446 447 case BPF_ALU | BPF_ADD | BPF_K: 448 case BPF_ALU | BPF_SUB | BPF_K: 449 case BPF_ALU | BPF_AND | BPF_K: 450 case BPF_ALU | BPF_OR | BPF_K: 451 case BPF_ALU | BPF_XOR | BPF_K: 452 case BPF_ALU64 | BPF_ADD | BPF_K: 453 case BPF_ALU64 | BPF_SUB | BPF_K: 454 case BPF_ALU64 | BPF_AND | BPF_K: 455 case BPF_ALU64 | BPF_OR | BPF_K: 456 case BPF_ALU64 | BPF_XOR | BPF_K: 457 if (BPF_CLASS(insn->code) == BPF_ALU64) 458 EMIT1(add_1mod(0x48, dst_reg)); 459 else if (is_ereg(dst_reg)) 460 EMIT1(add_1mod(0x40, dst_reg)); 461 462 /* 463 * b3 holds 'normal' opcode, b2 short form only valid 464 * in case dst is eax/rax. 465 */ 466 switch (BPF_OP(insn->code)) { 467 case BPF_ADD: 468 b3 = 0xC0; 469 b2 = 0x05; 470 break; 471 case BPF_SUB: 472 b3 = 0xE8; 473 b2 = 0x2D; 474 break; 475 case BPF_AND: 476 b3 = 0xE0; 477 b2 = 0x25; 478 break; 479 case BPF_OR: 480 b3 = 0xC8; 481 b2 = 0x0D; 482 break; 483 case BPF_XOR: 484 b3 = 0xF0; 485 b2 = 0x35; 486 break; 487 } 488 489 if (is_imm8(imm32)) 490 EMIT3(0x83, add_1reg(b3, dst_reg), imm32); 491 else if (is_axreg(dst_reg)) 492 EMIT1_off32(b2, imm32); 493 else 494 EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32); 495 break; 496 497 case BPF_ALU64 | BPF_MOV | BPF_K: 498 case BPF_ALU | BPF_MOV | BPF_K: 499 emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64, 500 dst_reg, imm32); 501 break; 502 503 case BPF_LD | BPF_IMM | BPF_DW: 504 emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm); 505 insn++; 506 i++; 507 break; 508 509 /* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */ 510 case BPF_ALU | BPF_MOD | BPF_X: 511 case BPF_ALU | BPF_DIV | BPF_X: 512 case BPF_ALU | BPF_MOD | BPF_K: 513 case BPF_ALU | BPF_DIV | BPF_K: 514 case BPF_ALU64 | BPF_MOD | BPF_X: 515 case BPF_ALU64 | BPF_DIV | BPF_X: 516 case BPF_ALU64 | BPF_MOD | BPF_K: 517 case BPF_ALU64 | BPF_DIV | BPF_K: 518 EMIT1(0x50); /* push rax */ 519 EMIT1(0x52); /* push rdx */ 520 521 if (BPF_SRC(insn->code) == BPF_X) 522 /* mov r11, src_reg */ 523 EMIT_mov(AUX_REG, src_reg); 524 else 525 /* mov r11, imm32 */ 526 EMIT3_off32(0x49, 0xC7, 0xC3, imm32); 527 528 /* mov rax, dst_reg */ 529 EMIT_mov(BPF_REG_0, dst_reg); 530 531 /* 532 * xor edx, edx 533 * equivalent to 'xor rdx, rdx', but one byte less 534 */ 535 EMIT2(0x31, 0xd2); 536 537 if (BPF_CLASS(insn->code) == BPF_ALU64) 538 /* div r11 */ 539 EMIT3(0x49, 0xF7, 0xF3); 540 else 541 /* div r11d */ 542 EMIT3(0x41, 0xF7, 0xF3); 543 544 if (BPF_OP(insn->code) == BPF_MOD) 545 /* mov r11, rdx */ 546 EMIT3(0x49, 0x89, 0xD3); 547 else 548 /* mov r11, rax */ 549 EMIT3(0x49, 0x89, 0xC3); 550 551 EMIT1(0x5A); /* pop rdx */ 552 EMIT1(0x58); /* pop rax */ 553 554 /* mov dst_reg, r11 */ 555 EMIT_mov(dst_reg, AUX_REG); 556 break; 557 558 case BPF_ALU | BPF_MUL | BPF_K: 559 case BPF_ALU | BPF_MUL | BPF_X: 560 case BPF_ALU64 | BPF_MUL | BPF_K: 561 case BPF_ALU64 | BPF_MUL | BPF_X: 562 { 563 bool is64 = BPF_CLASS(insn->code) == BPF_ALU64; 564 565 if (dst_reg != BPF_REG_0) 566 EMIT1(0x50); /* push rax */ 567 if (dst_reg != BPF_REG_3) 568 EMIT1(0x52); /* push rdx */ 569 570 /* mov r11, dst_reg */ 571 EMIT_mov(AUX_REG, dst_reg); 572 573 if (BPF_SRC(insn->code) == BPF_X) 574 emit_mov_reg(&prog, is64, BPF_REG_0, src_reg); 575 else 576 emit_mov_imm32(&prog, is64, BPF_REG_0, imm32); 577 578 if (is64) 579 EMIT1(add_1mod(0x48, AUX_REG)); 580 else if (is_ereg(AUX_REG)) 581 EMIT1(add_1mod(0x40, AUX_REG)); 582 /* mul(q) r11 */ 583 EMIT2(0xF7, add_1reg(0xE0, AUX_REG)); 584 585 if (dst_reg != BPF_REG_3) 586 EMIT1(0x5A); /* pop rdx */ 587 if (dst_reg != BPF_REG_0) { 588 /* mov dst_reg, rax */ 589 EMIT_mov(dst_reg, BPF_REG_0); 590 EMIT1(0x58); /* pop rax */ 591 } 592 break; 593 } 594 /* Shifts */ 595 case BPF_ALU | BPF_LSH | BPF_K: 596 case BPF_ALU | BPF_RSH | BPF_K: 597 case BPF_ALU | BPF_ARSH | BPF_K: 598 case BPF_ALU64 | BPF_LSH | BPF_K: 599 case BPF_ALU64 | BPF_RSH | BPF_K: 600 case BPF_ALU64 | BPF_ARSH | BPF_K: 601 if (BPF_CLASS(insn->code) == BPF_ALU64) 602 EMIT1(add_1mod(0x48, dst_reg)); 603 else if (is_ereg(dst_reg)) 604 EMIT1(add_1mod(0x40, dst_reg)); 605 606 switch (BPF_OP(insn->code)) { 607 case BPF_LSH: b3 = 0xE0; break; 608 case BPF_RSH: b3 = 0xE8; break; 609 case BPF_ARSH: b3 = 0xF8; break; 610 } 611 612 if (imm32 == 1) 613 EMIT2(0xD1, add_1reg(b3, dst_reg)); 614 else 615 EMIT3(0xC1, add_1reg(b3, dst_reg), imm32); 616 break; 617 618 case BPF_ALU | BPF_LSH | BPF_X: 619 case BPF_ALU | BPF_RSH | BPF_X: 620 case BPF_ALU | BPF_ARSH | BPF_X: 621 case BPF_ALU64 | BPF_LSH | BPF_X: 622 case BPF_ALU64 | BPF_RSH | BPF_X: 623 case BPF_ALU64 | BPF_ARSH | BPF_X: 624 625 /* Check for bad case when dst_reg == rcx */ 626 if (dst_reg == BPF_REG_4) { 627 /* mov r11, dst_reg */ 628 EMIT_mov(AUX_REG, dst_reg); 629 dst_reg = AUX_REG; 630 } 631 632 if (src_reg != BPF_REG_4) { /* common case */ 633 EMIT1(0x51); /* push rcx */ 634 635 /* mov rcx, src_reg */ 636 EMIT_mov(BPF_REG_4, src_reg); 637 } 638 639 /* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */ 640 if (BPF_CLASS(insn->code) == BPF_ALU64) 641 EMIT1(add_1mod(0x48, dst_reg)); 642 else if (is_ereg(dst_reg)) 643 EMIT1(add_1mod(0x40, dst_reg)); 644 645 switch (BPF_OP(insn->code)) { 646 case BPF_LSH: b3 = 0xE0; break; 647 case BPF_RSH: b3 = 0xE8; break; 648 case BPF_ARSH: b3 = 0xF8; break; 649 } 650 EMIT2(0xD3, add_1reg(b3, dst_reg)); 651 652 if (src_reg != BPF_REG_4) 653 EMIT1(0x59); /* pop rcx */ 654 655 if (insn->dst_reg == BPF_REG_4) 656 /* mov dst_reg, r11 */ 657 EMIT_mov(insn->dst_reg, AUX_REG); 658 break; 659 660 case BPF_ALU | BPF_END | BPF_FROM_BE: 661 switch (imm32) { 662 case 16: 663 /* Emit 'ror %ax, 8' to swap lower 2 bytes */ 664 EMIT1(0x66); 665 if (is_ereg(dst_reg)) 666 EMIT1(0x41); 667 EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8); 668 669 /* Emit 'movzwl eax, ax' */ 670 if (is_ereg(dst_reg)) 671 EMIT3(0x45, 0x0F, 0xB7); 672 else 673 EMIT2(0x0F, 0xB7); 674 EMIT1(add_2reg(0xC0, dst_reg, dst_reg)); 675 break; 676 case 32: 677 /* Emit 'bswap eax' to swap lower 4 bytes */ 678 if (is_ereg(dst_reg)) 679 EMIT2(0x41, 0x0F); 680 else 681 EMIT1(0x0F); 682 EMIT1(add_1reg(0xC8, dst_reg)); 683 break; 684 case 64: 685 /* Emit 'bswap rax' to swap 8 bytes */ 686 EMIT3(add_1mod(0x48, dst_reg), 0x0F, 687 add_1reg(0xC8, dst_reg)); 688 break; 689 } 690 break; 691 692 case BPF_ALU | BPF_END | BPF_FROM_LE: 693 switch (imm32) { 694 case 16: 695 /* 696 * Emit 'movzwl eax, ax' to zero extend 16-bit 697 * into 64 bit 698 */ 699 if (is_ereg(dst_reg)) 700 EMIT3(0x45, 0x0F, 0xB7); 701 else 702 EMIT2(0x0F, 0xB7); 703 EMIT1(add_2reg(0xC0, dst_reg, dst_reg)); 704 break; 705 case 32: 706 /* Emit 'mov eax, eax' to clear upper 32-bits */ 707 if (is_ereg(dst_reg)) 708 EMIT1(0x45); 709 EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg)); 710 break; 711 case 64: 712 /* nop */ 713 break; 714 } 715 break; 716 717 /* ST: *(u8*)(dst_reg + off) = imm */ 718 case BPF_ST | BPF_MEM | BPF_B: 719 if (is_ereg(dst_reg)) 720 EMIT2(0x41, 0xC6); 721 else 722 EMIT1(0xC6); 723 goto st; 724 case BPF_ST | BPF_MEM | BPF_H: 725 if (is_ereg(dst_reg)) 726 EMIT3(0x66, 0x41, 0xC7); 727 else 728 EMIT2(0x66, 0xC7); 729 goto st; 730 case BPF_ST | BPF_MEM | BPF_W: 731 if (is_ereg(dst_reg)) 732 EMIT2(0x41, 0xC7); 733 else 734 EMIT1(0xC7); 735 goto st; 736 case BPF_ST | BPF_MEM | BPF_DW: 737 EMIT2(add_1mod(0x48, dst_reg), 0xC7); 738 739 st: if (is_imm8(insn->off)) 740 EMIT2(add_1reg(0x40, dst_reg), insn->off); 741 else 742 EMIT1_off32(add_1reg(0x80, dst_reg), insn->off); 743 744 EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code))); 745 break; 746 747 /* STX: *(u8*)(dst_reg + off) = src_reg */ 748 case BPF_STX | BPF_MEM | BPF_B: 749 /* Emit 'mov byte ptr [rax + off], al' */ 750 if (is_ereg(dst_reg) || is_ereg(src_reg) || 751 /* We have to add extra byte for x86 SIL, DIL regs */ 752 src_reg == BPF_REG_1 || src_reg == BPF_REG_2) 753 EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88); 754 else 755 EMIT1(0x88); 756 goto stx; 757 case BPF_STX | BPF_MEM | BPF_H: 758 if (is_ereg(dst_reg) || is_ereg(src_reg)) 759 EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89); 760 else 761 EMIT2(0x66, 0x89); 762 goto stx; 763 case BPF_STX | BPF_MEM | BPF_W: 764 if (is_ereg(dst_reg) || is_ereg(src_reg)) 765 EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89); 766 else 767 EMIT1(0x89); 768 goto stx; 769 case BPF_STX | BPF_MEM | BPF_DW: 770 EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89); 771 stx: if (is_imm8(insn->off)) 772 EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off); 773 else 774 EMIT1_off32(add_2reg(0x80, dst_reg, src_reg), 775 insn->off); 776 break; 777 778 /* LDX: dst_reg = *(u8*)(src_reg + off) */ 779 case BPF_LDX | BPF_MEM | BPF_B: 780 /* Emit 'movzx rax, byte ptr [rax + off]' */ 781 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6); 782 goto ldx; 783 case BPF_LDX | BPF_MEM | BPF_H: 784 /* Emit 'movzx rax, word ptr [rax + off]' */ 785 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7); 786 goto ldx; 787 case BPF_LDX | BPF_MEM | BPF_W: 788 /* Emit 'mov eax, dword ptr [rax+0x14]' */ 789 if (is_ereg(dst_reg) || is_ereg(src_reg)) 790 EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B); 791 else 792 EMIT1(0x8B); 793 goto ldx; 794 case BPF_LDX | BPF_MEM | BPF_DW: 795 /* Emit 'mov rax, qword ptr [rax+0x14]' */ 796 EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B); 797 ldx: /* 798 * If insn->off == 0 we can save one extra byte, but 799 * special case of x86 R13 which always needs an offset 800 * is not worth the hassle 801 */ 802 if (is_imm8(insn->off)) 803 EMIT2(add_2reg(0x40, src_reg, dst_reg), insn->off); 804 else 805 EMIT1_off32(add_2reg(0x80, src_reg, dst_reg), 806 insn->off); 807 break; 808 809 /* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */ 810 case BPF_STX | BPF_XADD | BPF_W: 811 /* Emit 'lock add dword ptr [rax + off], eax' */ 812 if (is_ereg(dst_reg) || is_ereg(src_reg)) 813 EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01); 814 else 815 EMIT2(0xF0, 0x01); 816 goto xadd; 817 case BPF_STX | BPF_XADD | BPF_DW: 818 EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01); 819 xadd: if (is_imm8(insn->off)) 820 EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off); 821 else 822 EMIT1_off32(add_2reg(0x80, dst_reg, src_reg), 823 insn->off); 824 break; 825 826 /* call */ 827 case BPF_JMP | BPF_CALL: 828 func = (u8 *) __bpf_call_base + imm32; 829 jmp_offset = func - (image + addrs[i]); 830 if (!imm32 || !is_simm32(jmp_offset)) { 831 pr_err("unsupported BPF func %d addr %p image %p\n", 832 imm32, func, image); 833 return -EINVAL; 834 } 835 EMIT1_off32(0xE8, jmp_offset); 836 break; 837 838 case BPF_JMP | BPF_TAIL_CALL: 839 emit_bpf_tail_call(&prog); 840 break; 841 842 /* cond jump */ 843 case BPF_JMP | BPF_JEQ | BPF_X: 844 case BPF_JMP | BPF_JNE | BPF_X: 845 case BPF_JMP | BPF_JGT | BPF_X: 846 case BPF_JMP | BPF_JLT | BPF_X: 847 case BPF_JMP | BPF_JGE | BPF_X: 848 case BPF_JMP | BPF_JLE | BPF_X: 849 case BPF_JMP | BPF_JSGT | BPF_X: 850 case BPF_JMP | BPF_JSLT | BPF_X: 851 case BPF_JMP | BPF_JSGE | BPF_X: 852 case BPF_JMP | BPF_JSLE | BPF_X: 853 case BPF_JMP32 | BPF_JEQ | BPF_X: 854 case BPF_JMP32 | BPF_JNE | BPF_X: 855 case BPF_JMP32 | BPF_JGT | BPF_X: 856 case BPF_JMP32 | BPF_JLT | BPF_X: 857 case BPF_JMP32 | BPF_JGE | BPF_X: 858 case BPF_JMP32 | BPF_JLE | BPF_X: 859 case BPF_JMP32 | BPF_JSGT | BPF_X: 860 case BPF_JMP32 | BPF_JSLT | BPF_X: 861 case BPF_JMP32 | BPF_JSGE | BPF_X: 862 case BPF_JMP32 | BPF_JSLE | BPF_X: 863 /* cmp dst_reg, src_reg */ 864 if (BPF_CLASS(insn->code) == BPF_JMP) 865 EMIT1(add_2mod(0x48, dst_reg, src_reg)); 866 else if (is_ereg(dst_reg) || is_ereg(src_reg)) 867 EMIT1(add_2mod(0x40, dst_reg, src_reg)); 868 EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg)); 869 goto emit_cond_jmp; 870 871 case BPF_JMP | BPF_JSET | BPF_X: 872 case BPF_JMP32 | BPF_JSET | BPF_X: 873 /* test dst_reg, src_reg */ 874 if (BPF_CLASS(insn->code) == BPF_JMP) 875 EMIT1(add_2mod(0x48, dst_reg, src_reg)); 876 else if (is_ereg(dst_reg) || is_ereg(src_reg)) 877 EMIT1(add_2mod(0x40, dst_reg, src_reg)); 878 EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg)); 879 goto emit_cond_jmp; 880 881 case BPF_JMP | BPF_JSET | BPF_K: 882 case BPF_JMP32 | BPF_JSET | BPF_K: 883 /* test dst_reg, imm32 */ 884 if (BPF_CLASS(insn->code) == BPF_JMP) 885 EMIT1(add_1mod(0x48, dst_reg)); 886 else if (is_ereg(dst_reg)) 887 EMIT1(add_1mod(0x40, dst_reg)); 888 EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32); 889 goto emit_cond_jmp; 890 891 case BPF_JMP | BPF_JEQ | BPF_K: 892 case BPF_JMP | BPF_JNE | BPF_K: 893 case BPF_JMP | BPF_JGT | BPF_K: 894 case BPF_JMP | BPF_JLT | BPF_K: 895 case BPF_JMP | BPF_JGE | BPF_K: 896 case BPF_JMP | BPF_JLE | BPF_K: 897 case BPF_JMP | BPF_JSGT | BPF_K: 898 case BPF_JMP | BPF_JSLT | BPF_K: 899 case BPF_JMP | BPF_JSGE | BPF_K: 900 case BPF_JMP | BPF_JSLE | BPF_K: 901 case BPF_JMP32 | BPF_JEQ | BPF_K: 902 case BPF_JMP32 | BPF_JNE | BPF_K: 903 case BPF_JMP32 | BPF_JGT | BPF_K: 904 case BPF_JMP32 | BPF_JLT | BPF_K: 905 case BPF_JMP32 | BPF_JGE | BPF_K: 906 case BPF_JMP32 | BPF_JLE | BPF_K: 907 case BPF_JMP32 | BPF_JSGT | BPF_K: 908 case BPF_JMP32 | BPF_JSLT | BPF_K: 909 case BPF_JMP32 | BPF_JSGE | BPF_K: 910 case BPF_JMP32 | BPF_JSLE | BPF_K: 911 /* cmp dst_reg, imm8/32 */ 912 if (BPF_CLASS(insn->code) == BPF_JMP) 913 EMIT1(add_1mod(0x48, dst_reg)); 914 else if (is_ereg(dst_reg)) 915 EMIT1(add_1mod(0x40, dst_reg)); 916 917 if (is_imm8(imm32)) 918 EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32); 919 else 920 EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32); 921 922 emit_cond_jmp: /* Convert BPF opcode to x86 */ 923 switch (BPF_OP(insn->code)) { 924 case BPF_JEQ: 925 jmp_cond = X86_JE; 926 break; 927 case BPF_JSET: 928 case BPF_JNE: 929 jmp_cond = X86_JNE; 930 break; 931 case BPF_JGT: 932 /* GT is unsigned '>', JA in x86 */ 933 jmp_cond = X86_JA; 934 break; 935 case BPF_JLT: 936 /* LT is unsigned '<', JB in x86 */ 937 jmp_cond = X86_JB; 938 break; 939 case BPF_JGE: 940 /* GE is unsigned '>=', JAE in x86 */ 941 jmp_cond = X86_JAE; 942 break; 943 case BPF_JLE: 944 /* LE is unsigned '<=', JBE in x86 */ 945 jmp_cond = X86_JBE; 946 break; 947 case BPF_JSGT: 948 /* Signed '>', GT in x86 */ 949 jmp_cond = X86_JG; 950 break; 951 case BPF_JSLT: 952 /* Signed '<', LT in x86 */ 953 jmp_cond = X86_JL; 954 break; 955 case BPF_JSGE: 956 /* Signed '>=', GE in x86 */ 957 jmp_cond = X86_JGE; 958 break; 959 case BPF_JSLE: 960 /* Signed '<=', LE in x86 */ 961 jmp_cond = X86_JLE; 962 break; 963 default: /* to silence GCC warning */ 964 return -EFAULT; 965 } 966 jmp_offset = addrs[i + insn->off] - addrs[i]; 967 if (is_imm8(jmp_offset)) { 968 EMIT2(jmp_cond, jmp_offset); 969 } else if (is_simm32(jmp_offset)) { 970 EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset); 971 } else { 972 pr_err("cond_jmp gen bug %llx\n", jmp_offset); 973 return -EFAULT; 974 } 975 976 break; 977 978 case BPF_JMP | BPF_JA: 979 if (insn->off == -1) 980 /* -1 jmp instructions will always jump 981 * backwards two bytes. Explicitly handling 982 * this case avoids wasting too many passes 983 * when there are long sequences of replaced 984 * dead code. 985 */ 986 jmp_offset = -2; 987 else 988 jmp_offset = addrs[i + insn->off] - addrs[i]; 989 990 if (!jmp_offset) 991 /* Optimize out nop jumps */ 992 break; 993 emit_jmp: 994 if (is_imm8(jmp_offset)) { 995 EMIT2(0xEB, jmp_offset); 996 } else if (is_simm32(jmp_offset)) { 997 EMIT1_off32(0xE9, jmp_offset); 998 } else { 999 pr_err("jmp gen bug %llx\n", jmp_offset); 1000 return -EFAULT; 1001 } 1002 break; 1003 1004 case BPF_JMP | BPF_EXIT: 1005 if (seen_exit) { 1006 jmp_offset = ctx->cleanup_addr - addrs[i]; 1007 goto emit_jmp; 1008 } 1009 seen_exit = true; 1010 /* Update cleanup_addr */ 1011 ctx->cleanup_addr = proglen; 1012 if (!bpf_prog_was_classic(bpf_prog)) 1013 EMIT1(0x5B); /* get rid of tail_call_cnt */ 1014 EMIT2(0x41, 0x5F); /* pop r15 */ 1015 EMIT2(0x41, 0x5E); /* pop r14 */ 1016 EMIT2(0x41, 0x5D); /* pop r13 */ 1017 EMIT1(0x5B); /* pop rbx */ 1018 EMIT1(0xC9); /* leave */ 1019 EMIT1(0xC3); /* ret */ 1020 break; 1021 1022 default: 1023 /* 1024 * By design x86-64 JIT should support all BPF instructions. 1025 * This error will be seen if new instruction was added 1026 * to the interpreter, but not to the JIT, or if there is 1027 * junk in bpf_prog. 1028 */ 1029 pr_err("bpf_jit: unknown opcode %02x\n", insn->code); 1030 return -EINVAL; 1031 } 1032 1033 ilen = prog - temp; 1034 if (ilen > BPF_MAX_INSN_SIZE) { 1035 pr_err("bpf_jit: fatal insn size error\n"); 1036 return -EFAULT; 1037 } 1038 1039 if (image) { 1040 if (unlikely(proglen + ilen > oldproglen)) { 1041 pr_err("bpf_jit: fatal error\n"); 1042 return -EFAULT; 1043 } 1044 memcpy(image + proglen, temp, ilen); 1045 } 1046 proglen += ilen; 1047 addrs[i] = proglen; 1048 prog = temp; 1049 } 1050 return proglen; 1051 } 1052 1053 struct x64_jit_data { 1054 struct bpf_binary_header *header; 1055 int *addrs; 1056 u8 *image; 1057 int proglen; 1058 struct jit_context ctx; 1059 }; 1060 1061 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) 1062 { 1063 struct bpf_binary_header *header = NULL; 1064 struct bpf_prog *tmp, *orig_prog = prog; 1065 struct x64_jit_data *jit_data; 1066 int proglen, oldproglen = 0; 1067 struct jit_context ctx = {}; 1068 bool tmp_blinded = false; 1069 bool extra_pass = false; 1070 u8 *image = NULL; 1071 int *addrs; 1072 int pass; 1073 int i; 1074 1075 if (!prog->jit_requested) 1076 return orig_prog; 1077 1078 tmp = bpf_jit_blind_constants(prog); 1079 /* 1080 * If blinding was requested and we failed during blinding, 1081 * we must fall back to the interpreter. 1082 */ 1083 if (IS_ERR(tmp)) 1084 return orig_prog; 1085 if (tmp != prog) { 1086 tmp_blinded = true; 1087 prog = tmp; 1088 } 1089 1090 jit_data = prog->aux->jit_data; 1091 if (!jit_data) { 1092 jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL); 1093 if (!jit_data) { 1094 prog = orig_prog; 1095 goto out; 1096 } 1097 prog->aux->jit_data = jit_data; 1098 } 1099 addrs = jit_data->addrs; 1100 if (addrs) { 1101 ctx = jit_data->ctx; 1102 oldproglen = jit_data->proglen; 1103 image = jit_data->image; 1104 header = jit_data->header; 1105 extra_pass = true; 1106 goto skip_init_addrs; 1107 } 1108 addrs = kmalloc_array(prog->len, sizeof(*addrs), GFP_KERNEL); 1109 if (!addrs) { 1110 prog = orig_prog; 1111 goto out_addrs; 1112 } 1113 1114 /* 1115 * Before first pass, make a rough estimation of addrs[] 1116 * each BPF instruction is translated to less than 64 bytes 1117 */ 1118 for (proglen = 0, i = 0; i < prog->len; i++) { 1119 proglen += 64; 1120 addrs[i] = proglen; 1121 } 1122 ctx.cleanup_addr = proglen; 1123 skip_init_addrs: 1124 1125 /* 1126 * JITed image shrinks with every pass and the loop iterates 1127 * until the image stops shrinking. Very large BPF programs 1128 * may converge on the last pass. In such case do one more 1129 * pass to emit the final image. 1130 */ 1131 for (pass = 0; pass < 20 || image; pass++) { 1132 proglen = do_jit(prog, addrs, image, oldproglen, &ctx); 1133 if (proglen <= 0) { 1134 out_image: 1135 image = NULL; 1136 if (header) 1137 bpf_jit_binary_free(header); 1138 prog = orig_prog; 1139 goto out_addrs; 1140 } 1141 if (image) { 1142 if (proglen != oldproglen) { 1143 pr_err("bpf_jit: proglen=%d != oldproglen=%d\n", 1144 proglen, oldproglen); 1145 goto out_image; 1146 } 1147 break; 1148 } 1149 if (proglen == oldproglen) { 1150 header = bpf_jit_binary_alloc(proglen, &image, 1151 1, jit_fill_hole); 1152 if (!header) { 1153 prog = orig_prog; 1154 goto out_addrs; 1155 } 1156 } 1157 oldproglen = proglen; 1158 cond_resched(); 1159 } 1160 1161 if (bpf_jit_enable > 1) 1162 bpf_jit_dump(prog->len, proglen, pass + 1, image); 1163 1164 if (image) { 1165 if (!prog->is_func || extra_pass) { 1166 bpf_jit_binary_lock_ro(header); 1167 } else { 1168 jit_data->addrs = addrs; 1169 jit_data->ctx = ctx; 1170 jit_data->proglen = proglen; 1171 jit_data->image = image; 1172 jit_data->header = header; 1173 } 1174 prog->bpf_func = (void *)image; 1175 prog->jited = 1; 1176 prog->jited_len = proglen; 1177 } else { 1178 prog = orig_prog; 1179 } 1180 1181 if (!image || !prog->is_func || extra_pass) { 1182 if (image) 1183 bpf_prog_fill_jited_linfo(prog, addrs); 1184 out_addrs: 1185 kfree(addrs); 1186 kfree(jit_data); 1187 prog->aux->jit_data = NULL; 1188 } 1189 out: 1190 if (tmp_blinded) 1191 bpf_jit_prog_release_other(prog, prog == orig_prog ? 1192 tmp : orig_prog); 1193 return prog; 1194 } 1195