xref: /openbmc/qemu/scripts/decodetree.py (revision b4b9a0e32f93c0700f46617524317b0580126592)
1#!/usr/bin/env python3
2# Copyright (c) 2018 Linaro Limited
3#
4# This library is free software; you can redistribute it and/or
5# modify it under the terms of the GNU Lesser General Public
6# License as published by the Free Software Foundation; either
7# version 2.1 of the License, or (at your option) any later version.
8#
9# This library is distributed in the hope that it will be useful,
10# but WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12# Lesser General Public License for more details.
13#
14# You should have received a copy of the GNU Lesser General Public
15# License along with this library; if not, see <http://www.gnu.org/licenses/>.
16#
17
18#
19# Generate a decoding tree from a specification file.
20# See the syntax and semantics in docs/devel/decodetree.rst.
21#
22
23import io
24import os
25import re
26import sys
27import getopt
28
29insnwidth = 32
30bitop_width = 32
31insnmask = 0xffffffff
32variablewidth = False
33fields = {}
34arguments = {}
35formats = {}
36allpatterns = []
37anyextern = False
38
39translate_prefix = 'trans'
40translate_scope = 'static '
41input_file = ''
42output_file = None
43output_fd = None
44insntype = 'uint32_t'
45decode_function = 'decode'
46
47# An identifier for C.
48re_C_ident = '[a-zA-Z][a-zA-Z0-9_]*'
49
50# Identifiers for Arguments, Fields, Formats and Patterns.
51re_arg_ident = '&[a-zA-Z0-9_]*'
52re_fld_ident = '%[a-zA-Z0-9_]*'
53re_fmt_ident = '@[a-zA-Z0-9_]*'
54re_pat_ident = '[a-zA-Z0-9_]*'
55
56def error_with_file(file, lineno, *args):
57    """Print an error message from file:line and args and exit."""
58    global output_file
59    global output_fd
60
61    prefix = ''
62    if file:
63        prefix += f'{file}:'
64    if lineno:
65        prefix += f'{lineno}:'
66    if prefix:
67        prefix += ' '
68    print(prefix, end='error: ', file=sys.stderr)
69    print(*args, file=sys.stderr)
70
71    if output_file and output_fd:
72        output_fd.close()
73        os.remove(output_file)
74    exit(1)
75# end error_with_file
76
77
78def error(lineno, *args):
79    error_with_file(input_file, lineno, *args)
80# end error
81
82
83def output(*args):
84    global output_fd
85    for a in args:
86        output_fd.write(a)
87
88
89def output_autogen():
90    output('/* This file is autogenerated by scripts/decodetree.py.  */\n\n')
91
92
93def str_indent(c):
94    """Return a string with C spaces"""
95    return ' ' * c
96
97
98def str_fields(fields):
99    """Return a string uniquely identifying FIELDS"""
100    r = ''
101    for n in sorted(fields.keys()):
102        r += '_' + n
103    return r[1:]
104
105
106def whex(val):
107    """Return a hex string for val padded for insnwidth"""
108    global insnwidth
109    return f'0x{val:0{insnwidth // 4}x}'
110
111
112def whexC(val):
113    """Return a hex string for val padded for insnwidth,
114       and with the proper suffix for a C constant."""
115    suffix = ''
116    if val >= 0x100000000:
117        suffix = 'ull'
118    elif val >= 0x80000000:
119        suffix = 'u'
120    return whex(val) + suffix
121
122
123def str_match_bits(bits, mask):
124    """Return a string pretty-printing BITS/MASK"""
125    global insnwidth
126
127    i = 1 << (insnwidth - 1)
128    space = 0x01010100
129    r = ''
130    while i != 0:
131        if i & mask:
132            if i & bits:
133                r += '1'
134            else:
135                r += '0'
136        else:
137            r += '.'
138        if i & space:
139            r += ' '
140        i >>= 1
141    return r
142
143
144def is_pow2(x):
145    """Return true iff X is equal to a power of 2."""
146    return (x & (x - 1)) == 0
147
148
149def ctz(x):
150    """Return the number of times 2 factors into X."""
151    assert x != 0
152    r = 0
153    while ((x >> r) & 1) == 0:
154        r += 1
155    return r
156
157
158def is_contiguous(bits):
159    if bits == 0:
160        return -1
161    shift = ctz(bits)
162    if is_pow2((bits >> shift) + 1):
163        return shift
164    else:
165        return -1
166
167
168def eq_fields_for_args(flds_a, arg):
169    if len(flds_a) != len(arg.fields):
170        return False
171    # Only allow inference on default types
172    for t in arg.types:
173        if t != 'int':
174            return False
175    for k, a in flds_a.items():
176        if k not in arg.fields:
177            return False
178    return True
179
180
181def eq_fields_for_fmts(flds_a, flds_b):
182    if len(flds_a) != len(flds_b):
183        return False
184    for k, a in flds_a.items():
185        if k not in flds_b:
186            return False
187        b = flds_b[k]
188        if a.__class__ != b.__class__ or a != b:
189            return False
190    return True
191
192
193class Field:
194    """Class representing a simple instruction field"""
195    def __init__(self, sign, pos, len):
196        self.sign = sign
197        self.pos = pos
198        self.len = len
199        self.mask = ((1 << len) - 1) << pos
200
201    def __str__(self):
202        if self.sign:
203            s = 's'
204        else:
205            s = ''
206        return str(self.pos) + ':' + s + str(self.len)
207
208    def str_extract(self):
209        global bitop_width
210        s = 's' if self.sign else ''
211        return f'{s}extract{bitop_width}(insn, {self.pos}, {self.len})'
212
213    def __eq__(self, other):
214        return self.sign == other.sign and self.mask == other.mask
215
216    def __ne__(self, other):
217        return not self.__eq__(other)
218# end Field
219
220
221class MultiField:
222    """Class representing a compound instruction field"""
223    def __init__(self, subs, mask):
224        self.subs = subs
225        self.sign = subs[0].sign
226        self.mask = mask
227
228    def __str__(self):
229        return str(self.subs)
230
231    def str_extract(self):
232        global bitop_width
233        ret = '0'
234        pos = 0
235        for f in reversed(self.subs):
236            ext = f.str_extract()
237            if pos == 0:
238                ret = ext
239            else:
240                ret = f'deposit{bitop_width}({ret}, {pos}, {bitop_width - pos}, {ext})'
241            pos += f.len
242        return ret
243
244    def __ne__(self, other):
245        if len(self.subs) != len(other.subs):
246            return True
247        for a, b in zip(self.subs, other.subs):
248            if a.__class__ != b.__class__ or a != b:
249                return True
250        return False
251
252    def __eq__(self, other):
253        return not self.__ne__(other)
254# end MultiField
255
256
257class ConstField:
258    """Class representing an argument field with constant value"""
259    def __init__(self, value):
260        self.value = value
261        self.mask = 0
262        self.sign = value < 0
263
264    def __str__(self):
265        return str(self.value)
266
267    def str_extract(self):
268        return str(self.value)
269
270    def __cmp__(self, other):
271        return self.value - other.value
272# end ConstField
273
274
275class FunctionField:
276    """Class representing a field passed through a function"""
277    def __init__(self, func, base):
278        self.mask = base.mask
279        self.sign = base.sign
280        self.base = base
281        self.func = func
282
283    def __str__(self):
284        return self.func + '(' + str(self.base) + ')'
285
286    def str_extract(self):
287        return self.func + '(ctx, ' + self.base.str_extract() + ')'
288
289    def __eq__(self, other):
290        return self.func == other.func and self.base == other.base
291
292    def __ne__(self, other):
293        return not self.__eq__(other)
294# end FunctionField
295
296
297class ParameterField:
298    """Class representing a pseudo-field read from a function"""
299    def __init__(self, func):
300        self.mask = 0
301        self.sign = 0
302        self.func = func
303
304    def __str__(self):
305        return self.func
306
307    def str_extract(self):
308        return self.func + '(ctx)'
309
310    def __eq__(self, other):
311        return self.func == other.func
312
313    def __ne__(self, other):
314        return not self.__eq__(other)
315# end ParameterField
316
317
318class Arguments:
319    """Class representing the extracted fields of a format"""
320    def __init__(self, nm, flds, types, extern):
321        self.name = nm
322        self.extern = extern
323        self.fields = flds
324        self.types = types
325
326    def __str__(self):
327        return self.name + ' ' + str(self.fields)
328
329    def struct_name(self):
330        return 'arg_' + self.name
331
332    def output_def(self):
333        if not self.extern:
334            output('typedef struct {\n')
335            for (n, t) in zip(self.fields, self.types):
336                output(f'    {t} {n};\n')
337            output('} ', self.struct_name(), ';\n\n')
338# end Arguments
339
340
341class General:
342    """Common code between instruction formats and instruction patterns"""
343    def __init__(self, name, lineno, base, fixb, fixm, udfm, fldm, flds, w):
344        self.name = name
345        self.file = input_file
346        self.lineno = lineno
347        self.base = base
348        self.fixedbits = fixb
349        self.fixedmask = fixm
350        self.undefmask = udfm
351        self.fieldmask = fldm
352        self.fields = flds
353        self.width = w
354
355    def __str__(self):
356        return self.name + ' ' + str_match_bits(self.fixedbits, self.fixedmask)
357
358    def str1(self, i):
359        return str_indent(i) + self.__str__()
360# end General
361
362
363class Format(General):
364    """Class representing an instruction format"""
365
366    def extract_name(self):
367        global decode_function
368        return decode_function + '_extract_' + self.name
369
370    def output_extract(self):
371        output('static void ', self.extract_name(), '(DisasContext *ctx, ',
372               self.base.struct_name(), ' *a, ', insntype, ' insn)\n{\n')
373        for n, f in self.fields.items():
374            output('    a->', n, ' = ', f.str_extract(), ';\n')
375        output('}\n\n')
376# end Format
377
378
379class Pattern(General):
380    """Class representing an instruction pattern"""
381
382    def output_decl(self):
383        global translate_scope
384        global translate_prefix
385        output('typedef ', self.base.base.struct_name(),
386               ' arg_', self.name, ';\n')
387        output(translate_scope, 'bool ', translate_prefix, '_', self.name,
388               '(DisasContext *ctx, arg_', self.name, ' *a);\n')
389
390    def output_code(self, i, extracted, outerbits, outermask):
391        global translate_prefix
392        ind = str_indent(i)
393        arg = self.base.base.name
394        output(ind, '/* ', self.file, ':', str(self.lineno), ' */\n')
395        if not extracted:
396            output(ind, self.base.extract_name(),
397                   '(ctx, &u.f_', arg, ', insn);\n')
398        for n, f in self.fields.items():
399            output(ind, 'u.f_', arg, '.', n, ' = ', f.str_extract(), ';\n')
400        output(ind, 'if (', translate_prefix, '_', self.name,
401               '(ctx, &u.f_', arg, ')) return true;\n')
402
403    # Normal patterns do not have children.
404    def build_tree(self):
405        return
406    def prop_masks(self):
407        return
408    def prop_format(self):
409        return
410    def prop_width(self):
411        return
412
413# end Pattern
414
415
416class MultiPattern(General):
417    """Class representing a set of instruction patterns"""
418
419    def __init__(self, lineno):
420        self.file = input_file
421        self.lineno = lineno
422        self.pats = []
423        self.base = None
424        self.fixedbits = 0
425        self.fixedmask = 0
426        self.undefmask = 0
427        self.width = None
428
429    def __str__(self):
430        r = 'group'
431        if self.fixedbits is not None:
432            r += ' ' + str_match_bits(self.fixedbits, self.fixedmask)
433        return r
434
435    def output_decl(self):
436        for p in self.pats:
437            p.output_decl()
438
439    def prop_masks(self):
440        global insnmask
441
442        fixedmask = insnmask
443        undefmask = insnmask
444
445        # Collect fixedmask/undefmask for all of the children.
446        for p in self.pats:
447            p.prop_masks()
448            fixedmask &= p.fixedmask
449            undefmask &= p.undefmask
450
451        # Widen fixedmask until all fixedbits match
452        repeat = True
453        fixedbits = 0
454        while repeat and fixedmask != 0:
455            fixedbits = None
456            for p in self.pats:
457                thisbits = p.fixedbits & fixedmask
458                if fixedbits is None:
459                    fixedbits = thisbits
460                elif fixedbits != thisbits:
461                    fixedmask &= ~(fixedbits ^ thisbits)
462                    break
463            else:
464                repeat = False
465
466        self.fixedbits = fixedbits
467        self.fixedmask = fixedmask
468        self.undefmask = undefmask
469
470    def build_tree(self):
471        for p in self.pats:
472            p.build_tree()
473
474    def prop_format(self):
475        for p in self.pats:
476            p.build_tree()
477
478    def prop_width(self):
479        width = None
480        for p in self.pats:
481            p.prop_width()
482            if width is None:
483                width = p.width
484            elif width != p.width:
485                error_with_file(self.file, self.lineno,
486                                'width mismatch in patterns within braces')
487        self.width = width
488
489# end MultiPattern
490
491
492class IncMultiPattern(MultiPattern):
493    """Class representing an overlapping set of instruction patterns"""
494
495    def output_code(self, i, extracted, outerbits, outermask):
496        global translate_prefix
497        ind = str_indent(i)
498        for p in self.pats:
499            if outermask != p.fixedmask:
500                innermask = p.fixedmask & ~outermask
501                innerbits = p.fixedbits & ~outermask
502                output(ind, f'if ((insn & {whexC(innermask)}) == {whexC(innerbits)}) {{\n')
503                output(ind, f'    /* {str_match_bits(p.fixedbits, p.fixedmask)} */\n')
504                p.output_code(i + 4, extracted, p.fixedbits, p.fixedmask)
505                output(ind, '}\n')
506            else:
507                p.output_code(i, extracted, p.fixedbits, p.fixedmask)
508#end IncMultiPattern
509
510
511class Tree:
512    """Class representing a node in a decode tree"""
513
514    def __init__(self, fm, tm):
515        self.fixedmask = fm
516        self.thismask = tm
517        self.subs = []
518        self.base = None
519
520    def str1(self, i):
521        ind = str_indent(i)
522        r = ind + whex(self.fixedmask)
523        if self.format:
524            r += ' ' + self.format.name
525        r += ' [\n'
526        for (b, s) in self.subs:
527            r += ind + f'  {whex(b)}:\n'
528            r += s.str1(i + 4) + '\n'
529        r += ind + ']'
530        return r
531
532    def __str__(self):
533        return self.str1(0)
534
535    def output_code(self, i, extracted, outerbits, outermask):
536        ind = str_indent(i)
537
538        # If we identified all nodes below have the same format,
539        # extract the fields now.
540        if not extracted and self.base:
541            output(ind, self.base.extract_name(),
542                   '(ctx, &u.f_', self.base.base.name, ', insn);\n')
543            extracted = True
544
545        # Attempt to aid the compiler in producing compact switch statements.
546        # If the bits in the mask are contiguous, extract them.
547        sh = is_contiguous(self.thismask)
548        if sh > 0:
549            # Propagate SH down into the local functions.
550            def str_switch(b, sh=sh):
551                return f'(insn >> {sh}) & {b >> sh:#x}'
552
553            def str_case(b, sh=sh):
554                return hex(b >> sh)
555        else:
556            def str_switch(b):
557                return f'insn & {whexC(b)}'
558
559            def str_case(b):
560                return whexC(b)
561
562        output(ind, 'switch (', str_switch(self.thismask), ') {\n')
563        for b, s in sorted(self.subs):
564            assert (self.thismask & ~s.fixedmask) == 0
565            innermask = outermask | self.thismask
566            innerbits = outerbits | b
567            output(ind, 'case ', str_case(b), ':\n')
568            output(ind, '    /* ',
569                   str_match_bits(innerbits, innermask), ' */\n')
570            s.output_code(i + 4, extracted, innerbits, innermask)
571            output(ind, '    break;\n')
572        output(ind, '}\n')
573# end Tree
574
575
576class ExcMultiPattern(MultiPattern):
577    """Class representing a non-overlapping set of instruction patterns"""
578
579    def output_code(self, i, extracted, outerbits, outermask):
580        # Defer everything to our decomposed Tree node
581        self.tree.output_code(i, extracted, outerbits, outermask)
582
583    @staticmethod
584    def __build_tree(pats, outerbits, outermask):
585        # Find the intersection of all remaining fixedmask.
586        innermask = ~outermask & insnmask
587        for i in pats:
588            innermask &= i.fixedmask
589
590        if innermask == 0:
591            # Edge condition: One pattern covers the entire insnmask
592            if len(pats) == 1:
593                t = Tree(outermask, innermask)
594                t.subs.append((0, pats[0]))
595                return t
596
597            text = 'overlapping patterns:'
598            for p in pats:
599                text += '\n' + p.file + ':' + str(p.lineno) + ': ' + str(p)
600            error_with_file(pats[0].file, pats[0].lineno, text)
601
602        fullmask = outermask | innermask
603
604        # Sort each element of pats into the bin selected by the mask.
605        bins = {}
606        for i in pats:
607            fb = i.fixedbits & innermask
608            if fb in bins:
609                bins[fb].append(i)
610            else:
611                bins[fb] = [i]
612
613        # We must recurse if any bin has more than one element or if
614        # the single element in the bin has not been fully matched.
615        t = Tree(fullmask, innermask)
616
617        for b, l in bins.items():
618            s = l[0]
619            if len(l) > 1 or s.fixedmask & ~fullmask != 0:
620                s = ExcMultiPattern.__build_tree(l, b | outerbits, fullmask)
621            t.subs.append((b, s))
622
623        return t
624
625    def build_tree(self):
626        super().prop_format()
627        self.tree = self.__build_tree(self.pats, self.fixedbits,
628                                      self.fixedmask)
629
630    @staticmethod
631    def __prop_format(tree):
632        """Propagate Format objects into the decode tree"""
633
634        # Depth first search.
635        for (b, s) in tree.subs:
636            if isinstance(s, Tree):
637                ExcMultiPattern.__prop_format(s)
638
639        # If all entries in SUBS have the same format, then
640        # propagate that into the tree.
641        f = None
642        for (b, s) in tree.subs:
643            if f is None:
644                f = s.base
645                if f is None:
646                    return
647            if f is not s.base:
648                return
649        tree.base = f
650
651    def prop_format(self):
652        super().prop_format()
653        self.__prop_format(self.tree)
654
655# end ExcMultiPattern
656
657
658def parse_field(lineno, name, toks):
659    """Parse one instruction field from TOKS at LINENO"""
660    global fields
661    global insnwidth
662
663    # A "simple" field will have only one entry;
664    # a "multifield" will have several.
665    subs = []
666    width = 0
667    func = None
668    for t in toks:
669        if re.match('^!function=', t):
670            if func:
671                error(lineno, 'duplicate function')
672            func = t.split('=')
673            func = func[1]
674            continue
675
676        if re.fullmatch('[0-9]+:s[0-9]+', t):
677            # Signed field extract
678            subtoks = t.split(':s')
679            sign = True
680        elif re.fullmatch('[0-9]+:[0-9]+', t):
681            # Unsigned field extract
682            subtoks = t.split(':')
683            sign = False
684        else:
685            error(lineno, f'invalid field token "{t}"')
686        po = int(subtoks[0])
687        le = int(subtoks[1])
688        if po + le > insnwidth:
689            error(lineno, f'field {t} too large')
690        f = Field(sign, po, le)
691        subs.append(f)
692        width += le
693
694    if width > insnwidth:
695        error(lineno, 'field too large')
696    if len(subs) == 0:
697        if func:
698            f = ParameterField(func)
699        else:
700            error(lineno, 'field with no value')
701    else:
702        if len(subs) == 1:
703            f = subs[0]
704        else:
705            mask = 0
706            for s in subs:
707                if mask & s.mask:
708                    error(lineno, 'field components overlap')
709                mask |= s.mask
710            f = MultiField(subs, mask)
711        if func:
712            f = FunctionField(func, f)
713
714    if name in fields:
715        error(lineno, 'duplicate field', name)
716    fields[name] = f
717# end parse_field
718
719
720def parse_arguments(lineno, name, toks):
721    """Parse one argument set from TOKS at LINENO"""
722    global arguments
723    global re_C_ident
724    global anyextern
725
726    flds = []
727    types = []
728    extern = False
729    for n in toks:
730        if re.fullmatch('!extern', n):
731            extern = True
732            anyextern = True
733            continue
734        if re.fullmatch(re_C_ident + ':' + re_C_ident, n):
735            (n, t) = n.split(':')
736        elif re.fullmatch(re_C_ident, n):
737            t = 'int'
738        else:
739            error(lineno, f'invalid argument set token "{n}"')
740        if n in flds:
741            error(lineno, f'duplicate argument "{n}"')
742        flds.append(n)
743        types.append(t)
744
745    if name in arguments:
746        error(lineno, 'duplicate argument set', name)
747    arguments[name] = Arguments(name, flds, types, extern)
748# end parse_arguments
749
750
751def lookup_field(lineno, name):
752    global fields
753    if name in fields:
754        return fields[name]
755    error(lineno, 'undefined field', name)
756
757
758def add_field(lineno, flds, new_name, f):
759    if new_name in flds:
760        error(lineno, 'duplicate field', new_name)
761    flds[new_name] = f
762    return flds
763
764
765def add_field_byname(lineno, flds, new_name, old_name):
766    return add_field(lineno, flds, new_name, lookup_field(lineno, old_name))
767
768
769def infer_argument_set(flds):
770    global arguments
771    global decode_function
772
773    for arg in arguments.values():
774        if eq_fields_for_args(flds, arg):
775            return arg
776
777    name = decode_function + str(len(arguments))
778    arg = Arguments(name, flds.keys(), ['int'] * len(flds), False)
779    arguments[name] = arg
780    return arg
781
782
783def infer_format(arg, fieldmask, flds, width):
784    global arguments
785    global formats
786    global decode_function
787
788    const_flds = {}
789    var_flds = {}
790    for n, c in flds.items():
791        if c is ConstField:
792            const_flds[n] = c
793        else:
794            var_flds[n] = c
795
796    # Look for an existing format with the same argument set and fields
797    for fmt in formats.values():
798        if arg and fmt.base != arg:
799            continue
800        if fieldmask != fmt.fieldmask:
801            continue
802        if width != fmt.width:
803            continue
804        if not eq_fields_for_fmts(flds, fmt.fields):
805            continue
806        return (fmt, const_flds)
807
808    name = decode_function + '_Fmt_' + str(len(formats))
809    if not arg:
810        arg = infer_argument_set(flds)
811
812    fmt = Format(name, 0, arg, 0, 0, 0, fieldmask, var_flds, width)
813    formats[name] = fmt
814
815    return (fmt, const_flds)
816# end infer_format
817
818
819def parse_generic(lineno, parent_pat, name, toks):
820    """Parse one instruction format from TOKS at LINENO"""
821    global fields
822    global arguments
823    global formats
824    global allpatterns
825    global re_arg_ident
826    global re_fld_ident
827    global re_fmt_ident
828    global re_C_ident
829    global insnwidth
830    global insnmask
831    global variablewidth
832
833    is_format = parent_pat is None
834
835    fixedmask = 0
836    fixedbits = 0
837    undefmask = 0
838    width = 0
839    flds = {}
840    arg = None
841    fmt = None
842    for t in toks:
843        # '&Foo' gives a format an explicit argument set.
844        if re.fullmatch(re_arg_ident, t):
845            tt = t[1:]
846            if arg:
847                error(lineno, 'multiple argument sets')
848            if tt in arguments:
849                arg = arguments[tt]
850            else:
851                error(lineno, 'undefined argument set', t)
852            continue
853
854        # '@Foo' gives a pattern an explicit format.
855        if re.fullmatch(re_fmt_ident, t):
856            tt = t[1:]
857            if fmt:
858                error(lineno, 'multiple formats')
859            if tt in formats:
860                fmt = formats[tt]
861            else:
862                error(lineno, 'undefined format', t)
863            continue
864
865        # '%Foo' imports a field.
866        if re.fullmatch(re_fld_ident, t):
867            tt = t[1:]
868            flds = add_field_byname(lineno, flds, tt, tt)
869            continue
870
871        # 'Foo=%Bar' imports a field with a different name.
872        if re.fullmatch(re_C_ident + '=' + re_fld_ident, t):
873            (fname, iname) = t.split('=%')
874            flds = add_field_byname(lineno, flds, fname, iname)
875            continue
876
877        # 'Foo=number' sets an argument field to a constant value
878        if re.fullmatch(re_C_ident + '=[+-]?[0-9]+', t):
879            (fname, value) = t.split('=')
880            value = int(value)
881            flds = add_field(lineno, flds, fname, ConstField(value))
882            continue
883
884        # Pattern of 0s, 1s, dots and dashes indicate required zeros,
885        # required ones, or dont-cares.
886        if re.fullmatch('[01.-]+', t):
887            shift = len(t)
888            fms = t.replace('0', '1')
889            fms = fms.replace('.', '0')
890            fms = fms.replace('-', '0')
891            fbs = t.replace('.', '0')
892            fbs = fbs.replace('-', '0')
893            ubm = t.replace('1', '0')
894            ubm = ubm.replace('.', '0')
895            ubm = ubm.replace('-', '1')
896            fms = int(fms, 2)
897            fbs = int(fbs, 2)
898            ubm = int(ubm, 2)
899            fixedbits = (fixedbits << shift) | fbs
900            fixedmask = (fixedmask << shift) | fms
901            undefmask = (undefmask << shift) | ubm
902        # Otherwise, fieldname:fieldwidth
903        elif re.fullmatch(re_C_ident + ':s?[0-9]+', t):
904            (fname, flen) = t.split(':')
905            sign = False
906            if flen[0] == 's':
907                sign = True
908                flen = flen[1:]
909            shift = int(flen, 10)
910            if shift + width > insnwidth:
911                error(lineno, f'field {fname} exceeds insnwidth')
912            f = Field(sign, insnwidth - width - shift, shift)
913            flds = add_field(lineno, flds, fname, f)
914            fixedbits <<= shift
915            fixedmask <<= shift
916            undefmask <<= shift
917        else:
918            error(lineno, f'invalid token "{t}"')
919        width += shift
920
921    if variablewidth and width < insnwidth and width % 8 == 0:
922        shift = insnwidth - width
923        fixedbits <<= shift
924        fixedmask <<= shift
925        undefmask <<= shift
926        undefmask |= (1 << shift) - 1
927
928    # We should have filled in all of the bits of the instruction.
929    elif not (is_format and width == 0) and width != insnwidth:
930        error(lineno, f'definition has {width} bits')
931
932    # Do not check for fields overlapping fields; one valid usage
933    # is to be able to duplicate fields via import.
934    fieldmask = 0
935    for f in flds.values():
936        fieldmask |= f.mask
937
938    # Fix up what we've parsed to match either a format or a pattern.
939    if is_format:
940        # Formats cannot reference formats.
941        if fmt:
942            error(lineno, 'format referencing format')
943        # If an argument set is given, then there should be no fields
944        # without a place to store it.
945        if arg:
946            for f in flds.keys():
947                if f not in arg.fields:
948                    error(lineno, f'field {f} not in argument set {arg.name}')
949        else:
950            arg = infer_argument_set(flds)
951        if name in formats:
952            error(lineno, 'duplicate format name', name)
953        fmt = Format(name, lineno, arg, fixedbits, fixedmask,
954                     undefmask, fieldmask, flds, width)
955        formats[name] = fmt
956    else:
957        # Patterns can reference a format ...
958        if fmt:
959            # ... but not an argument simultaneously
960            if arg:
961                error(lineno, 'pattern specifies both format and argument set')
962            if fixedmask & fmt.fixedmask:
963                error(lineno, 'pattern fixed bits overlap format fixed bits')
964            if width != fmt.width:
965                error(lineno, 'pattern uses format of different width')
966            fieldmask |= fmt.fieldmask
967            fixedbits |= fmt.fixedbits
968            fixedmask |= fmt.fixedmask
969            undefmask |= fmt.undefmask
970        else:
971            (fmt, flds) = infer_format(arg, fieldmask, flds, width)
972        arg = fmt.base
973        for f in flds.keys():
974            if f not in arg.fields:
975                error(lineno, f'field {f} not in argument set {arg.name}')
976            if f in fmt.fields.keys():
977                error(lineno, f'field {f} set by format and pattern')
978        for f in arg.fields:
979            if f not in flds.keys() and f not in fmt.fields.keys():
980                error(lineno, f'field {f} not initialized')
981        pat = Pattern(name, lineno, fmt, fixedbits, fixedmask,
982                      undefmask, fieldmask, flds, width)
983        parent_pat.pats.append(pat)
984        allpatterns.append(pat)
985
986    # Validate the masks that we have assembled.
987    if fieldmask & fixedmask:
988        error(lineno, 'fieldmask overlaps fixedmask ',
989              f'({whex(fieldmask)} & {whex(fixedmask)})')
990    if fieldmask & undefmask:
991        error(lineno, 'fieldmask overlaps undefmask ',
992              f'({whex(fieldmask)} & {whex(undefmask)})')
993    if fixedmask & undefmask:
994        error(lineno, 'fixedmask overlaps undefmask ',
995              f'({whex(fixedmask)} & {whex(undefmask)})')
996    if not is_format:
997        allbits = fieldmask | fixedmask | undefmask
998        if allbits != insnmask:
999            error(lineno, 'bits left unspecified ',
1000                  f'({whex(allbits ^ insnmask)})')
1001# end parse_general
1002
1003
1004def parse_file(f, parent_pat):
1005    """Parse all of the patterns within a file"""
1006    global re_arg_ident
1007    global re_fld_ident
1008    global re_fmt_ident
1009    global re_pat_ident
1010
1011    # Read all of the lines of the file.  Concatenate lines
1012    # ending in backslash; discard empty lines and comments.
1013    toks = []
1014    lineno = 0
1015    nesting = 0
1016    nesting_pats = []
1017
1018    for line in f:
1019        lineno += 1
1020
1021        # Expand and strip spaces, to find indent.
1022        line = line.rstrip()
1023        line = line.expandtabs()
1024        len1 = len(line)
1025        line = line.lstrip()
1026        len2 = len(line)
1027
1028        # Discard comments
1029        end = line.find('#')
1030        if end >= 0:
1031            line = line[:end]
1032
1033        t = line.split()
1034        if len(toks) != 0:
1035            # Next line after continuation
1036            toks.extend(t)
1037        else:
1038            # Allow completely blank lines.
1039            if len1 == 0:
1040                continue
1041            indent = len1 - len2
1042            # Empty line due to comment.
1043            if len(t) == 0:
1044                # Indentation must be correct, even for comment lines.
1045                if indent != nesting:
1046                    error(lineno, 'indentation ', indent, ' != ', nesting)
1047                continue
1048            start_lineno = lineno
1049            toks = t
1050
1051        # Continuation?
1052        if toks[-1] == '\\':
1053            toks.pop()
1054            continue
1055
1056        name = toks[0]
1057        del toks[0]
1058
1059        # End nesting?
1060        if name == '}' or name == ']':
1061            if len(toks) != 0:
1062                error(start_lineno, 'extra tokens after close brace')
1063
1064            # Make sure { } and [ ] nest properly.
1065            if (name == '}') != isinstance(parent_pat, IncMultiPattern):
1066                error(lineno, 'mismatched close brace')
1067
1068            try:
1069                parent_pat = nesting_pats.pop()
1070            except:
1071                error(lineno, 'extra close brace')
1072
1073            nesting -= 2
1074            if indent != nesting:
1075                error(lineno, 'indentation ', indent, ' != ', nesting)
1076
1077            toks = []
1078            continue
1079
1080        # Everything else should have current indentation.
1081        if indent != nesting:
1082            error(start_lineno, 'indentation ', indent, ' != ', nesting)
1083
1084        # Start nesting?
1085        if name == '{' or name == '[':
1086            if len(toks) != 0:
1087                error(start_lineno, 'extra tokens after open brace')
1088
1089            if name == '{':
1090                nested_pat = IncMultiPattern(start_lineno)
1091            else:
1092                nested_pat = ExcMultiPattern(start_lineno)
1093            parent_pat.pats.append(nested_pat)
1094            nesting_pats.append(parent_pat)
1095            parent_pat = nested_pat
1096
1097            nesting += 2
1098            toks = []
1099            continue
1100
1101        # Determine the type of object needing to be parsed.
1102        if re.fullmatch(re_fld_ident, name):
1103            parse_field(start_lineno, name[1:], toks)
1104        elif re.fullmatch(re_arg_ident, name):
1105            parse_arguments(start_lineno, name[1:], toks)
1106        elif re.fullmatch(re_fmt_ident, name):
1107            parse_generic(start_lineno, None, name[1:], toks)
1108        elif re.fullmatch(re_pat_ident, name):
1109            parse_generic(start_lineno, parent_pat, name, toks)
1110        else:
1111            error(lineno, f'invalid token "{name}"')
1112        toks = []
1113
1114    if nesting != 0:
1115        error(lineno, 'missing close brace')
1116# end parse_file
1117
1118
1119class SizeTree:
1120    """Class representing a node in a size decode tree"""
1121
1122    def __init__(self, m, w):
1123        self.mask = m
1124        self.subs = []
1125        self.base = None
1126        self.width = w
1127
1128    def str1(self, i):
1129        ind = str_indent(i)
1130        r = ind + whex(self.mask) + ' [\n'
1131        for (b, s) in self.subs:
1132            r += ind + f'  {whex(b)}:\n'
1133            r += s.str1(i + 4) + '\n'
1134        r += ind + ']'
1135        return r
1136
1137    def __str__(self):
1138        return self.str1(0)
1139
1140    def output_code(self, i, extracted, outerbits, outermask):
1141        ind = str_indent(i)
1142
1143        # If we need to load more bytes to test, do so now.
1144        if extracted < self.width:
1145            output(ind, f'insn = {decode_function}_load_bytes',
1146                   f'(ctx, insn, {extracted // 8}, {self.width // 8});\n')
1147            extracted = self.width
1148
1149        # Attempt to aid the compiler in producing compact switch statements.
1150        # If the bits in the mask are contiguous, extract them.
1151        sh = is_contiguous(self.mask)
1152        if sh > 0:
1153            # Propagate SH down into the local functions.
1154            def str_switch(b, sh=sh):
1155                return f'(insn >> {sh}) & {b >> sh:#x}'
1156
1157            def str_case(b, sh=sh):
1158                return hex(b >> sh)
1159        else:
1160            def str_switch(b):
1161                return f'insn & {whexC(b)}'
1162
1163            def str_case(b):
1164                return whexC(b)
1165
1166        output(ind, 'switch (', str_switch(self.mask), ') {\n')
1167        for b, s in sorted(self.subs):
1168            innermask = outermask | self.mask
1169            innerbits = outerbits | b
1170            output(ind, 'case ', str_case(b), ':\n')
1171            output(ind, '    /* ',
1172                   str_match_bits(innerbits, innermask), ' */\n')
1173            s.output_code(i + 4, extracted, innerbits, innermask)
1174        output(ind, '}\n')
1175        output(ind, 'return insn;\n')
1176# end SizeTree
1177
1178class SizeLeaf:
1179    """Class representing a leaf node in a size decode tree"""
1180
1181    def __init__(self, m, w):
1182        self.mask = m
1183        self.width = w
1184
1185    def str1(self, i):
1186        return str_indent(i) + whex(self.mask)
1187
1188    def __str__(self):
1189        return self.str1(0)
1190
1191    def output_code(self, i, extracted, outerbits, outermask):
1192        global decode_function
1193        ind = str_indent(i)
1194
1195        # If we need to load more bytes, do so now.
1196        if extracted < self.width:
1197            output(ind, f'insn = {decode_function}_load_bytes',
1198                   f'(ctx, insn, {extracted // 8}, {self.width // 8});\n')
1199            extracted = self.width
1200        output(ind, 'return insn;\n')
1201# end SizeLeaf
1202
1203
1204def build_size_tree(pats, width, outerbits, outermask):
1205    global insnwidth
1206
1207    # Collect the mask of bits that are fixed in this width
1208    innermask = 0xff << (insnwidth - width)
1209    innermask &= ~outermask
1210    minwidth = None
1211    onewidth = True
1212    for i in pats:
1213        innermask &= i.fixedmask
1214        if minwidth is None:
1215            minwidth = i.width
1216        elif minwidth != i.width:
1217            onewidth = False;
1218            if minwidth < i.width:
1219                minwidth = i.width
1220
1221    if onewidth:
1222        return SizeLeaf(innermask, minwidth)
1223
1224    if innermask == 0:
1225        if width < minwidth:
1226            return build_size_tree(pats, width + 8, outerbits, outermask)
1227
1228        pnames = []
1229        for p in pats:
1230            pnames.append(p.name + ':' + p.file + ':' + str(p.lineno))
1231        error_with_file(pats[0].file, pats[0].lineno,
1232                        f'overlapping patterns size {width}:', pnames)
1233
1234    bins = {}
1235    for i in pats:
1236        fb = i.fixedbits & innermask
1237        if fb in bins:
1238            bins[fb].append(i)
1239        else:
1240            bins[fb] = [i]
1241
1242    fullmask = outermask | innermask
1243    lens = sorted(bins.keys())
1244    if len(lens) == 1:
1245        b = lens[0]
1246        return build_size_tree(bins[b], width + 8, b | outerbits, fullmask)
1247
1248    r = SizeTree(innermask, width)
1249    for b, l in bins.items():
1250        s = build_size_tree(l, width, b | outerbits, fullmask)
1251        r.subs.append((b, s))
1252    return r
1253# end build_size_tree
1254
1255
1256def prop_size(tree):
1257    """Propagate minimum widths up the decode size tree"""
1258
1259    if isinstance(tree, SizeTree):
1260        min = None
1261        for (b, s) in tree.subs:
1262            width = prop_size(s)
1263            if min is None or min > width:
1264                min = width
1265        assert min >= tree.width
1266        tree.width = min
1267    else:
1268        min = tree.width
1269    return min
1270# end prop_size
1271
1272
1273def main():
1274    global arguments
1275    global formats
1276    global allpatterns
1277    global translate_scope
1278    global translate_prefix
1279    global output_fd
1280    global output_file
1281    global input_file
1282    global insnwidth
1283    global insntype
1284    global insnmask
1285    global decode_function
1286    global bitop_width
1287    global variablewidth
1288    global anyextern
1289
1290    decode_scope = 'static '
1291
1292    long_opts = ['decode=', 'translate=', 'output=', 'insnwidth=',
1293                 'static-decode=', 'varinsnwidth=']
1294    try:
1295        (opts, args) = getopt.gnu_getopt(sys.argv[1:], 'o:vw:', long_opts)
1296    except getopt.GetoptError as err:
1297        error(0, err)
1298    for o, a in opts:
1299        if o in ('-o', '--output'):
1300            output_file = a
1301        elif o == '--decode':
1302            decode_function = a
1303            decode_scope = ''
1304        elif o == '--static-decode':
1305            decode_function = a
1306        elif o == '--translate':
1307            translate_prefix = a
1308            translate_scope = ''
1309        elif o in ('-w', '--insnwidth', '--varinsnwidth'):
1310            if o == '--varinsnwidth':
1311                variablewidth = True
1312            insnwidth = int(a)
1313            if insnwidth == 16:
1314                insntype = 'uint16_t'
1315                insnmask = 0xffff
1316            elif insnwidth == 64:
1317                insntype = 'uint64_t'
1318                insnmask = 0xffffffffffffffff
1319                bitop_width = 64
1320            elif insnwidth != 32:
1321                error(0, 'cannot handle insns of width', insnwidth)
1322        else:
1323            assert False, 'unhandled option'
1324
1325    if len(args) < 1:
1326        error(0, 'missing input file')
1327
1328    toppat = ExcMultiPattern(0)
1329
1330    for filename in args:
1331        input_file = filename
1332        f = open(filename, 'rt', encoding='utf-8')
1333        parse_file(f, toppat)
1334        f.close()
1335
1336    # We do not want to compute masks for toppat, because those masks
1337    # are used as a starting point for build_tree.  For toppat, we must
1338    # insist that decode begins from naught.
1339    for i in toppat.pats:
1340        i.prop_masks()
1341
1342    toppat.build_tree()
1343    toppat.prop_format()
1344
1345    if variablewidth:
1346        for i in toppat.pats:
1347            i.prop_width()
1348        stree = build_size_tree(toppat.pats, 8, 0, 0)
1349        prop_size(stree)
1350
1351    if output_file:
1352        output_fd = open(output_file, 'wt', encoding='utf-8')
1353    else:
1354        output_fd = io.TextIOWrapper(sys.stdout.buffer,
1355                                     encoding=sys.stdout.encoding,
1356                                     errors="ignore")
1357
1358    output_autogen()
1359    for n in sorted(arguments.keys()):
1360        f = arguments[n]
1361        f.output_def()
1362
1363    # A single translate function can be invoked for different patterns.
1364    # Make sure that the argument sets are the same, and declare the
1365    # function only once.
1366    #
1367    # If we're sharing formats, we're likely also sharing trans_* functions,
1368    # but we can't tell which ones.  Prevent issues from the compiler by
1369    # suppressing redundant declaration warnings.
1370    if anyextern:
1371        output("#pragma GCC diagnostic push\n",
1372               "#pragma GCC diagnostic ignored \"-Wredundant-decls\"\n",
1373               "#ifdef __clang__\n"
1374               "#  pragma GCC diagnostic ignored \"-Wtypedef-redefinition\"\n",
1375               "#endif\n\n")
1376
1377    out_pats = {}
1378    for i in allpatterns:
1379        if i.name in out_pats:
1380            p = out_pats[i.name]
1381            if i.base.base != p.base.base:
1382                error(0, i.name, ' has conflicting argument sets')
1383        else:
1384            i.output_decl()
1385            out_pats[i.name] = i
1386    output('\n')
1387
1388    if anyextern:
1389        output("#pragma GCC diagnostic pop\n\n")
1390
1391    for n in sorted(formats.keys()):
1392        f = formats[n]
1393        f.output_extract()
1394
1395    output(decode_scope, 'bool ', decode_function,
1396           '(DisasContext *ctx, ', insntype, ' insn)\n{\n')
1397
1398    i4 = str_indent(4)
1399
1400    if len(allpatterns) != 0:
1401        output(i4, 'union {\n')
1402        for n in sorted(arguments.keys()):
1403            f = arguments[n]
1404            output(i4, i4, f.struct_name(), ' f_', f.name, ';\n')
1405        output(i4, '} u;\n\n')
1406        toppat.output_code(4, False, 0, 0)
1407
1408    output(i4, 'return false;\n')
1409    output('}\n')
1410
1411    if variablewidth:
1412        output('\n', decode_scope, insntype, ' ', decode_function,
1413               '_load(DisasContext *ctx)\n{\n',
1414               '    ', insntype, ' insn = 0;\n\n')
1415        stree.output_code(4, 0, 0, 0)
1416        output('}\n')
1417
1418    if output_file:
1419        output_fd.close()
1420# end main
1421
1422
1423if __name__ == '__main__':
1424    main()
1425