xref: /openbmc/qemu/tests/tcg/i386/test-avx.py (revision fa3673e4)
1#! /usr/bin/env python3
2
3# Generate test-avx.h from x86.csv
4
5import csv
6import sys
7from fnmatch import fnmatch
8
9archs = [
10    "SSE", "SSE2", "SSE3", "SSSE3", "SSE4_1", "SSE4_2",
11    "AES", "AVX", "AVX2", "AES+AVX", "VAES+AVX",
12    "F16C", "FMA", "SHA",
13]
14
15ignore = set(["FISTTP",
16    "LDMXCSR", "VLDMXCSR", "STMXCSR", "VSTMXCSR"])
17
18imask = {
19    'vBLENDPD': 0xff,
20    'vBLENDPS': 0x0f,
21    'CMP[PS][SD]': 0x07,
22    'VCMP[PS][SD]': 0x1f,
23    'vCVTPS2PH': 0x7,
24    'vDPPD': 0x33,
25    'vDPPS': 0xff,
26    'vEXTRACTPS': 0x03,
27    'vINSERTPS': 0xff,
28    'MPSADBW': 0x7,
29    'VMPSADBW': 0x3f,
30    'vPALIGNR': 0x3f,
31    'vPBLENDW': 0xff,
32    'vPCMP[EI]STR*': 0x0f,
33    'vPEXTRB': 0x0f,
34    'vPEXTRW': 0x07,
35    'vPEXTRD': 0x03,
36    'vPEXTRQ': 0x01,
37    'vPINSRB': 0x0f,
38    'vPINSRW': 0x07,
39    'vPINSRD': 0x03,
40    'vPINSRQ': 0x01,
41    'vPSHUF[DW]': 0xff,
42    'vPSHUF[LH]W': 0xff,
43    'vPS[LR][AL][WDQ]': 0x3f,
44    'vPS[RL]LDQ': 0x1f,
45    'vROUND[PS][SD]': 0x7,
46    'SHA1RNDS4': 0x03,
47    'vSHUFPD': 0x0f,
48    'vSHUFPS': 0xff,
49    'vAESKEYGENASSIST': 0xff,
50    'VEXTRACT[FI]128': 0x01,
51    'VINSERT[FI]128': 0x01,
52    'VPBLENDD': 0xff,
53    'VPERM2[FI]128': 0xbb,
54    'VPERMPD': 0xff,
55    'VPERMQ': 0xff,
56    'VPERMILPS': 0xff,
57    'VPERMILPD': 0x0f,
58    }
59
60def strip_comments(x):
61    for l in x:
62        if l != '' and l[0] != '#':
63            yield l
64
65def reg_w(w):
66    if w == 8:
67        return 'al'
68    elif w == 16:
69        return 'ax'
70    elif w == 32:
71        return 'eax'
72    elif w == 64:
73        return 'rax'
74    raise Exception("bad reg_w %d" % w)
75
76def mem_w(w):
77    if w == 8:
78        t = "BYTE"
79    elif w == 16:
80        t = "WORD"
81    elif w == 32:
82        t = "DWORD"
83    elif w == 64:
84        t = "QWORD"
85    elif w == 128:
86        t = "XMMWORD"
87    elif w == 256:
88        t = "YMMWORD"
89    else:
90        raise Exception()
91
92    return t + " PTR 32[rdx]"
93
94class XMMArg():
95    isxmm = True
96    def __init__(self, reg, mw):
97        if mw not in [0, 8, 16, 32, 64, 128, 256]:
98            raise Exception("Bad /m width: %s" % w)
99        self.reg = reg
100        self.mw = mw
101        self.ismem = mw != 0
102    def regstr(self, n):
103        if n < 0:
104            return mem_w(self.mw)
105        else:
106            return "%smm%d" % (self.reg, n)
107
108class MMArg():
109    isxmm = True
110    def __init__(self, mw):
111        if mw not in [0, 32, 64]:
112            raise Exception("Bad mem width: %s" % mw)
113        self.mw = mw
114        self.ismem = mw != 0
115    def regstr(self, n):
116        return "mm%d" % (n & 7)
117
118def match(op, pattern):
119    if pattern[0] == 'v':
120        return fnmatch(op, pattern[1:]) or fnmatch(op, 'V'+pattern[1:])
121    return fnmatch(op, pattern)
122
123class ArgVSIB():
124    isxmm = True
125    ismem = False
126    def __init__(self, reg, w):
127        if w not in [32, 64]:
128            raise Exception("Bad vsib width: %s" % w)
129        self.w = w
130        self.reg = reg
131    def regstr(self, n):
132        reg = "%smm%d" % (self.reg, n >> 2)
133        return "[rsi + %s * %d]" % (reg, 1 << (n & 3))
134
135class ArgImm8u():
136    isxmm = False
137    ismem = False
138    def __init__(self, op):
139        for k, v in imask.items():
140            if match(op, k):
141                self.mask = imask[k];
142                return
143        raise Exception("Unknown immediate")
144    def vals(self):
145        mask = self.mask
146        yield 0
147        n = 0
148        while n != mask:
149            n += 1
150            while (n & ~mask) != 0:
151                n += (n & ~mask)
152            yield n
153
154class ArgRM():
155    isxmm = False
156    def __init__(self, rw, mw):
157        if rw not in [8, 16, 32, 64]:
158            raise Exception("Bad r/w width: %s" % w)
159        if mw not in [0, 8, 16, 32, 64]:
160            raise Exception("Bad r/w width: %s" % w)
161        self.rw = rw
162        self.mw = mw
163        self.ismem = mw != 0
164    def regstr(self, n):
165        if n < 0:
166            return mem_w(self.mw)
167        else:
168            return reg_w(self.rw)
169
170class ArgMem():
171    isxmm = False
172    ismem = True
173    def __init__(self, w):
174        if w not in [8, 16, 32, 64, 128, 256]:
175            raise Exception("Bad mem width: %s" % w)
176        self.w = w
177    def regstr(self, n):
178        return mem_w(self.w)
179
180class SkipInstruction(Exception):
181    pass
182
183def ArgGenerator(arg, op):
184    if arg[:3] == 'xmm' or arg[:3] == "ymm":
185        if "/" in arg:
186            r, m = arg.split('/')
187            if (m[0] != 'm'):
188                raise Exception("Expected /m: %s", arg)
189            return XMMArg(arg[0], int(m[1:]));
190        else:
191            return XMMArg(arg[0], 0);
192    elif arg[:2] == 'mm':
193        if "/" in arg:
194            r, m = arg.split('/')
195            if (m[0] != 'm'):
196                raise Exception("Expected /m: %s", arg)
197            return MMArg(int(m[1:]));
198        else:
199            return MMArg(0);
200    elif arg[:4] == 'imm8':
201        return ArgImm8u(op);
202    elif arg == '<XMM0>':
203        return None
204    elif arg[0] == 'r':
205        if '/m' in arg:
206            r, m = arg.split('/')
207            if (m[0] != 'm'):
208                raise Exception("Expected /m: %s", arg)
209            mw = int(m[1:])
210            if r == 'r':
211                rw = mw
212            else:
213                rw = int(r[1:])
214            return ArgRM(rw, mw)
215
216        return ArgRM(int(arg[1:]), 0);
217    elif arg[0] == 'm':
218        return ArgMem(int(arg[1:]))
219    elif arg[:2] == 'vm':
220        return ArgVSIB(arg[-1], int(arg[2:-1]))
221    else:
222        raise Exception("Unrecognised arg: %s", arg)
223
224class InsnGenerator:
225    def __init__(self, op, args):
226        self.op = op
227        if op[-2:] in ["PH", "PS", "PD", "SS", "SD"]:
228            if op[-1] == 'H':
229                self.optype = 'F16'
230            elif op[-1] == 'S':
231                self.optype = 'F32'
232            else:
233                self.optype = 'F64'
234        else:
235            self.optype = 'I'
236
237        try:
238            self.args = list(ArgGenerator(a, op) for a in args)
239            if not any((x.isxmm for x in self.args)):
240                raise SkipInstruction
241            if len(self.args) > 0 and self.args[-1] is None:
242                self.args = self.args[:-1]
243        except SkipInstruction:
244            raise
245        except Exception as e:
246            raise Exception("Bad arg %s: %s" % (op, e))
247
248    def gen(self):
249        regs = (10, 11, 12)
250        dest = 9
251
252        nreg = len(self.args)
253        if nreg == 0:
254            yield self.op
255            return
256        if isinstance(self.args[-1], ArgImm8u):
257            nreg -= 1
258            immarg = self.args[-1]
259        else:
260            immarg = None
261        memarg = -1
262        for n, arg in enumerate(self.args):
263            if arg.ismem:
264                memarg = n
265
266        if (self.op.startswith("VGATHER") or self.op.startswith("VPGATHER")):
267            if "GATHERD" in self.op:
268                ireg = 13 << 2
269            else:
270                ireg = 14 << 2
271            regset = [
272                (dest, ireg | 0, regs[0]),
273                (dest, ireg | 1, regs[0]),
274                (dest, ireg | 2, regs[0]),
275                (dest, ireg | 3, regs[0]),
276                ]
277            if memarg >= 0:
278                raise Exception("vsib with memory: %s" % self.op)
279        elif nreg == 1:
280            regset = [(regs[0],)]
281            if memarg == 0:
282                regset += [(-1,)]
283        elif nreg == 2:
284            regset = [
285                (regs[0], regs[1]),
286                (regs[0], regs[0]),
287                ]
288            if memarg == 0:
289                regset += [(-1, regs[0])]
290            elif memarg == 1:
291                regset += [(dest, -1)]
292        elif nreg == 3:
293            regset = [
294                (dest, regs[0], regs[1]),
295                (dest, regs[0], regs[0]),
296                (regs[0], regs[0], regs[1]),
297                (regs[0], regs[1], regs[0]),
298                (regs[0], regs[0], regs[0]),
299                ]
300            if memarg == 2:
301                regset += [
302                    (dest, regs[0], -1),
303                    (regs[0], regs[0], -1),
304                    ]
305            elif memarg > 0:
306                raise Exception("Memarg %d" % memarg)
307        elif nreg == 4:
308            regset = [
309                (dest, regs[0], regs[1], regs[2]),
310                (dest, regs[0], regs[0], regs[1]),
311                (dest, regs[0], regs[1], regs[0]),
312                (dest, regs[1], regs[0], regs[0]),
313                (dest, regs[0], regs[0], regs[0]),
314                (regs[0], regs[0], regs[1], regs[2]),
315                (regs[0], regs[1], regs[0], regs[2]),
316                (regs[0], regs[1], regs[2], regs[0]),
317                (regs[0], regs[0], regs[0], regs[1]),
318                (regs[0], regs[0], regs[1], regs[0]),
319                (regs[0], regs[1], regs[0], regs[0]),
320                (regs[0], regs[0], regs[0], regs[0]),
321                ]
322            if memarg == 2:
323                regset += [
324                    (dest, regs[0], -1, regs[1]),
325                    (dest, regs[0], -1, regs[0]),
326                    (regs[0], regs[0], -1, regs[1]),
327                    (regs[0], regs[1], -1, regs[0]),
328                    (regs[0], regs[0], -1, regs[0]),
329                    ]
330            elif memarg > 0:
331                raise Exception("Memarg4 %d" % memarg)
332        else:
333            raise Exception("Too many regs: %s(%d)" % (self.op, nreg))
334
335        for regv in regset:
336            argstr = []
337            for i in range(nreg):
338                arg = self.args[i]
339                argstr.append(arg.regstr(regv[i]))
340            if immarg is None:
341                yield self.op + ' ' + ','.join(argstr)
342            else:
343                for immval in immarg.vals():
344                    yield self.op + ' ' + ','.join(argstr) + ',' + str(immval)
345
346def split0(s):
347    if s == '':
348        return []
349    return s.split(',')
350
351def main():
352    n = 0
353    if len(sys.argv) != 3:
354        print("Usage: test-avx.py x86.csv test-avx.h")
355        exit(1)
356    csvfile = open(sys.argv[1], 'r', newline='')
357    with open(sys.argv[2], "w") as outf:
358        outf.write("// Generated by test-avx.py. Do not edit.\n")
359        for row in csv.reader(strip_comments(csvfile)):
360            insn = row[0].replace(',', '').split()
361            if insn[0] in ignore:
362                continue
363            cpuid = row[6]
364            if cpuid in archs:
365                try:
366                    g = InsnGenerator(insn[0], insn[1:])
367                    for insn in g.gen():
368                        outf.write('TEST(%d, "%s", %s)\n' % (n, insn, g.optype))
369                        n += 1
370                except SkipInstruction:
371                    pass
372        outf.write("#undef TEST\n")
373        csvfile.close()
374
375if __name__ == "__main__":
376    main()
377