1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5
6 typedef void (*testfn)(void);
7
8 typedef struct {
9 uint64_t q0, q1, q2, q3;
10 } __attribute__((aligned(32))) v4di;
11
12 typedef struct {
13 uint64_t mm[8];
14 v4di ymm[16];
15 uint64_t r[16];
16 uint64_t flags;
17 uint32_t ff;
18 uint64_t pad;
19 v4di mem[4];
20 v4di mem0[4];
21 } reg_state;
22
23 typedef struct {
24 int n;
25 testfn fn;
26 const char *s;
27 reg_state *init;
28 } TestDef;
29
30 reg_state initI;
31 reg_state initF16;
32 reg_state initF32;
33 reg_state initF64;
34
dump_ymm(const char * name,int n,const v4di * r,int ff)35 static void dump_ymm(const char *name, int n, const v4di *r, int ff)
36 {
37 printf("%s%d = %016lx %016lx %016lx %016lx\n",
38 name, n, r->q3, r->q2, r->q1, r->q0);
39 if (ff == 64) {
40 double v[4];
41 memcpy(v, r, sizeof(v));
42 printf(" %16g %16g %16g %16g\n",
43 v[3], v[2], v[1], v[0]);
44 } else if (ff == 32) {
45 float v[8];
46 memcpy(v, r, sizeof(v));
47 printf(" %8g %8g %8g %8g %8g %8g %8g %8g\n",
48 v[7], v[6], v[5], v[4], v[3], v[2], v[1], v[0]);
49 }
50 }
51
dump_regs(reg_state * s)52 static void dump_regs(reg_state *s)
53 {
54 int i;
55
56 for (i = 0; i < 16; i++) {
57 dump_ymm("ymm", i, &s->ymm[i], 0);
58 }
59 for (i = 0; i < 4; i++) {
60 dump_ymm("mem", i, &s->mem0[i], 0);
61 }
62 }
63
compare_state(const reg_state * a,const reg_state * b)64 static void compare_state(const reg_state *a, const reg_state *b)
65 {
66 int i;
67 for (i = 0; i < 8; i++) {
68 if (a->mm[i] != b->mm[i]) {
69 printf("MM%d = %016lx\n", i, b->mm[i]);
70 }
71 }
72 for (i = 0; i < 16; i++) {
73 if (a->r[i] != b->r[i]) {
74 printf("r%d = %016lx\n", i, b->r[i]);
75 }
76 }
77 for (i = 0; i < 16; i++) {
78 if (memcmp(&a->ymm[i], &b->ymm[i], 32)) {
79 dump_ymm("ymm", i, &b->ymm[i], a->ff);
80 }
81 }
82 for (i = 0; i < 4; i++) {
83 if (memcmp(&a->mem0[i], &a->mem[i], 32)) {
84 dump_ymm("mem", i, &a->mem[i], a->ff);
85 }
86 }
87 if (a->flags != b->flags) {
88 printf("FLAGS = %016lx\n", b->flags);
89 }
90 }
91
92 #define LOADMM(r, o) "movq " #r ", " #o "[%0]\n\t"
93 #define LOADYMM(r, o) "vmovdqa " #r ", " #o "[%0]\n\t"
94 #define STOREMM(r, o) "movq " #o "[%1], " #r "\n\t"
95 #define STOREYMM(r, o) "vmovdqa " #o "[%1], " #r "\n\t"
96 #define MMREG(F) \
97 F(mm0, 0x00) \
98 F(mm1, 0x08) \
99 F(mm2, 0x10) \
100 F(mm3, 0x18) \
101 F(mm4, 0x20) \
102 F(mm5, 0x28) \
103 F(mm6, 0x30) \
104 F(mm7, 0x38)
105 #define YMMREG(F) \
106 F(ymm0, 0x040) \
107 F(ymm1, 0x060) \
108 F(ymm2, 0x080) \
109 F(ymm3, 0x0a0) \
110 F(ymm4, 0x0c0) \
111 F(ymm5, 0x0e0) \
112 F(ymm6, 0x100) \
113 F(ymm7, 0x120) \
114 F(ymm8, 0x140) \
115 F(ymm9, 0x160) \
116 F(ymm10, 0x180) \
117 F(ymm11, 0x1a0) \
118 F(ymm12, 0x1c0) \
119 F(ymm13, 0x1e0) \
120 F(ymm14, 0x200) \
121 F(ymm15, 0x220)
122 #define LOADREG(r, o) "mov " #r ", " #o "[rax]\n\t"
123 #define STOREREG(r, o) "mov " #o "[rax], " #r "\n\t"
124 #define REG(F) \
125 F(rbx, 0x248) \
126 F(rcx, 0x250) \
127 F(rdx, 0x258) \
128 F(rsi, 0x260) \
129 F(rdi, 0x268) \
130 F(r8, 0x280) \
131 F(r9, 0x288) \
132 F(r10, 0x290) \
133 F(r11, 0x298) \
134 F(r12, 0x2a0) \
135 F(r13, 0x2a8) \
136 F(r14, 0x2b0) \
137 F(r15, 0x2b8) \
138
run_test(const TestDef * t)139 static void run_test(const TestDef *t)
140 {
141 reg_state result;
142 reg_state *init = t->init;
143 memcpy(init->mem, init->mem0, sizeof(init->mem));
144 printf("%5d %s\n", t->n, t->s);
145 asm volatile(
146 MMREG(LOADMM)
147 YMMREG(LOADYMM)
148 "sub rsp, 128\n\t"
149 "push rax\n\t"
150 "push rbx\n\t"
151 "push rcx\n\t"
152 "push rdx\n\t"
153 "push %1\n\t"
154 "push %2\n\t"
155 "mov rax, %0\n\t"
156 "pushf\n\t"
157 "pop rbx\n\t"
158 "shr rbx, 8\n\t"
159 "shl rbx, 8\n\t"
160 "mov rcx, 0x2c0[rax]\n\t"
161 "and rcx, 0xff\n\t"
162 "or rbx, rcx\n\t"
163 "push rbx\n\t"
164 "popf\n\t"
165 REG(LOADREG)
166 "mov rax, 0x240[rax]\n\t"
167 "call [rsp]\n\t"
168 "mov [rsp], rax\n\t"
169 "mov rax, 8[rsp]\n\t"
170 REG(STOREREG)
171 "mov rbx, [rsp]\n\t"
172 "mov 0x240[rax], rbx\n\t"
173 "mov rbx, 0\n\t"
174 "mov 0x270[rax], rbx\n\t"
175 "mov 0x278[rax], rbx\n\t"
176 "pushf\n\t"
177 "pop rbx\n\t"
178 "and rbx, 0xff\n\t"
179 "mov 0x2c0[rax], rbx\n\t"
180 "add rsp, 16\n\t"
181 "pop rdx\n\t"
182 "pop rcx\n\t"
183 "pop rbx\n\t"
184 "pop rax\n\t"
185 "add rsp, 128\n\t"
186 MMREG(STOREMM)
187 YMMREG(STOREYMM)
188 : : "r"(init), "r"(&result), "r"(t->fn)
189 : "memory", "cc",
190 "rsi", "rdi",
191 "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
192 "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7",
193 "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5",
194 "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11",
195 "ymm12", "ymm13", "ymm14", "ymm15"
196 );
197 compare_state(init, &result);
198 }
199
200 #define TEST(n, cmd, type) \
201 static void __attribute__((naked)) test_##n(void) \
202 { \
203 asm volatile(cmd); \
204 asm volatile("ret"); \
205 }
206 #include "test-avx.h"
207
208
209 static const TestDef test_table[] = {
210 #define TEST(n, cmd, type) {n, test_##n, cmd, &init##type},
211 #include "test-avx.h"
212 {-1, NULL, "", NULL}
213 };
214
run_all(void)215 static void run_all(void)
216 {
217 const TestDef *t;
218 for (t = test_table; t->fn; t++) {
219 run_test(t);
220 }
221 }
222
223 #define ARRAY_LEN(x) (sizeof(x) / sizeof(x[0]))
224
225 uint16_t val_f16[] = { 0x4000, 0xbc00, 0x44cd, 0x3a66, 0x4200, 0x7a1a, 0x4780, 0x4826 };
226 float val_f32[] = {2.0, -1.0, 4.8, 0.8, 3, -42.0, 5e6, 7.5, 8.3};
227 double val_f64[] = {2.0, -1.0, 4.8, 0.8, 3, -42.0, 5e6, 7.5};
228 v4di val_i64[] = {
229 {0x3d6b3b6a9e4118f2lu, 0x355ae76d2774d78clu,
230 0xac3ff76c4daa4b28lu, 0xe7fabd204cb54083lu},
231 {0xd851c54a56bf1f29lu, 0x4a84d1d50bf4c4fflu,
232 0x56621e553d52b56clu, 0xd0069553da8f584alu},
233 {0x5826475e2c5fd799lu, 0xfd32edc01243f5e9lu,
234 0x738ba2c66d3fe126lu, 0x5707219c6e6c26b4lu},
235 };
236
237 v4di deadbeef = {0xa5a5a5a5deadbeefull, 0xa5a5a5a5deadbeefull,
238 0xa5a5a5a5deadbeefull, 0xa5a5a5a5deadbeefull};
239 /* &gather_mem[0x10] is 512 bytes from the base; indices must be >=-64, <64
240 * to account for scaling by 8 */
241 v4di indexq = {0x000000000000001full, 0x000000000000003dull,
242 0xffffffffffffffffull, 0xffffffffffffffdfull};
243 v4di indexd = {0x00000002ffffffcdull, 0xfffffff500000010ull,
244 0x0000003afffffff0ull, 0x000000000000000eull};
245
246 v4di gather_mem[0x20];
247 _Static_assert(sizeof(gather_mem) == 1024);
248
init_f16reg(v4di * r)249 void init_f16reg(v4di *r)
250 {
251 memset(r, 0, sizeof(*r));
252 memcpy(r, val_f16, sizeof(val_f16));
253 }
254
init_f32reg(v4di * r)255 void init_f32reg(v4di *r)
256 {
257 static int n;
258 float v[8];
259 int i;
260 for (i = 0; i < 8; i++) {
261 v[i] = val_f32[n++];
262 if (n == ARRAY_LEN(val_f32)) {
263 n = 0;
264 }
265 }
266 memcpy(r, v, sizeof(*r));
267 }
268
init_f64reg(v4di * r)269 void init_f64reg(v4di *r)
270 {
271 static int n;
272 double v[4];
273 int i;
274 for (i = 0; i < 4; i++) {
275 v[i] = val_f64[n++];
276 if (n == ARRAY_LEN(val_f64)) {
277 n = 0;
278 }
279 }
280 memcpy(r, v, sizeof(*r));
281 }
282
init_intreg(v4di * r)283 void init_intreg(v4di *r)
284 {
285 static uint64_t mask;
286 static int n;
287
288 r->q0 = val_i64[n].q0 ^ mask;
289 r->q1 = val_i64[n].q1 ^ mask;
290 r->q2 = val_i64[n].q2 ^ mask;
291 r->q3 = val_i64[n].q3 ^ mask;
292 n++;
293 if (n == ARRAY_LEN(val_i64)) {
294 n = 0;
295 mask *= 0x104C11DB7;
296 }
297 }
298
init_all(reg_state * s)299 static void init_all(reg_state *s)
300 {
301 int i;
302
303 s->r[3] = (uint64_t)&s->mem[0]; /* rdx */
304 s->r[4] = (uint64_t)&gather_mem[ARRAY_LEN(gather_mem) / 2]; /* rsi */
305 s->r[5] = (uint64_t)&s->mem[2]; /* rdi */
306 s->flags = 2;
307 for (i = 0; i < 16; i++) {
308 s->ymm[i] = deadbeef;
309 }
310 s->ymm[13] = indexd;
311 s->ymm[14] = indexq;
312 for (i = 0; i < 4; i++) {
313 s->mem0[i] = deadbeef;
314 }
315 }
316
main(int argc,char * argv[])317 int main(int argc, char *argv[])
318 {
319 int i;
320
321 init_all(&initI);
322 init_intreg(&initI.ymm[0]);
323 init_intreg(&initI.ymm[9]);
324 init_intreg(&initI.ymm[10]);
325 init_intreg(&initI.ymm[11]);
326 init_intreg(&initI.ymm[12]);
327 init_intreg(&initI.mem0[1]);
328 printf("Int:\n");
329 dump_regs(&initI);
330
331 init_all(&initF16);
332 init_f16reg(&initF16.ymm[0]);
333 init_f16reg(&initF16.ymm[9]);
334 init_f16reg(&initF16.ymm[10]);
335 init_f16reg(&initF16.ymm[11]);
336 init_f16reg(&initF16.ymm[12]);
337 init_f16reg(&initF16.mem0[1]);
338 initF16.ff = 16;
339 printf("F16:\n");
340 dump_regs(&initF16);
341
342 init_all(&initF32);
343 init_f32reg(&initF32.ymm[0]);
344 init_f32reg(&initF32.ymm[9]);
345 init_f32reg(&initF32.ymm[10]);
346 init_f32reg(&initF32.ymm[11]);
347 init_f32reg(&initF32.ymm[12]);
348 init_f32reg(&initF32.mem0[1]);
349 initF32.ff = 32;
350 printf("F32:\n");
351 dump_regs(&initF32);
352
353 init_all(&initF64);
354 init_f64reg(&initF64.ymm[0]);
355 init_f64reg(&initF64.ymm[9]);
356 init_f64reg(&initF64.ymm[10]);
357 init_f64reg(&initF64.ymm[11]);
358 init_f64reg(&initF64.ymm[12]);
359 init_f64reg(&initF64.mem0[1]);
360 initF64.ff = 64;
361 printf("F64:\n");
362 dump_regs(&initF64);
363
364 for (i = 0; i < ARRAY_LEN(gather_mem); i++) {
365 init_intreg(&gather_mem[i]);
366 }
367
368 if (argc > 1) {
369 int n = atoi(argv[1]);
370 run_test(&test_table[n]);
371 } else {
372 run_all();
373 }
374 return 0;
375 }
376