1 #include <stdio.h> 2 #include <stdint.h> 3 #include <stdlib.h> 4 #include <string.h> 5 6 typedef void (*testfn)(void); 7 8 typedef struct { 9 uint64_t q0, q1, q2, q3; 10 } __attribute__((aligned(32))) v4di; 11 12 typedef struct { 13 uint64_t mm[8]; 14 v4di ymm[16]; 15 uint64_t r[16]; 16 uint64_t flags; 17 uint32_t ff; 18 uint64_t pad; 19 v4di mem[4]; 20 v4di mem0[4]; 21 } reg_state; 22 23 typedef struct { 24 int n; 25 testfn fn; 26 const char *s; 27 reg_state *init; 28 } TestDef; 29 30 reg_state initI; 31 reg_state initF16; 32 reg_state initF32; 33 reg_state initF64; 34 35 static void dump_ymm(const char *name, int n, const v4di *r, int ff) 36 { 37 printf("%s%d = %016lx %016lx %016lx %016lx\n", 38 name, n, r->q3, r->q2, r->q1, r->q0); 39 if (ff == 64) { 40 double v[4]; 41 memcpy(v, r, sizeof(v)); 42 printf(" %16g %16g %16g %16g\n", 43 v[3], v[2], v[1], v[0]); 44 } else if (ff == 32) { 45 float v[8]; 46 memcpy(v, r, sizeof(v)); 47 printf(" %8g %8g %8g %8g %8g %8g %8g %8g\n", 48 v[7], v[6], v[5], v[4], v[3], v[2], v[1], v[0]); 49 } 50 } 51 52 static void dump_regs(reg_state *s) 53 { 54 int i; 55 56 for (i = 0; i < 16; i++) { 57 dump_ymm("ymm", i, &s->ymm[i], 0); 58 } 59 for (i = 0; i < 4; i++) { 60 dump_ymm("mem", i, &s->mem0[i], 0); 61 } 62 } 63 64 static void compare_state(const reg_state *a, const reg_state *b) 65 { 66 int i; 67 for (i = 0; i < 8; i++) { 68 if (a->mm[i] != b->mm[i]) { 69 printf("MM%d = %016lx\n", i, b->mm[i]); 70 } 71 } 72 for (i = 0; i < 16; i++) { 73 if (a->r[i] != b->r[i]) { 74 printf("r%d = %016lx\n", i, b->r[i]); 75 } 76 } 77 for (i = 0; i < 16; i++) { 78 if (memcmp(&a->ymm[i], &b->ymm[i], 32)) { 79 dump_ymm("ymm", i, &b->ymm[i], a->ff); 80 } 81 } 82 for (i = 0; i < 4; i++) { 83 if (memcmp(&a->mem0[i], &a->mem[i], 32)) { 84 dump_ymm("mem", i, &a->mem[i], a->ff); 85 } 86 } 87 if (a->flags != b->flags) { 88 printf("FLAGS = %016lx\n", b->flags); 89 } 90 } 91 92 #define LOADMM(r, o) "movq " #r ", " #o "[%0]\n\t" 93 #define LOADYMM(r, o) "vmovdqa " #r ", " #o "[%0]\n\t" 94 #define STOREMM(r, o) "movq " #o "[%1], " #r "\n\t" 95 #define STOREYMM(r, o) "vmovdqa " #o "[%1], " #r "\n\t" 96 #define MMREG(F) \ 97 F(mm0, 0x00) \ 98 F(mm1, 0x08) \ 99 F(mm2, 0x10) \ 100 F(mm3, 0x18) \ 101 F(mm4, 0x20) \ 102 F(mm5, 0x28) \ 103 F(mm6, 0x30) \ 104 F(mm7, 0x38) 105 #define YMMREG(F) \ 106 F(ymm0, 0x040) \ 107 F(ymm1, 0x060) \ 108 F(ymm2, 0x080) \ 109 F(ymm3, 0x0a0) \ 110 F(ymm4, 0x0c0) \ 111 F(ymm5, 0x0e0) \ 112 F(ymm6, 0x100) \ 113 F(ymm7, 0x120) \ 114 F(ymm8, 0x140) \ 115 F(ymm9, 0x160) \ 116 F(ymm10, 0x180) \ 117 F(ymm11, 0x1a0) \ 118 F(ymm12, 0x1c0) \ 119 F(ymm13, 0x1e0) \ 120 F(ymm14, 0x200) \ 121 F(ymm15, 0x220) 122 #define LOADREG(r, o) "mov " #r ", " #o "[rax]\n\t" 123 #define STOREREG(r, o) "mov " #o "[rax], " #r "\n\t" 124 #define REG(F) \ 125 F(rbx, 0x248) \ 126 F(rcx, 0x250) \ 127 F(rdx, 0x258) \ 128 F(rsi, 0x260) \ 129 F(rdi, 0x268) \ 130 F(r8, 0x280) \ 131 F(r9, 0x288) \ 132 F(r10, 0x290) \ 133 F(r11, 0x298) \ 134 F(r12, 0x2a0) \ 135 F(r13, 0x2a8) \ 136 F(r14, 0x2b0) \ 137 F(r15, 0x2b8) \ 138 139 static void run_test(const TestDef *t) 140 { 141 reg_state result; 142 reg_state *init = t->init; 143 memcpy(init->mem, init->mem0, sizeof(init->mem)); 144 printf("%5d %s\n", t->n, t->s); 145 asm volatile( 146 MMREG(LOADMM) 147 YMMREG(LOADYMM) 148 "sub rsp, 128\n\t" 149 "push rax\n\t" 150 "push rbx\n\t" 151 "push rcx\n\t" 152 "push rdx\n\t" 153 "push %1\n\t" 154 "push %2\n\t" 155 "mov rax, %0\n\t" 156 "pushf\n\t" 157 "pop rbx\n\t" 158 "shr rbx, 8\n\t" 159 "shl rbx, 8\n\t" 160 "mov rcx, 0x2c0[rax]\n\t" 161 "and rcx, 0xff\n\t" 162 "or rbx, rcx\n\t" 163 "push rbx\n\t" 164 "popf\n\t" 165 REG(LOADREG) 166 "mov rax, 0x240[rax]\n\t" 167 "call [rsp]\n\t" 168 "mov [rsp], rax\n\t" 169 "mov rax, 8[rsp]\n\t" 170 REG(STOREREG) 171 "mov rbx, [rsp]\n\t" 172 "mov 0x240[rax], rbx\n\t" 173 "mov rbx, 0\n\t" 174 "mov 0x270[rax], rbx\n\t" 175 "mov 0x278[rax], rbx\n\t" 176 "pushf\n\t" 177 "pop rbx\n\t" 178 "and rbx, 0xff\n\t" 179 "mov 0x2c0[rax], rbx\n\t" 180 "add rsp, 16\n\t" 181 "pop rdx\n\t" 182 "pop rcx\n\t" 183 "pop rbx\n\t" 184 "pop rax\n\t" 185 "add rsp, 128\n\t" 186 MMREG(STOREMM) 187 YMMREG(STOREYMM) 188 : : "r"(init), "r"(&result), "r"(t->fn) 189 : "memory", "cc", 190 "rsi", "rdi", 191 "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", 192 "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7", 193 "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", 194 "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", 195 "ymm12", "ymm13", "ymm14", "ymm15" 196 ); 197 compare_state(init, &result); 198 } 199 200 #define TEST(n, cmd, type) \ 201 static void __attribute__((naked)) test_##n(void) \ 202 { \ 203 asm volatile(cmd); \ 204 asm volatile("ret"); \ 205 } 206 #include "test-avx.h" 207 208 209 static const TestDef test_table[] = { 210 #define TEST(n, cmd, type) {n, test_##n, cmd, &init##type}, 211 #include "test-avx.h" 212 {-1, NULL, "", NULL} 213 }; 214 215 static void run_all(void) 216 { 217 const TestDef *t; 218 for (t = test_table; t->fn; t++) { 219 run_test(t); 220 } 221 } 222 223 #define ARRAY_LEN(x) (sizeof(x) / sizeof(x[0])) 224 225 uint16_t val_f16[] = { 0x4000, 0xbc00, 0x44cd, 0x3a66, 0x4200, 0x7a1a, 0x4780, 0x4826 }; 226 float val_f32[] = {2.0, -1.0, 4.8, 0.8, 3, -42.0, 5e6, 7.5, 8.3}; 227 double val_f64[] = {2.0, -1.0, 4.8, 0.8, 3, -42.0, 5e6, 7.5}; 228 v4di val_i64[] = { 229 {0x3d6b3b6a9e4118f2lu, 0x355ae76d2774d78clu, 230 0xac3ff76c4daa4b28lu, 0xe7fabd204cb54083lu}, 231 {0xd851c54a56bf1f29lu, 0x4a84d1d50bf4c4fflu, 232 0x56621e553d52b56clu, 0xd0069553da8f584alu}, 233 {0x5826475e2c5fd799lu, 0xfd32edc01243f5e9lu, 234 0x738ba2c66d3fe126lu, 0x5707219c6e6c26b4lu}, 235 }; 236 237 v4di deadbeef = {0xa5a5a5a5deadbeefull, 0xa5a5a5a5deadbeefull, 238 0xa5a5a5a5deadbeefull, 0xa5a5a5a5deadbeefull}; 239 /* &gather_mem[0x10] is 512 bytes from the base; indices must be >=-64, <64 240 * to account for scaling by 8 */ 241 v4di indexq = {0x000000000000001full, 0x000000000000003dull, 242 0xffffffffffffffffull, 0xffffffffffffffdfull}; 243 v4di indexd = {0x00000002ffffffcdull, 0xfffffff500000010ull, 244 0x0000003afffffff0ull, 0x000000000000000eull}; 245 246 v4di gather_mem[0x20]; 247 _Static_assert(sizeof(gather_mem) == 1024); 248 249 void init_f16reg(v4di *r) 250 { 251 memset(r, 0, sizeof(*r)); 252 memcpy(r, val_f16, sizeof(val_f16)); 253 } 254 255 void init_f32reg(v4di *r) 256 { 257 static int n; 258 float v[8]; 259 int i; 260 for (i = 0; i < 8; i++) { 261 v[i] = val_f32[n++]; 262 if (n == ARRAY_LEN(val_f32)) { 263 n = 0; 264 } 265 } 266 memcpy(r, v, sizeof(*r)); 267 } 268 269 void init_f64reg(v4di *r) 270 { 271 static int n; 272 double v[4]; 273 int i; 274 for (i = 0; i < 4; i++) { 275 v[i] = val_f64[n++]; 276 if (n == ARRAY_LEN(val_f64)) { 277 n = 0; 278 } 279 } 280 memcpy(r, v, sizeof(*r)); 281 } 282 283 void init_intreg(v4di *r) 284 { 285 static uint64_t mask; 286 static int n; 287 288 r->q0 = val_i64[n].q0 ^ mask; 289 r->q1 = val_i64[n].q1 ^ mask; 290 r->q2 = val_i64[n].q2 ^ mask; 291 r->q3 = val_i64[n].q3 ^ mask; 292 n++; 293 if (n == ARRAY_LEN(val_i64)) { 294 n = 0; 295 mask *= 0x104C11DB7; 296 } 297 } 298 299 static void init_all(reg_state *s) 300 { 301 int i; 302 303 s->r[3] = (uint64_t)&s->mem[0]; /* rdx */ 304 s->r[4] = (uint64_t)&gather_mem[ARRAY_LEN(gather_mem) / 2]; /* rsi */ 305 s->r[5] = (uint64_t)&s->mem[2]; /* rdi */ 306 s->flags = 2; 307 for (i = 0; i < 16; i++) { 308 s->ymm[i] = deadbeef; 309 } 310 s->ymm[13] = indexd; 311 s->ymm[14] = indexq; 312 for (i = 0; i < 4; i++) { 313 s->mem0[i] = deadbeef; 314 } 315 } 316 317 int main(int argc, char *argv[]) 318 { 319 int i; 320 321 init_all(&initI); 322 init_intreg(&initI.ymm[0]); 323 init_intreg(&initI.ymm[9]); 324 init_intreg(&initI.ymm[10]); 325 init_intreg(&initI.ymm[11]); 326 init_intreg(&initI.ymm[12]); 327 init_intreg(&initI.mem0[1]); 328 printf("Int:\n"); 329 dump_regs(&initI); 330 331 init_all(&initF16); 332 init_f16reg(&initF16.ymm[0]); 333 init_f16reg(&initF16.ymm[9]); 334 init_f16reg(&initF16.ymm[10]); 335 init_f16reg(&initF16.ymm[11]); 336 init_f16reg(&initF16.ymm[12]); 337 init_f16reg(&initF16.mem0[1]); 338 initF16.ff = 16; 339 printf("F16:\n"); 340 dump_regs(&initF16); 341 342 init_all(&initF32); 343 init_f32reg(&initF32.ymm[0]); 344 init_f32reg(&initF32.ymm[9]); 345 init_f32reg(&initF32.ymm[10]); 346 init_f32reg(&initF32.ymm[11]); 347 init_f32reg(&initF32.ymm[12]); 348 init_f32reg(&initF32.mem0[1]); 349 initF32.ff = 32; 350 printf("F32:\n"); 351 dump_regs(&initF32); 352 353 init_all(&initF64); 354 init_f64reg(&initF64.ymm[0]); 355 init_f64reg(&initF64.ymm[9]); 356 init_f64reg(&initF64.ymm[10]); 357 init_f64reg(&initF64.ymm[11]); 358 init_f64reg(&initF64.ymm[12]); 359 init_f64reg(&initF64.mem0[1]); 360 initF64.ff = 64; 361 printf("F64:\n"); 362 dump_regs(&initF64); 363 364 for (i = 0; i < ARRAY_LEN(gather_mem); i++) { 365 init_intreg(&gather_mem[i]); 366 } 367 368 if (argc > 1) { 369 int n = atoi(argv[1]); 370 run_test(&test_table[n]); 371 } else { 372 run_all(); 373 } 374 return 0; 375 } 376