xref: /openbmc/qemu/tests/tcg/i386/test-avx.c (revision 197a137290103993b33f93c90e788ab4984f103a)
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 
6 typedef void (*testfn)(void);
7 
8 typedef struct {
9     uint64_t q0, q1, q2, q3;
10 } __attribute__((aligned(32))) v4di;
11 
12 typedef struct {
13     uint64_t mm[8];
14     v4di ymm[16];
15     uint64_t r[16];
16     uint64_t flags;
17     uint32_t ff;
18     uint64_t pad;
19     v4di mem[4];
20     v4di mem0[4];
21 } reg_state;
22 
23 typedef struct {
24     int n;
25     testfn fn;
26     const char *s;
27     reg_state *init;
28 } TestDef;
29 
30 reg_state initI;
31 reg_state initF16;
32 reg_state initF32;
33 reg_state initF64;
34 
35 static void dump_ymm(const char *name, int n, const v4di *r, int ff)
36 {
37     printf("%s%d = %016lx %016lx %016lx %016lx\n",
38            name, n, r->q3, r->q2, r->q1, r->q0);
39     if (ff == 64) {
40         double v[4];
41         memcpy(v, r, sizeof(v));
42         printf("        %16g %16g %16g %16g\n",
43                 v[3], v[2], v[1], v[0]);
44     } else if (ff == 32) {
45         float v[8];
46         memcpy(v, r, sizeof(v));
47         printf(" %8g %8g %8g %8g %8g %8g %8g %8g\n",
48                 v[7], v[6], v[5], v[4], v[3], v[2], v[1], v[0]);
49     }
50 }
51 
52 static void dump_regs(reg_state *s)
53 {
54     int i;
55 
56     for (i = 0; i < 16; i++) {
57         dump_ymm("ymm", i, &s->ymm[i], 0);
58     }
59     for (i = 0; i < 4; i++) {
60         dump_ymm("mem", i, &s->mem0[i], 0);
61     }
62 }
63 
64 static void compare_state(const reg_state *a, const reg_state *b)
65 {
66     int i;
67     for (i = 0; i < 8; i++) {
68         if (a->mm[i] != b->mm[i]) {
69             printf("MM%d = %016lx\n", i, b->mm[i]);
70         }
71     }
72     for (i = 0; i < 16; i++) {
73         if (a->r[i] != b->r[i]) {
74             printf("r%d = %016lx\n", i, b->r[i]);
75         }
76     }
77     for (i = 0; i < 16; i++) {
78         if (memcmp(&a->ymm[i], &b->ymm[i], 32)) {
79             dump_ymm("ymm", i, &b->ymm[i], a->ff);
80         }
81     }
82     for (i = 0; i < 4; i++) {
83         if (memcmp(&a->mem0[i], &a->mem[i], 32)) {
84             dump_ymm("mem", i, &a->mem[i], a->ff);
85         }
86     }
87     if (a->flags != b->flags) {
88         printf("FLAGS = %016lx\n", b->flags);
89     }
90 }
91 
92 #define LOADMM(r, o) "movq " #r ", " #o "[%0]\n\t"
93 #define LOADYMM(r, o) "vmovdqa " #r ", " #o "[%0]\n\t"
94 #define STOREMM(r, o) "movq " #o "[%1], " #r "\n\t"
95 #define STOREYMM(r, o) "vmovdqa " #o "[%1], " #r "\n\t"
96 #define MMREG(F) \
97     F(mm0, 0x00) \
98     F(mm1, 0x08) \
99     F(mm2, 0x10) \
100     F(mm3, 0x18) \
101     F(mm4, 0x20) \
102     F(mm5, 0x28) \
103     F(mm6, 0x30) \
104     F(mm7, 0x38)
105 #define YMMREG(F) \
106     F(ymm0, 0x040) \
107     F(ymm1, 0x060) \
108     F(ymm2, 0x080) \
109     F(ymm3, 0x0a0) \
110     F(ymm4, 0x0c0) \
111     F(ymm5, 0x0e0) \
112     F(ymm6, 0x100) \
113     F(ymm7, 0x120) \
114     F(ymm8, 0x140) \
115     F(ymm9, 0x160) \
116     F(ymm10, 0x180) \
117     F(ymm11, 0x1a0) \
118     F(ymm12, 0x1c0) \
119     F(ymm13, 0x1e0) \
120     F(ymm14, 0x200) \
121     F(ymm15, 0x220)
122 #define LOADREG(r, o) "mov " #r ", " #o "[rax]\n\t"
123 #define STOREREG(r, o) "mov " #o "[rax], " #r "\n\t"
124 #define REG(F) \
125     F(rbx, 0x248) \
126     F(rcx, 0x250) \
127     F(rdx, 0x258) \
128     F(rsi, 0x260) \
129     F(rdi, 0x268) \
130     F(r8, 0x280) \
131     F(r9, 0x288) \
132     F(r10, 0x290) \
133     F(r11, 0x298) \
134     F(r12, 0x2a0) \
135     F(r13, 0x2a8) \
136     F(r14, 0x2b0) \
137     F(r15, 0x2b8) \
138 
139 static void run_test(const TestDef *t)
140 {
141     reg_state result;
142     reg_state *init = t->init;
143     memcpy(init->mem, init->mem0, sizeof(init->mem));
144     printf("%5d %s\n", t->n, t->s);
145     asm volatile(
146             MMREG(LOADMM)
147             YMMREG(LOADYMM)
148             "sub rsp, 128\n\t"
149             "push rax\n\t"
150             "push rbx\n\t"
151             "push rcx\n\t"
152             "push rdx\n\t"
153             "push %1\n\t"
154             "push %2\n\t"
155             "mov rax, %0\n\t"
156             "pushf\n\t"
157             "pop rbx\n\t"
158             "shr rbx, 8\n\t"
159             "shl rbx, 8\n\t"
160             "mov rcx, 0x2c0[rax]\n\t"
161             "and rcx, 0xff\n\t"
162             "or rbx, rcx\n\t"
163             "push rbx\n\t"
164             "popf\n\t"
165             REG(LOADREG)
166             "mov rax, 0x240[rax]\n\t"
167             "call [rsp]\n\t"
168             "mov [rsp], rax\n\t"
169             "mov rax, 8[rsp]\n\t"
170             REG(STOREREG)
171             "mov rbx, [rsp]\n\t"
172             "mov 0x240[rax], rbx\n\t"
173             "mov rbx, 0\n\t"
174             "mov 0x270[rax], rbx\n\t"
175             "mov 0x278[rax], rbx\n\t"
176             "pushf\n\t"
177             "pop rbx\n\t"
178             "and rbx, 0xff\n\t"
179             "mov 0x2c0[rax], rbx\n\t"
180             "add rsp, 16\n\t"
181             "pop rdx\n\t"
182             "pop rcx\n\t"
183             "pop rbx\n\t"
184             "pop rax\n\t"
185             "add rsp, 128\n\t"
186             MMREG(STOREMM)
187             YMMREG(STOREYMM)
188             : : "r"(init), "r"(&result), "r"(t->fn)
189             : "memory", "cc",
190             "rsi", "rdi",
191             "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
192             "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7",
193             "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5",
194             "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11",
195             "ymm12", "ymm13", "ymm14", "ymm15"
196             );
197     compare_state(init, &result);
198 }
199 
200 #define TEST(n, cmd, type) \
201 static void __attribute__((naked)) test_##n(void) \
202 { \
203     asm volatile(cmd); \
204     asm volatile("ret"); \
205 }
206 #include "test-avx.h"
207 
208 
209 static const TestDef test_table[] = {
210 #define TEST(n, cmd, type) {n, test_##n, cmd, &init##type},
211 #include "test-avx.h"
212     {-1, NULL, "", NULL}
213 };
214 
215 static void run_all(void)
216 {
217     const TestDef *t;
218     for (t = test_table; t->fn; t++) {
219         run_test(t);
220     }
221 }
222 
223 #define ARRAY_LEN(x) (sizeof(x) / sizeof(x[0]))
224 
225 uint16_t val_f16[] = { 0x4000, 0xbc00, 0x44cd, 0x3a66, 0x4200, 0x7a1a, 0x4780, 0x4826 };
226 float val_f32[] = {2.0, -1.0, 4.8, 0.8, 3, -42.0, 5e6, 7.5, 8.3};
227 double val_f64[] = {2.0, -1.0, 4.8, 0.8, 3, -42.0, 5e6, 7.5};
228 v4di val_i64[] = {
229     {0x3d6b3b6a9e4118f2lu, 0x355ae76d2774d78clu,
230      0xac3ff76c4daa4b28lu, 0xe7fabd204cb54083lu},
231     {0xd851c54a56bf1f29lu, 0x4a84d1d50bf4c4fflu,
232      0x56621e553d52b56clu, 0xd0069553da8f584alu},
233     {0x5826475e2c5fd799lu, 0xfd32edc01243f5e9lu,
234      0x738ba2c66d3fe126lu, 0x5707219c6e6c26b4lu},
235 };
236 
237 v4di deadbeef = {0xa5a5a5a5deadbeefull, 0xa5a5a5a5deadbeefull,
238                  0xa5a5a5a5deadbeefull, 0xa5a5a5a5deadbeefull};
239 v4di indexq = {0x000000000000001full, 0x000000000000008full,
240                0xffffffffffffffffull, 0xffffffffffffff5full};
241 v4di indexd = {0x00000002000000efull, 0xfffffff500000010ull,
242                0x0000000afffffff0ull, 0x000000000000000eull};
243 
244 v4di gather_mem[0x20];
245 
246 void init_f16reg(v4di *r)
247 {
248     memset(r, 0, sizeof(*r));
249     memcpy(r, val_f16, sizeof(val_f16));
250 }
251 
252 void init_f32reg(v4di *r)
253 {
254     static int n;
255     float v[8];
256     int i;
257     for (i = 0; i < 8; i++) {
258         v[i] = val_f32[n++];
259         if (n == ARRAY_LEN(val_f32)) {
260             n = 0;
261         }
262     }
263     memcpy(r, v, sizeof(*r));
264 }
265 
266 void init_f64reg(v4di *r)
267 {
268     static int n;
269     double v[4];
270     int i;
271     for (i = 0; i < 4; i++) {
272         v[i] = val_f64[n++];
273         if (n == ARRAY_LEN(val_f64)) {
274             n = 0;
275         }
276     }
277     memcpy(r, v, sizeof(*r));
278 }
279 
280 void init_intreg(v4di *r)
281 {
282     static uint64_t mask;
283     static int n;
284 
285     r->q0 = val_i64[n].q0 ^ mask;
286     r->q1 = val_i64[n].q1 ^ mask;
287     r->q2 = val_i64[n].q2 ^ mask;
288     r->q3 = val_i64[n].q3 ^ mask;
289     n++;
290     if (n == ARRAY_LEN(val_i64)) {
291         n = 0;
292         mask *= 0x104C11DB7;
293     }
294 }
295 
296 static void init_all(reg_state *s)
297 {
298     int i;
299 
300     s->r[3] = (uint64_t)&s->mem[0]; /* rdx */
301     s->r[4] = (uint64_t)&gather_mem[ARRAY_LEN(gather_mem) / 2]; /* rsi */
302     s->r[5] = (uint64_t)&s->mem[2]; /* rdi */
303     s->flags = 2;
304     for (i = 0; i < 16; i++) {
305         s->ymm[i] = deadbeef;
306     }
307     s->ymm[13] = indexd;
308     s->ymm[14] = indexq;
309     for (i = 0; i < 4; i++) {
310         s->mem0[i] = deadbeef;
311     }
312 }
313 
314 int main(int argc, char *argv[])
315 {
316     int i;
317 
318     init_all(&initI);
319     init_intreg(&initI.ymm[10]);
320     init_intreg(&initI.ymm[11]);
321     init_intreg(&initI.ymm[12]);
322     init_intreg(&initI.mem0[1]);
323     printf("Int:\n");
324     dump_regs(&initI);
325 
326     init_all(&initF16);
327     init_f16reg(&initF16.ymm[10]);
328     init_f16reg(&initF16.ymm[11]);
329     init_f16reg(&initF16.ymm[12]);
330     init_f16reg(&initF16.mem0[1]);
331     initF16.ff = 16;
332     printf("F16:\n");
333     dump_regs(&initF16);
334 
335     init_all(&initF32);
336     init_f32reg(&initF32.ymm[10]);
337     init_f32reg(&initF32.ymm[11]);
338     init_f32reg(&initF32.ymm[12]);
339     init_f32reg(&initF32.mem0[1]);
340     initF32.ff = 32;
341     printf("F32:\n");
342     dump_regs(&initF32);
343 
344     init_all(&initF64);
345     init_f64reg(&initF64.ymm[10]);
346     init_f64reg(&initF64.ymm[11]);
347     init_f64reg(&initF64.ymm[12]);
348     init_f64reg(&initF64.mem0[1]);
349     initF64.ff = 64;
350     printf("F64:\n");
351     dump_regs(&initF64);
352 
353     for (i = 0; i < ARRAY_LEN(gather_mem); i++) {
354         init_intreg(&gather_mem[i]);
355     }
356 
357     if (argc > 1) {
358         int n = atoi(argv[1]);
359         run_test(&test_table[n]);
360     } else {
361         run_all();
362     }
363     return 0;
364 }
365