xref: /openbmc/qemu/tests/tcg/i386/test-avx.c (revision 4921d0a7)
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 
6 typedef void (*testfn)(void);
7 
8 typedef struct {
9     uint64_t q0, q1, q2, q3;
10 } __attribute__((aligned(32))) v4di;
11 
12 typedef struct {
13     uint64_t mm[8];
14     v4di ymm[16];
15     uint64_t r[16];
16     uint64_t flags;
17     uint32_t ff;
18     uint64_t pad;
19     v4di mem[4];
20     v4di mem0[4];
21 } reg_state;
22 
23 typedef struct {
24     int n;
25     testfn fn;
26     const char *s;
27     reg_state *init;
28 } TestDef;
29 
30 reg_state initI;
31 reg_state initF16;
32 reg_state initF32;
33 reg_state initF64;
34 
35 static void dump_ymm(const char *name, int n, const v4di *r, int ff)
36 {
37     printf("%s%d = %016lx %016lx %016lx %016lx\n",
38            name, n, r->q3, r->q2, r->q1, r->q0);
39     if (ff == 64) {
40         double v[4];
41         memcpy(v, r, sizeof(v));
42         printf("        %16g %16g %16g %16g\n",
43                 v[3], v[2], v[1], v[0]);
44     } else if (ff == 32) {
45         float v[8];
46         memcpy(v, r, sizeof(v));
47         printf(" %8g %8g %8g %8g %8g %8g %8g %8g\n",
48                 v[7], v[6], v[5], v[4], v[3], v[2], v[1], v[0]);
49     }
50 }
51 
52 static void dump_regs(reg_state *s)
53 {
54     int i;
55 
56     for (i = 0; i < 16; i++) {
57         dump_ymm("ymm", i, &s->ymm[i], 0);
58     }
59     for (i = 0; i < 4; i++) {
60         dump_ymm("mem", i, &s->mem0[i], 0);
61     }
62 }
63 
64 static void compare_state(const reg_state *a, const reg_state *b)
65 {
66     int i;
67     for (i = 0; i < 8; i++) {
68         if (a->mm[i] != b->mm[i]) {
69             printf("MM%d = %016lx\n", i, b->mm[i]);
70         }
71     }
72     for (i = 0; i < 16; i++) {
73         if (a->r[i] != b->r[i]) {
74             printf("r%d = %016lx\n", i, b->r[i]);
75         }
76     }
77     for (i = 0; i < 16; i++) {
78         if (memcmp(&a->ymm[i], &b->ymm[i], 32)) {
79             dump_ymm("ymm", i, &b->ymm[i], a->ff);
80         }
81     }
82     for (i = 0; i < 4; i++) {
83         if (memcmp(&a->mem0[i], &a->mem[i], 32)) {
84             dump_ymm("mem", i, &a->mem[i], a->ff);
85         }
86     }
87     if (a->flags != b->flags) {
88         printf("FLAGS = %016lx\n", b->flags);
89     }
90 }
91 
92 #define LOADMM(r, o) "movq " #r ", " #o "[%0]\n\t"
93 #define LOADYMM(r, o) "vmovdqa " #r ", " #o "[%0]\n\t"
94 #define STOREMM(r, o) "movq " #o "[%1], " #r "\n\t"
95 #define STOREYMM(r, o) "vmovdqa " #o "[%1], " #r "\n\t"
96 #define MMREG(F) \
97     F(mm0, 0x00) \
98     F(mm1, 0x08) \
99     F(mm2, 0x10) \
100     F(mm3, 0x18) \
101     F(mm4, 0x20) \
102     F(mm5, 0x28) \
103     F(mm6, 0x30) \
104     F(mm7, 0x38)
105 #define YMMREG(F) \
106     F(ymm0, 0x040) \
107     F(ymm1, 0x060) \
108     F(ymm2, 0x080) \
109     F(ymm3, 0x0a0) \
110     F(ymm4, 0x0c0) \
111     F(ymm5, 0x0e0) \
112     F(ymm6, 0x100) \
113     F(ymm7, 0x120) \
114     F(ymm8, 0x140) \
115     F(ymm9, 0x160) \
116     F(ymm10, 0x180) \
117     F(ymm11, 0x1a0) \
118     F(ymm12, 0x1c0) \
119     F(ymm13, 0x1e0) \
120     F(ymm14, 0x200) \
121     F(ymm15, 0x220)
122 #define LOADREG(r, o) "mov " #r ", " #o "[rax]\n\t"
123 #define STOREREG(r, o) "mov " #o "[rax], " #r "\n\t"
124 #define REG(F) \
125     F(rbx, 0x248) \
126     F(rcx, 0x250) \
127     F(rdx, 0x258) \
128     F(rsi, 0x260) \
129     F(rdi, 0x268) \
130     F(r8, 0x280) \
131     F(r9, 0x288) \
132     F(r10, 0x290) \
133     F(r11, 0x298) \
134     F(r12, 0x2a0) \
135     F(r13, 0x2a8) \
136     F(r14, 0x2b0) \
137     F(r15, 0x2b8) \
138 
139 static void run_test(const TestDef *t)
140 {
141     reg_state result;
142     reg_state *init = t->init;
143     memcpy(init->mem, init->mem0, sizeof(init->mem));
144     printf("%5d %s\n", t->n, t->s);
145     asm volatile(
146             MMREG(LOADMM)
147             YMMREG(LOADYMM)
148             "sub rsp, 128\n\t"
149             "push rax\n\t"
150             "push rbx\n\t"
151             "push rcx\n\t"
152             "push rdx\n\t"
153             "push %1\n\t"
154             "push %2\n\t"
155             "mov rax, %0\n\t"
156             "pushf\n\t"
157             "pop rbx\n\t"
158             "shr rbx, 8\n\t"
159             "shl rbx, 8\n\t"
160             "mov rcx, 0x2c0[rax]\n\t"
161             "and rcx, 0xff\n\t"
162             "or rbx, rcx\n\t"
163             "push rbx\n\t"
164             "popf\n\t"
165             REG(LOADREG)
166             "mov rax, 0x240[rax]\n\t"
167             "call [rsp]\n\t"
168             "mov [rsp], rax\n\t"
169             "mov rax, 8[rsp]\n\t"
170             REG(STOREREG)
171             "mov rbx, [rsp]\n\t"
172             "mov 0x240[rax], rbx\n\t"
173             "mov rbx, 0\n\t"
174             "mov 0x270[rax], rbx\n\t"
175             "mov 0x278[rax], rbx\n\t"
176             "pushf\n\t"
177             "pop rbx\n\t"
178             "and rbx, 0xff\n\t"
179             "mov 0x2c0[rax], rbx\n\t"
180             "add rsp, 16\n\t"
181             "pop rdx\n\t"
182             "pop rcx\n\t"
183             "pop rbx\n\t"
184             "pop rax\n\t"
185             "add rsp, 128\n\t"
186             MMREG(STOREMM)
187             YMMREG(STOREYMM)
188             : : "r"(init), "r"(&result), "r"(t->fn)
189             : "memory", "cc",
190             "rsi", "rdi",
191             "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
192             "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7",
193             "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5",
194             "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11",
195             "ymm12", "ymm13", "ymm14", "ymm15"
196             );
197     compare_state(init, &result);
198 }
199 
200 #define TEST(n, cmd, type) \
201 static void __attribute__((naked)) test_##n(void) \
202 { \
203     asm volatile(cmd); \
204     asm volatile("ret"); \
205 }
206 #include "test-avx.h"
207 
208 
209 static const TestDef test_table[] = {
210 #define TEST(n, cmd, type) {n, test_##n, cmd, &init##type},
211 #include "test-avx.h"
212     {-1, NULL, "", NULL}
213 };
214 
215 static void run_all(void)
216 {
217     const TestDef *t;
218     for (t = test_table; t->fn; t++) {
219         run_test(t);
220     }
221 }
222 
223 #define ARRAY_LEN(x) (sizeof(x) / sizeof(x[0]))
224 
225 uint16_t val_f16[] = { 0x4000, 0xbc00, 0x44cd, 0x3a66, 0x4200, 0x7a1a, 0x4780, 0x4826 };
226 float val_f32[] = {2.0, -1.0, 4.8, 0.8, 3, -42.0, 5e6, 7.5, 8.3};
227 double val_f64[] = {2.0, -1.0, 4.8, 0.8, 3, -42.0, 5e6, 7.5};
228 v4di val_i64[] = {
229     {0x3d6b3b6a9e4118f2lu, 0x355ae76d2774d78clu,
230      0xac3ff76c4daa4b28lu, 0xe7fabd204cb54083lu},
231     {0xd851c54a56bf1f29lu, 0x4a84d1d50bf4c4fflu,
232      0x56621e553d52b56clu, 0xd0069553da8f584alu},
233     {0x5826475e2c5fd799lu, 0xfd32edc01243f5e9lu,
234      0x738ba2c66d3fe126lu, 0x5707219c6e6c26b4lu},
235 };
236 
237 v4di deadbeef = {0xa5a5a5a5deadbeefull, 0xa5a5a5a5deadbeefull,
238                  0xa5a5a5a5deadbeefull, 0xa5a5a5a5deadbeefull};
239 /* &gather_mem[0x10] is 512 bytes from the base; indices must be >=-64, <64
240  * to account for scaling by 8 */
241 v4di indexq = {0x000000000000001full, 0x000000000000003dull,
242                0xffffffffffffffffull, 0xffffffffffffffdfull};
243 v4di indexd = {0x00000002ffffffcdull, 0xfffffff500000010ull,
244                0x0000003afffffff0ull, 0x000000000000000eull};
245 
246 v4di gather_mem[0x20];
247 _Static_assert(sizeof(gather_mem) == 1024);
248 
249 void init_f16reg(v4di *r)
250 {
251     memset(r, 0, sizeof(*r));
252     memcpy(r, val_f16, sizeof(val_f16));
253 }
254 
255 void init_f32reg(v4di *r)
256 {
257     static int n;
258     float v[8];
259     int i;
260     for (i = 0; i < 8; i++) {
261         v[i] = val_f32[n++];
262         if (n == ARRAY_LEN(val_f32)) {
263             n = 0;
264         }
265     }
266     memcpy(r, v, sizeof(*r));
267 }
268 
269 void init_f64reg(v4di *r)
270 {
271     static int n;
272     double v[4];
273     int i;
274     for (i = 0; i < 4; i++) {
275         v[i] = val_f64[n++];
276         if (n == ARRAY_LEN(val_f64)) {
277             n = 0;
278         }
279     }
280     memcpy(r, v, sizeof(*r));
281 }
282 
283 void init_intreg(v4di *r)
284 {
285     static uint64_t mask;
286     static int n;
287 
288     r->q0 = val_i64[n].q0 ^ mask;
289     r->q1 = val_i64[n].q1 ^ mask;
290     r->q2 = val_i64[n].q2 ^ mask;
291     r->q3 = val_i64[n].q3 ^ mask;
292     n++;
293     if (n == ARRAY_LEN(val_i64)) {
294         n = 0;
295         mask *= 0x104C11DB7;
296     }
297 }
298 
299 static void init_all(reg_state *s)
300 {
301     int i;
302 
303     s->r[3] = (uint64_t)&s->mem[0]; /* rdx */
304     s->r[4] = (uint64_t)&gather_mem[ARRAY_LEN(gather_mem) / 2]; /* rsi */
305     s->r[5] = (uint64_t)&s->mem[2]; /* rdi */
306     s->flags = 2;
307     for (i = 0; i < 16; i++) {
308         s->ymm[i] = deadbeef;
309     }
310     s->ymm[13] = indexd;
311     s->ymm[14] = indexq;
312     for (i = 0; i < 4; i++) {
313         s->mem0[i] = deadbeef;
314     }
315 }
316 
317 int main(int argc, char *argv[])
318 {
319     int i;
320 
321     init_all(&initI);
322     init_intreg(&initI.ymm[0]);
323     init_intreg(&initI.ymm[9]);
324     init_intreg(&initI.ymm[10]);
325     init_intreg(&initI.ymm[11]);
326     init_intreg(&initI.ymm[12]);
327     init_intreg(&initI.mem0[1]);
328     printf("Int:\n");
329     dump_regs(&initI);
330 
331     init_all(&initF16);
332     init_f16reg(&initF16.ymm[0]);
333     init_f16reg(&initF16.ymm[9]);
334     init_f16reg(&initF16.ymm[10]);
335     init_f16reg(&initF16.ymm[11]);
336     init_f16reg(&initF16.ymm[12]);
337     init_f16reg(&initF16.mem0[1]);
338     initF16.ff = 16;
339     printf("F16:\n");
340     dump_regs(&initF16);
341 
342     init_all(&initF32);
343     init_f32reg(&initF32.ymm[0]);
344     init_f32reg(&initF32.ymm[9]);
345     init_f32reg(&initF32.ymm[10]);
346     init_f32reg(&initF32.ymm[11]);
347     init_f32reg(&initF32.ymm[12]);
348     init_f32reg(&initF32.mem0[1]);
349     initF32.ff = 32;
350     printf("F32:\n");
351     dump_regs(&initF32);
352 
353     init_all(&initF64);
354     init_f64reg(&initF64.ymm[0]);
355     init_f64reg(&initF64.ymm[9]);
356     init_f64reg(&initF64.ymm[10]);
357     init_f64reg(&initF64.ymm[11]);
358     init_f64reg(&initF64.ymm[12]);
359     init_f64reg(&initF64.mem0[1]);
360     initF64.ff = 64;
361     printf("F64:\n");
362     dump_regs(&initF64);
363 
364     for (i = 0; i < ARRAY_LEN(gather_mem); i++) {
365         init_intreg(&gather_mem[i]);
366     }
367 
368     if (argc > 1) {
369         int n = atoi(argv[1]);
370         run_test(&test_table[n]);
371     } else {
372         run_all();
373     }
374     return 0;
375 }
376