xref: /openbmc/linux/tools/net/ynl/ynl-gen-c.py (revision 844f5ed5)
1#!/usr/bin/env python3
2
3import argparse
4import collections
5import os
6import yaml
7
8from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation
9
10
11def c_upper(name):
12    return name.upper().replace('-', '_')
13
14
15def c_lower(name):
16    return name.lower().replace('-', '_')
17
18
19class BaseNlLib:
20    def get_family_id(self):
21        return 'ys->family_id'
22
23    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
24        ind = '\n\t\t' + '\t' * indent + ' '
25        if is_dump:
26            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
27        else:
28            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
29                   "ynl_cb_array, NLMSG_MIN_TYPE)"
30
31
32class Type(SpecAttr):
33    def __init__(self, family, attr_set, attr, value):
34        super().__init__(family, attr_set, attr, value)
35
36        self.attr = attr
37        self.attr_set = attr_set
38        self.type = attr['type']
39        self.checks = attr.get('checks', {})
40
41        if 'len' in attr:
42            self.len = attr['len']
43        if 'nested-attributes' in attr:
44            self.nested_attrs = attr['nested-attributes']
45            if self.nested_attrs == family.name:
46                self.nested_render_name = f"{family.name}"
47            else:
48                self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
49
50        self.c_name = c_lower(self.name)
51        if self.c_name in _C_KW:
52            self.c_name += '_'
53
54        # Added by resolve():
55        self.enum_name = None
56        delattr(self, "enum_name")
57
58    def resolve(self):
59        self.enum_name = f"{self.attr_set.name_prefix}{self.name}"
60        self.enum_name = c_upper(self.enum_name)
61
62    def is_multi_val(self):
63        return None
64
65    def is_scalar(self):
66        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
67
68    def presence_type(self):
69        return 'bit'
70
71    def presence_member(self, space, type_filter):
72        if self.presence_type() != type_filter:
73            return
74
75        if self.presence_type() == 'bit':
76            pfx = '__' if space == 'user' else ''
77            return f"{pfx}u32 {self.c_name}:1;"
78
79        if self.presence_type() == 'len':
80            pfx = '__' if space == 'user' else ''
81            return f"{pfx}u32 {self.c_name}_len;"
82
83    def _complex_member_type(self, ri):
84        return None
85
86    def free_needs_iter(self):
87        return False
88
89    def free(self, ri, var, ref):
90        if self.is_multi_val() or self.presence_type() == 'len':
91            ri.cw.p(f'free({var}->{ref}{self.c_name});')
92
93    def arg_member(self, ri):
94        member = self._complex_member_type(ri)
95        if member:
96            return [member + ' *' + self.c_name]
97        raise Exception(f"Struct member not implemented for class type {self.type}")
98
99    def struct_member(self, ri):
100        if self.is_multi_val():
101            ri.cw.p(f"unsigned int n_{self.c_name};")
102        member = self._complex_member_type(ri)
103        if member:
104            ptr = '*' if self.is_multi_val() else ''
105            ri.cw.p(f"{member} {ptr}{self.c_name};")
106            return
107        members = self.arg_member(ri)
108        for one in members:
109            ri.cw.p(one + ';')
110
111    def _attr_policy(self, policy):
112        return '{ .type = ' + policy + ', }'
113
114    def attr_policy(self, cw):
115        policy = c_upper('nla-' + self.attr['type'])
116
117        spec = self._attr_policy(policy)
118        cw.p(f"\t[{self.enum_name}] = {spec},")
119
120    def _attr_typol(self):
121        raise Exception(f"Type policy not implemented for class type {self.type}")
122
123    def attr_typol(self, cw):
124        typol = self._attr_typol()
125        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
126
127    def _attr_put_line(self, ri, var, line):
128        if self.presence_type() == 'bit':
129            ri.cw.p(f"if ({var}->_present.{self.c_name})")
130        elif self.presence_type() == 'len':
131            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
132        ri.cw.p(f"{line};")
133
134    def _attr_put_simple(self, ri, var, put_type):
135        line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
136        self._attr_put_line(ri, var, line)
137
138    def attr_put(self, ri, var):
139        raise Exception(f"Put not implemented for class type {self.type}")
140
141    def _attr_get(self, ri, var):
142        raise Exception(f"Attr get not implemented for class type {self.type}")
143
144    def attr_get(self, ri, var, first):
145        lines, init_lines, local_vars = self._attr_get(ri, var)
146        if type(lines) is str:
147            lines = [lines]
148        if type(init_lines) is str:
149            init_lines = [init_lines]
150
151        kw = 'if' if first else 'else if'
152        ri.cw.block_start(line=f"{kw} (mnl_attr_get_type(attr) == {self.enum_name})")
153        if local_vars:
154            for local in local_vars:
155                ri.cw.p(local)
156            ri.cw.nl()
157
158        if not self.is_multi_val():
159            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
160            ri.cw.p("return MNL_CB_ERROR;")
161            if self.presence_type() == 'bit':
162                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
163
164        if init_lines:
165            ri.cw.nl()
166            for line in init_lines:
167                ri.cw.p(line)
168
169        for line in lines:
170            ri.cw.p(line)
171        ri.cw.block_end()
172
173    def _setter_lines(self, ri, member, presence):
174        raise Exception(f"Setter not implemented for class type {self.type}")
175
176    def setter(self, ri, space, direction, deref=False, ref=None):
177        ref = (ref if ref else []) + [self.c_name]
178        var = "req"
179        member = f"{var}->{'.'.join(ref)}"
180
181        code = []
182        presence = ''
183        for i in range(0, len(ref)):
184            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
185            if self.presence_type() == 'bit':
186                code.append(presence + ' = 1;')
187        code += self._setter_lines(ri, member, presence)
188
189        ri.cw.write_func('static inline void',
190                         f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}",
191                         body=code,
192                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
193
194
195class TypeUnused(Type):
196    def presence_type(self):
197        return ''
198
199    def _attr_typol(self):
200        return '.type = YNL_PT_REJECT, '
201
202    def attr_policy(self, cw):
203        pass
204
205
206class TypePad(Type):
207    def presence_type(self):
208        return ''
209
210    def _attr_typol(self):
211        return '.type = YNL_PT_REJECT, '
212
213    def attr_policy(self, cw):
214        pass
215
216
217class TypeScalar(Type):
218    def __init__(self, family, attr_set, attr, value):
219        super().__init__(family, attr_set, attr, value)
220
221        self.byte_order_comment = ''
222        if 'byte-order' in attr:
223            self.byte_order_comment = f" /* {attr['byte-order']} */"
224
225        # Added by resolve():
226        self.is_bitfield = None
227        delattr(self, "is_bitfield")
228        self.type_name = None
229        delattr(self, "type_name")
230
231    def resolve(self):
232        self.resolve_up(super())
233
234        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
235            self.is_bitfield = True
236        elif 'enum' in self.attr:
237            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
238        else:
239            self.is_bitfield = False
240
241        if 'enum' in self.attr and not self.is_bitfield:
242            self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
243        else:
244            self.type_name = '__' + self.type
245
246    def _mnl_type(self):
247        t = self.type
248        # mnl does not have a helper for signed types
249        if t[0] == 's':
250            t = 'u' + t[1:]
251        return t
252
253    def _attr_policy(self, policy):
254        if 'flags-mask' in self.checks or self.is_bitfield:
255            if self.is_bitfield:
256                mask = self.family.consts[self.attr['enum']].get_mask()
257            else:
258                flags = self.family.consts[self.checks['flags-mask']]
259                flag_cnt = len(flags['entries'])
260                mask = (1 << flag_cnt) - 1
261            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
262        elif 'min' in self.checks:
263            return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
264        elif 'enum' in self.attr:
265            enum = self.family.consts[self.attr['enum']]
266            cnt = len(enum['entries'])
267            return f"NLA_POLICY_MAX({policy}, {cnt - 1})"
268        return super()._attr_policy(policy)
269
270    def _attr_typol(self):
271        return f'.type = YNL_PT_U{self.type[1:]}, '
272
273    def arg_member(self, ri):
274        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
275
276    def attr_put(self, ri, var):
277        self._attr_put_simple(ri, var, self._mnl_type())
278
279    def _attr_get(self, ri, var):
280        return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
281
282    def _setter_lines(self, ri, member, presence):
283        return [f"{member} = {self.c_name};"]
284
285
286class TypeFlag(Type):
287    def arg_member(self, ri):
288        return []
289
290    def _attr_typol(self):
291        return '.type = YNL_PT_FLAG, '
292
293    def attr_put(self, ri, var):
294        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
295
296    def _attr_get(self, ri, var):
297        return [], None, None
298
299    def _setter_lines(self, ri, member, presence):
300        return []
301
302
303class TypeString(Type):
304    def arg_member(self, ri):
305        return [f"const char *{self.c_name}"]
306
307    def presence_type(self):
308        return 'len'
309
310    def struct_member(self, ri):
311        ri.cw.p(f"char *{self.c_name};")
312
313    def _attr_typol(self):
314        return f'.type = YNL_PT_NUL_STR, '
315
316    def _attr_policy(self, policy):
317        mem = '{ .type = ' + policy
318        if 'max-len' in self.checks:
319            mem += ', .len = ' + str(self.checks['max-len'])
320        mem += ', }'
321        return mem
322
323    def attr_policy(self, cw):
324        if self.checks.get('unterminated-ok', False):
325            policy = 'NLA_STRING'
326        else:
327            policy = 'NLA_NUL_STRING'
328
329        spec = self._attr_policy(policy)
330        cw.p(f"\t[{self.enum_name}] = {spec},")
331
332    def attr_put(self, ri, var):
333        self._attr_put_simple(ri, var, 'strz')
334
335    def _attr_get(self, ri, var):
336        len_mem = var + '->_present.' + self.c_name + '_len'
337        return [f"{len_mem} = len;",
338                f"{var}->{self.c_name} = malloc(len + 1);",
339                f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
340                f"{var}->{self.c_name}[len] = 0;"], \
341               ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
342               ['unsigned int len;']
343
344    def _setter_lines(self, ri, member, presence):
345        return [f"free({member});",
346                f"{presence}_len = strlen({self.c_name});",
347                f"{member} = malloc({presence}_len + 1);",
348                f'memcpy({member}, {self.c_name}, {presence}_len);',
349                f'{member}[{presence}_len] = 0;']
350
351
352class TypeBinary(Type):
353    def arg_member(self, ri):
354        return [f"const void *{self.c_name}", 'size_t len']
355
356    def presence_type(self):
357        return 'len'
358
359    def struct_member(self, ri):
360        ri.cw.p(f"void *{self.c_name};")
361
362    def _attr_typol(self):
363        return f'.type = YNL_PT_BINARY,'
364
365    def _attr_policy(self, policy):
366        mem = '{ '
367        if len(self.checks) == 1 and 'min-len' in self.checks:
368            mem += '.len = ' + str(self.checks['min-len'])
369        elif len(self.checks) == 0:
370            mem += '.type = NLA_BINARY'
371        else:
372            raise Exception('One or more of binary type checks not implemented, yet')
373        mem += ', }'
374        return mem
375
376    def attr_put(self, ri, var):
377        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
378                            f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
379
380    def _attr_get(self, ri, var):
381        len_mem = var + '->_present.' + self.c_name + '_len'
382        return [f"{len_mem} = len;",
383                f"{var}->{self.c_name} = malloc(len);",
384                f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
385               ['len = mnl_attr_get_payload_len(attr);'], \
386               ['unsigned int len;']
387
388    def _setter_lines(self, ri, member, presence):
389        return [f"free({member});",
390                f"{member} = malloc({presence}_len);",
391                f'memcpy({member}, {self.c_name}, {presence}_len);']
392
393
394class TypeNest(Type):
395    def _complex_member_type(self, ri):
396        return f"struct {self.nested_render_name}"
397
398    def free(self, ri, var, ref):
399        ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
400
401    def _attr_typol(self):
402        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
403
404    def _attr_policy(self, policy):
405        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
406
407    def attr_put(self, ri, var):
408        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
409                            f"{self.enum_name}, &{var}->{self.c_name})")
410
411    def _attr_get(self, ri, var):
412        get_lines = [f"{self.nested_render_name}_parse(&parg, attr);"]
413        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
414                      f"parg.data = &{var}->{self.c_name};"]
415        return get_lines, init_lines, None
416
417    def setter(self, ri, space, direction, deref=False, ref=None):
418        ref = (ref if ref else []) + [self.c_name]
419
420        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
421            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
422
423
424class TypeMultiAttr(Type):
425    def is_multi_val(self):
426        return True
427
428    def presence_type(self):
429        return 'count'
430
431    def _complex_member_type(self, ri):
432        if 'type' not in self.attr or self.attr['type'] == 'nest':
433            return f"struct {self.nested_render_name}"
434        elif self.attr['type'] in scalars:
435            scalar_pfx = '__' if ri.ku_space == 'user' else ''
436            return scalar_pfx + self.attr['type']
437        else:
438            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
439
440    def free_needs_iter(self):
441        return 'type' not in self.attr or self.attr['type'] == 'nest'
442
443    def free(self, ri, var, ref):
444        if 'type' not in self.attr or self.attr['type'] == 'nest':
445            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
446            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
447
448    def _attr_typol(self):
449        if 'type' not in self.attr or self.attr['type'] == 'nest':
450            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
451        elif self.attr['type'] in scalars:
452            return f".type = YNL_PT_U{self.attr['type'][1:]}, "
453        else:
454            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
455
456    def _attr_get(self, ri, var):
457        return f'{var}->n_{self.c_name}++;', None, None
458
459
460class TypeArrayNest(Type):
461    def is_multi_val(self):
462        return True
463
464    def presence_type(self):
465        return 'count'
466
467    def _complex_member_type(self, ri):
468        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
469            return f"struct {self.nested_render_name}"
470        elif self.attr['sub-type'] in scalars:
471            scalar_pfx = '__' if ri.ku_space == 'user' else ''
472            return scalar_pfx + self.attr['sub-type']
473        else:
474            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
475
476    def _attr_typol(self):
477        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
478
479    def _attr_get(self, ri, var):
480        local_vars = ['const struct nlattr *attr2;']
481        get_lines = [f'attr_{self.c_name} = attr;',
482                     'mnl_attr_for_each_nested(attr2, attr)',
483                     f'\t{var}->n_{self.c_name}++;']
484        return get_lines, None, local_vars
485
486
487class TypeNestTypeValue(Type):
488    def _complex_member_type(self, ri):
489        return f"struct {self.nested_render_name}"
490
491    def _attr_typol(self):
492        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
493
494    def _attr_get(self, ri, var):
495        prev = 'attr'
496        tv_args = ''
497        get_lines = []
498        local_vars = []
499        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
500                      f"parg.data = &{var}->{self.c_name};"]
501        if 'type-value' in self.attr:
502            tv_names = [c_lower(x) for x in self.attr["type-value"]]
503            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
504            local_vars += [f'__u32 {", ".join(tv_names)};']
505            for level in self.attr["type-value"]:
506                level = c_lower(level)
507                get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
508                get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
509                prev = 'attr_' + level
510
511            tv_args = f", {', '.join(tv_names)}"
512
513        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
514        return get_lines, init_lines, local_vars
515
516
517class Struct:
518    def __init__(self, family, space_name, type_list=None, inherited=None):
519        self.family = family
520        self.space_name = space_name
521        self.attr_set = family.attr_sets[space_name]
522        # Use list to catch comparisons with empty sets
523        self._inherited = inherited if inherited is not None else []
524        self.inherited = []
525
526        self.nested = type_list is None
527        if family.name == c_lower(space_name):
528            self.render_name = f"{family.name}"
529        else:
530            self.render_name = f"{family.name}_{c_lower(space_name)}"
531        self.struct_name = 'struct ' + self.render_name
532        self.ptr_name = self.struct_name + ' *'
533
534        self.request = False
535        self.reply = False
536
537        self.attr_list = []
538        self.attrs = dict()
539        if type_list:
540            for t in type_list:
541                self.attr_list.append((t, self.attr_set[t]),)
542        else:
543            for t in self.attr_set:
544                self.attr_list.append((t, self.attr_set[t]),)
545
546        max_val = 0
547        self.attr_max_val = None
548        for name, attr in self.attr_list:
549            if attr.value >= max_val:
550                max_val = attr.value
551                self.attr_max_val = attr
552            self.attrs[name] = attr
553
554    def __iter__(self):
555        yield from self.attrs
556
557    def __getitem__(self, key):
558        return self.attrs[key]
559
560    def member_list(self):
561        return self.attr_list
562
563    def set_inherited(self, new_inherited):
564        if self._inherited != new_inherited:
565            raise Exception("Inheriting different members not supported")
566        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
567
568
569class EnumEntry:
570    def __init__(self, enum_set, yaml, prev, value_start):
571        if isinstance(yaml, str):
572            self.name = yaml
573            yaml = {}
574            self.doc = ''
575        else:
576            self.name = yaml['name']
577            self.doc = yaml.get('doc', '')
578
579        self.yaml = yaml
580        self.enum_set = enum_set
581        self.c_name = c_upper(enum_set.value_pfx + self.name)
582
583        if 'value' in yaml:
584            self.value = yaml['value']
585            if prev:
586                self.value_change = (self.value != prev.value + 1)
587        elif prev:
588            self.value_change = False
589            self.value = prev.value + 1
590        else:
591            self.value = value_start
592            self.value_change = (self.value != 0)
593
594        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
595
596    def __getitem__(self, key):
597        return self.yaml[key]
598
599    def __contains__(self, key):
600        return key in self.yaml
601
602    def has_doc(self):
603        return bool(self.doc)
604
605    # raw value, i.e. the id in the enum, unlike user value which is a mask for flags
606    def raw_value(self):
607        return self.value
608
609    # user value, same as raw value for enums, for flags it's the mask
610    def user_value(self):
611        if self.enum_set['type'] == 'flags':
612            return 1 << self.value
613        else:
614            return self.value
615
616
617class EnumSet:
618    def __init__(self, family, yaml):
619        self.yaml = yaml
620        self.family = family
621
622        self.render_name = c_lower(family.name + '-' + yaml['name'])
623        self.enum_name = 'enum ' + self.render_name
624
625        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
626
627        self.type = yaml['type']
628
629        prev_entry = None
630        value_start = self.yaml.get('value-start', 0)
631        self.entries = {}
632        self.entry_list = []
633        for entry in self.yaml['entries']:
634            e = EnumEntry(self, entry, prev_entry, value_start)
635            self.entries[e.name] = e
636            self.entry_list.append(e)
637            prev_entry = e
638
639    def __getitem__(self, key):
640        return self.yaml[key]
641
642    def __contains__(self, key):
643        return key in self.yaml
644
645    def has_doc(self):
646        if 'doc' in self.yaml:
647            return True
648        for entry in self.entry_list:
649            if entry.has_doc():
650                return True
651        return False
652
653    def get_mask(self):
654        mask = 0
655        idx = self.yaml.get('value-start', 0)
656        for _ in self.entry_list:
657            mask |= 1 << idx
658            idx += 1
659        return mask
660
661
662class AttrSet(SpecAttrSet):
663    def __init__(self, family, yaml):
664        super().__init__(family, yaml)
665
666        if self.subset_of is None:
667            if 'name-prefix' in yaml:
668                pfx = yaml['name-prefix']
669            elif self.name == family.name:
670                pfx = family.name + '-a-'
671            else:
672                pfx = f"{family.name}-a-{self.name}-"
673            self.name_prefix = c_upper(pfx)
674            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
675        else:
676            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
677            self.max_name = family.attr_sets[self.subset_of].max_name
678
679        # Added by resolve:
680        self.c_name = None
681        delattr(self, "c_name")
682
683    def resolve(self):
684        self.c_name = c_lower(self.name)
685        if self.c_name in _C_KW:
686            self.c_name += '_'
687        if self.c_name == self.family.c_name:
688            self.c_name = ''
689
690    def new_attr(self, elem, value):
691        if 'multi-attr' in elem and elem['multi-attr']:
692            return TypeMultiAttr(self.family, self, elem, value)
693        elif elem['type'] in scalars:
694            return TypeScalar(self.family, self, elem, value)
695        elif elem['type'] == 'unused':
696            return TypeUnused(self.family, self, elem, value)
697        elif elem['type'] == 'pad':
698            return TypePad(self.family, self, elem, value)
699        elif elem['type'] == 'flag':
700            return TypeFlag(self.family, self, elem, value)
701        elif elem['type'] == 'string':
702            return TypeString(self.family, self, elem, value)
703        elif elem['type'] == 'binary':
704            return TypeBinary(self.family, self, elem, value)
705        elif elem['type'] == 'nest':
706            return TypeNest(self.family, self, elem, value)
707        elif elem['type'] == 'array-nest':
708            return TypeArrayNest(self.family, self, elem, value)
709        elif elem['type'] == 'nest-type-value':
710            return TypeNestTypeValue(self.family, self, elem, value)
711        else:
712            raise Exception(f"No typed class for type {elem['type']}")
713
714
715class Operation(SpecOperation):
716    def __init__(self, family, yaml, req_value, rsp_value):
717        super().__init__(family, yaml, req_value, rsp_value)
718
719        if req_value != rsp_value:
720            raise Exception("Directional messages not supported by codegen")
721
722        self.render_name = family.name + '_' + c_lower(self.name)
723
724        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
725                         ('dump' in yaml and 'request' in yaml['dump'])
726
727        # Added by resolve:
728        self.enum_name = None
729        delattr(self, "enum_name")
730
731    def resolve(self):
732        self.resolve_up(super())
733
734        if not self.is_async:
735            self.enum_name = self.family.op_prefix + c_upper(self.name)
736        else:
737            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
738
739    def add_notification(self, op):
740        if 'notify' not in self.yaml:
741            self.yaml['notify'] = dict()
742            self.yaml['notify']['reply'] = self.yaml['do']['reply']
743            self.yaml['notify']['cmds'] = []
744        self.yaml['notify']['cmds'].append(op)
745
746
747class Family(SpecFamily):
748    def __init__(self, file_name):
749        # Added by resolve:
750        self.c_name = None
751        delattr(self, "c_name")
752        self.op_prefix = None
753        delattr(self, "op_prefix")
754        self.async_op_prefix = None
755        delattr(self, "async_op_prefix")
756        self.mcgrps = None
757        delattr(self, "mcgrps")
758        self.consts = None
759        delattr(self, "consts")
760        self.hooks = None
761        delattr(self, "hooks")
762
763        super().__init__(file_name)
764
765        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
766        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
767
768        if 'definitions' not in self.yaml:
769            self.yaml['definitions'] = []
770
771        if 'uapi-header' in self.yaml:
772            self.uapi_header = self.yaml['uapi-header']
773        else:
774            self.uapi_header = f"linux/{self.name}.h"
775
776    def resolve(self):
777        self.resolve_up(super())
778
779        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
780            raise Exception("Codegen only supported for genetlink")
781
782        self.c_name = c_lower(self.name)
783        if 'name-prefix' in self.yaml['operations']:
784            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
785        else:
786            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
787        if 'async-prefix' in self.yaml['operations']:
788            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
789        else:
790            self.async_op_prefix = self.op_prefix
791
792        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
793
794        self.consts = dict()
795
796        self.hooks = dict()
797        for when in ['pre', 'post']:
798            self.hooks[when] = dict()
799            for op_mode in ['do', 'dump']:
800                self.hooks[when][op_mode] = dict()
801                self.hooks[when][op_mode]['set'] = set()
802                self.hooks[when][op_mode]['list'] = []
803
804        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
805        self.root_sets = dict()
806        # dict space-name -> set('request', 'reply')
807        self.pure_nested_structs = dict()
808        self.all_notify = dict()
809
810        self._mock_up_events()
811
812        self._dictify()
813        self._load_root_sets()
814        self._load_nested_sets()
815        self._load_all_notify()
816        self._load_hooks()
817
818        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
819        if self.kernel_policy == 'global':
820            self._load_global_policy()
821
822    def new_attr_set(self, elem):
823        return AttrSet(self, elem)
824
825    def new_operation(self, elem, req_value, rsp_value):
826        return Operation(self, elem, req_value, rsp_value)
827
828    # Fake a 'do' equivalent of all events, so that we can render their response parsing
829    def _mock_up_events(self):
830        for op in self.yaml['operations']['list']:
831            if 'event' in op:
832                op['do'] = {
833                    'reply': {
834                        'attributes': op['event']['attributes']
835                    }
836                }
837
838    def _dictify(self):
839        for elem in self.yaml['definitions']:
840            if elem['type'] == 'enum' or elem['type'] == 'flags':
841                self.consts[elem['name']] = EnumSet(self, elem)
842            else:
843                self.consts[elem['name']] = elem
844
845        ntf = []
846        for msg in self.msgs.values():
847            if 'notify' in msg:
848                ntf.append(msg)
849        for n in ntf:
850            self.ops[n['notify']].add_notification(n)
851
852    def _load_root_sets(self):
853        for op_name, op in self.ops.items():
854            if 'attribute-set' not in op:
855                continue
856
857            req_attrs = set()
858            rsp_attrs = set()
859            for op_mode in ['do', 'dump']:
860                if op_mode in op and 'request' in op[op_mode]:
861                    req_attrs.update(set(op[op_mode]['request']['attributes']))
862                if op_mode in op and 'reply' in op[op_mode]:
863                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
864
865            if op['attribute-set'] not in self.root_sets:
866                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
867            else:
868                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
869                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
870
871    def _load_nested_sets(self):
872        for root_set, rs_members in self.root_sets.items():
873            for attr, spec in self.attr_sets[root_set].items():
874                if 'nested-attributes' in spec:
875                    inherit = set()
876                    nested = spec['nested-attributes']
877                    if nested not in self.root_sets:
878                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
879                    if attr in rs_members['request']:
880                        self.pure_nested_structs[nested].request = True
881                    if attr in rs_members['reply']:
882                        self.pure_nested_structs[nested].reply = True
883
884                    if 'type-value' in spec:
885                        if nested in self.root_sets:
886                            raise Exception("Inheriting members to a space used as root not supported")
887                        inherit.update(set(spec['type-value']))
888                    elif spec['type'] == 'array-nest':
889                        inherit.add('idx')
890                    self.pure_nested_structs[nested].set_inherited(inherit)
891
892    def _load_all_notify(self):
893        for op_name, op in self.ops.items():
894            if not op:
895                continue
896
897            if 'notify' in op:
898                self.all_notify[op_name] = op['notify']['cmds']
899
900    def _load_global_policy(self):
901        global_set = set()
902        attr_set_name = None
903        for op_name, op in self.ops.items():
904            if not op:
905                continue
906            if 'attribute-set' not in op:
907                continue
908
909            if attr_set_name is None:
910                attr_set_name = op['attribute-set']
911            if attr_set_name != op['attribute-set']:
912                raise Exception('For a global policy all ops must use the same set')
913
914            for op_mode in ['do', 'dump']:
915                if op_mode in op:
916                    global_set.update(op[op_mode].get('request', []))
917
918        self.global_policy = []
919        self.global_policy_set = attr_set_name
920        for attr in self.attr_sets[attr_set_name]:
921            if attr in global_set:
922                self.global_policy.append(attr)
923
924    def _load_hooks(self):
925        for op in self.ops.values():
926            for op_mode in ['do', 'dump']:
927                if op_mode not in op:
928                    continue
929                for when in ['pre', 'post']:
930                    if when not in op[op_mode]:
931                        continue
932                    name = op[op_mode][when]
933                    if name in self.hooks[when][op_mode]['set']:
934                        continue
935                    self.hooks[when][op_mode]['set'].add(name)
936                    self.hooks[when][op_mode]['list'].append(name)
937
938
939class RenderInfo:
940    def __init__(self, cw, family, ku_space, op, op_name, op_mode, attr_set=None):
941        self.family = family
942        self.nl = cw.nlib
943        self.ku_space = ku_space
944        self.op = op
945        self.op_name = op_name
946        self.op_mode = op_mode
947
948        # 'do' and 'dump' response parsing is identical
949        if op_mode != 'do' and 'dump' in op and 'do' in op and 'reply' in op['do'] and \
950           op["do"]["reply"] == op["dump"]["reply"]:
951            self.type_consistent = True
952        else:
953            self.type_consistent = op_mode == 'event'
954
955        self.attr_set = attr_set
956        if not self.attr_set:
957            self.attr_set = op['attribute-set']
958
959        if op:
960            self.type_name = c_lower(op_name)
961        else:
962            self.type_name = c_lower(attr_set)
963
964        self.cw = cw
965
966        self.struct = dict()
967        for op_dir in ['request', 'reply']:
968            if op and op_dir in op[op_mode]:
969                self.struct[op_dir] = Struct(family, self.attr_set,
970                                             type_list=op[op_mode][op_dir]['attributes'])
971        if op_mode == 'event':
972            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
973
974
975class CodeWriter:
976    def __init__(self, nlib, out_file):
977        self.nlib = nlib
978
979        self._nl = False
980        self._silent_block = False
981        self._ind = 0
982        self._out = out_file
983
984    @classmethod
985    def _is_cond(cls, line):
986        return line.startswith('if') or line.startswith('while') or line.startswith('for')
987
988    def p(self, line, add_ind=0):
989        if self._nl:
990            self._out.write('\n')
991            self._nl = False
992        ind = self._ind
993        if line[-1] == ':':
994            ind -= 1
995        if self._silent_block:
996            ind += 1
997        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
998        if add_ind:
999            ind += add_ind
1000        self._out.write('\t' * ind + line + '\n')
1001
1002    def nl(self):
1003        self._nl = True
1004
1005    def block_start(self, line=''):
1006        if line:
1007            line = line + ' '
1008        self.p(line + '{')
1009        self._ind += 1
1010
1011    def block_end(self, line=''):
1012        if line and line[0] not in {';', ','}:
1013            line = ' ' + line
1014        self._ind -= 1
1015        self.p('}' + line)
1016
1017    def write_doc_line(self, doc, indent=True):
1018        words = doc.split()
1019        line = ' *'
1020        for word in words:
1021            if len(line) + len(word) >= 79:
1022                self.p(line)
1023                line = ' *'
1024                if indent:
1025                    line += '  '
1026            line += ' ' + word
1027        self.p(line)
1028
1029    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1030        if not args:
1031            args = ['void']
1032
1033        if doc:
1034            self.p('/*')
1035            self.p(' * ' + doc)
1036            self.p(' */')
1037
1038        oneline = qual_ret
1039        if qual_ret[-1] != '*':
1040            oneline += ' '
1041        oneline += f"{name}({', '.join(args)}){suffix}"
1042
1043        if len(oneline) < 80:
1044            self.p(oneline)
1045            return
1046
1047        v = qual_ret
1048        if len(v) > 3:
1049            self.p(v)
1050            v = ''
1051        elif qual_ret[-1] != '*':
1052            v += ' '
1053        v += name + '('
1054        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1055        delta_ind = len(v) - len(ind)
1056        v += args[0]
1057        i = 1
1058        while i < len(args):
1059            next_len = len(v) + len(args[i])
1060            if v[0] == '\t':
1061                next_len += delta_ind
1062            if next_len > 76:
1063                self.p(v + ',')
1064                v = ind
1065            else:
1066                v += ', '
1067            v += args[i]
1068            i += 1
1069        self.p(v + ')' + suffix)
1070
1071    def write_func_lvar(self, local_vars):
1072        if not local_vars:
1073            return
1074
1075        if type(local_vars) is str:
1076            local_vars = [local_vars]
1077
1078        local_vars.sort(key=len, reverse=True)
1079        for var in local_vars:
1080            self.p(var)
1081        self.nl()
1082
1083    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1084        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1085        self.write_func_lvar(local_vars=local_vars)
1086
1087        self.block_start()
1088        for line in body:
1089            self.p(line)
1090        self.block_end()
1091
1092    def writes_defines(self, defines):
1093        longest = 0
1094        for define in defines:
1095            if len(define[0]) > longest:
1096                longest = len(define[0])
1097        longest = ((longest + 8) // 8) * 8
1098        for define in defines:
1099            line = '#define ' + define[0]
1100            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1101            if type(define[1]) is int:
1102                line += str(define[1])
1103            elif type(define[1]) is str:
1104                line += '"' + define[1] + '"'
1105            self.p(line)
1106
1107    def write_struct_init(self, members):
1108        longest = max([len(x[0]) for x in members])
1109        longest += 1  # because we prepend a .
1110        longest = ((longest + 8) // 8) * 8
1111        for one in members:
1112            line = '.' + one[0]
1113            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1114            line += '= ' + one[1] + ','
1115            self.p(line)
1116
1117
1118scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1119
1120direction_to_suffix = {
1121    'reply': '_rsp',
1122    'request': '_req',
1123    '': ''
1124}
1125
1126op_mode_to_wrapper = {
1127    'do': '',
1128    'dump': '_list',
1129    'notify': '_ntf',
1130    'event': '',
1131}
1132
1133_C_KW = {
1134    'do'
1135}
1136
1137
1138def rdir(direction):
1139    if direction == 'reply':
1140        return 'request'
1141    if direction == 'request':
1142        return 'reply'
1143    return direction
1144
1145
1146def op_prefix(ri, direction, deref=False):
1147    suffix = f"_{ri.type_name}"
1148
1149    if not ri.op_mode or ri.op_mode == 'do':
1150        suffix += f"{direction_to_suffix[direction]}"
1151    else:
1152        if direction == 'request':
1153            suffix += '_req_dump'
1154        else:
1155            if ri.type_consistent:
1156                if deref:
1157                    suffix += f"{direction_to_suffix[direction]}"
1158                else:
1159                    suffix += op_mode_to_wrapper[ri.op_mode]
1160            else:
1161                suffix += '_rsp'
1162                suffix += '_dump' if deref else '_list'
1163
1164    return f"{ri.family['name']}{suffix}"
1165
1166
1167def type_name(ri, direction, deref=False):
1168    return f"struct {op_prefix(ri, direction, deref=deref)}"
1169
1170
1171def print_prototype(ri, direction, terminate=True, doc=None):
1172    suffix = ';' if terminate else ''
1173
1174    fname = ri.op.render_name
1175    if ri.op_mode == 'dump':
1176        fname += '_dump'
1177
1178    args = ['struct ynl_sock *ys']
1179    if 'request' in ri.op[ri.op_mode]:
1180        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1181
1182    ret = 'int'
1183    if 'reply' in ri.op[ri.op_mode]:
1184        ret = f"{type_name(ri, rdir(direction))} *"
1185
1186    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1187
1188
1189def print_req_prototype(ri):
1190    print_prototype(ri, "request", doc=ri.op['doc'])
1191
1192
1193def print_dump_prototype(ri):
1194    print_prototype(ri, "request")
1195
1196
1197def put_typol_fwd(cw, struct):
1198    cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;')
1199
1200
1201def put_typol(cw, struct):
1202    type_max = struct.attr_set.max_name
1203    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1204
1205    for _, arg in struct.member_list():
1206        arg.attr_typol(cw)
1207
1208    cw.block_end(line=';')
1209    cw.nl()
1210
1211    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1212    cw.p(f'.max_attr = {type_max},')
1213    cw.p(f'.table = {struct.render_name}_policy,')
1214    cw.block_end(line=';')
1215    cw.nl()
1216
1217
1218def put_req_nested(ri, struct):
1219    func_args = ['struct nlmsghdr *nlh',
1220                 'unsigned int attr_type',
1221                 f'{struct.ptr_name}obj']
1222
1223    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1224    ri.cw.block_start()
1225    ri.cw.write_func_lvar('struct nlattr *nest;')
1226
1227    ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1228
1229    for _, arg in struct.member_list():
1230        arg.attr_put(ri, "obj")
1231
1232    ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1233
1234    ri.cw.nl()
1235    ri.cw.p('return 0;')
1236    ri.cw.block_end()
1237    ri.cw.nl()
1238
1239
1240def _multi_parse(ri, struct, init_lines, local_vars):
1241    if struct.nested:
1242        iter_line = "mnl_attr_for_each_nested(attr, nested)"
1243    else:
1244        iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1245
1246    array_nests = set()
1247    multi_attrs = set()
1248    needs_parg = False
1249    for arg, aspec in struct.member_list():
1250        if aspec['type'] == 'array-nest':
1251            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1252            array_nests.add(arg)
1253        if 'multi-attr' in aspec:
1254            multi_attrs.add(arg)
1255        needs_parg |= 'nested-attributes' in aspec
1256    if array_nests or multi_attrs:
1257        local_vars.append('int i;')
1258    if needs_parg:
1259        local_vars.append('struct ynl_parse_arg parg;')
1260        init_lines.append('parg.ys = yarg->ys;')
1261
1262    ri.cw.block_start()
1263    ri.cw.write_func_lvar(local_vars)
1264
1265    for line in init_lines:
1266        ri.cw.p(line)
1267    ri.cw.nl()
1268
1269    for arg in struct.inherited:
1270        ri.cw.p(f'dst->{arg} = {arg};')
1271
1272    ri.cw.nl()
1273    ri.cw.block_start(line=iter_line)
1274
1275    first = True
1276    for _, arg in struct.member_list():
1277        arg.attr_get(ri, 'dst', first=first)
1278        first = False
1279
1280    ri.cw.block_end()
1281    ri.cw.nl()
1282
1283    for anest in sorted(array_nests):
1284        aspec = struct[anest]
1285
1286        ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1287        ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1288        ri.cw.p('i = 0;')
1289        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1290        ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1291        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1292        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1293        ri.cw.p('return MNL_CB_ERROR;')
1294        ri.cw.p('i++;')
1295        ri.cw.block_end()
1296        ri.cw.block_end()
1297    ri.cw.nl()
1298
1299    for anest in sorted(multi_attrs):
1300        aspec = struct[anest]
1301        ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1302        ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1303        ri.cw.p('i = 0;')
1304        if 'nested-attributes' in aspec:
1305            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1306        ri.cw.block_start(line=iter_line)
1307        ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1308        if 'nested-attributes' in aspec:
1309            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1310            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1311            ri.cw.p('return MNL_CB_ERROR;')
1312        elif aspec['type'] in scalars:
1313            t = aspec['type']
1314            if t[0] == 's':
1315                t = 'u' + t[1:]
1316            ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1317        else:
1318            raise Exception('Nest parsing type not supported yet')
1319        ri.cw.p('i++;')
1320        ri.cw.block_end()
1321        ri.cw.block_end()
1322        ri.cw.block_end()
1323    ri.cw.nl()
1324
1325    if struct.nested:
1326        ri.cw.p('return 0;')
1327    else:
1328        ri.cw.p('return MNL_CB_OK;')
1329    ri.cw.block_end()
1330    ri.cw.nl()
1331
1332
1333def parse_rsp_nested(ri, struct):
1334    func_args = ['struct ynl_parse_arg *yarg',
1335                 'const struct nlattr *nested']
1336    for arg in struct.inherited:
1337        func_args.append('__u32 ' + arg)
1338
1339    local_vars = ['const struct nlattr *attr;',
1340                  f'{struct.ptr_name}dst = yarg->data;']
1341    init_lines = []
1342
1343    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1344
1345    _multi_parse(ri, struct, init_lines, local_vars)
1346
1347
1348def parse_rsp_msg(ri, deref=False):
1349    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1350        return
1351
1352    func_args = ['const struct nlmsghdr *nlh',
1353                 'void *data']
1354
1355    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1356                  'struct ynl_parse_arg *yarg = data;',
1357                  'const struct nlattr *attr;']
1358    init_lines = ['dst = yarg->data;']
1359
1360    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1361
1362    _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1363
1364
1365def print_req(ri):
1366    ret_ok = '0'
1367    ret_err = '-1'
1368    direction = "request"
1369    local_vars = ['struct nlmsghdr *nlh;',
1370                  'int len, err;']
1371
1372    if 'reply' in ri.op[ri.op_mode]:
1373        ret_ok = 'rsp'
1374        ret_err = 'NULL'
1375        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1376                       'struct ynl_parse_arg yarg = { .ys = ys, };']
1377
1378    print_prototype(ri, direction, terminate=False)
1379    ri.cw.block_start()
1380    ri.cw.write_func_lvar(local_vars)
1381
1382    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1383
1384    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1385    if 'reply' in ri.op[ri.op_mode]:
1386        ri.cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1387    ri.cw.nl()
1388    for _, attr in ri.struct["request"].member_list():
1389        attr.attr_put(ri, "req")
1390    ri.cw.nl()
1391
1392    ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);')
1393    ri.cw.p('if (err < 0)')
1394    ri.cw.p(f"return {ret_err};")
1395    ri.cw.nl()
1396    ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1397    ri.cw.p('if (len < 0)')
1398    ri.cw.p(f"return {ret_err};")
1399    ri.cw.nl()
1400
1401    if 'reply' in ri.op[ri.op_mode]:
1402        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1403        ri.cw.p('yarg.data = rsp;')
1404        ri.cw.nl()
1405        ri.cw.p(f"err = {ri.nl.parse_cb_run(op_prefix(ri, 'reply') + '_parse', '&yarg', False)};")
1406        ri.cw.p('if (err < 0)')
1407        ri.cw.p('goto err_free;')
1408        ri.cw.nl()
1409
1410    ri.cw.p('err = ynl_recv_ack(ys, err);')
1411    ri.cw.p('if (err)')
1412    ri.cw.p('goto err_free;')
1413    ri.cw.nl()
1414    ri.cw.p(f"return {ret_ok};")
1415    ri.cw.nl()
1416    ri.cw.p('err_free:')
1417
1418    if 'reply' in ri.op[ri.op_mode]:
1419        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1420    ri.cw.p(f"return {ret_err};")
1421    ri.cw.block_end()
1422
1423
1424def print_dump(ri):
1425    direction = "request"
1426    print_prototype(ri, direction, terminate=False)
1427    ri.cw.block_start()
1428    local_vars = ['struct ynl_dump_state yds = {};',
1429                  'struct nlmsghdr *nlh;',
1430                  'int len, err;']
1431
1432    for var in local_vars:
1433        ri.cw.p(f'{var}')
1434    ri.cw.nl()
1435
1436    ri.cw.p('yds.ys = ys;')
1437    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1438    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1439    ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1440    ri.cw.nl()
1441    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1442
1443    if "request" in ri.op[ri.op_mode]:
1444        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1445        ri.cw.nl()
1446        for _, attr in ri.struct["request"].member_list():
1447            attr.attr_put(ri, "req")
1448    ri.cw.nl()
1449
1450    ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);')
1451    ri.cw.p('if (err < 0)')
1452    ri.cw.p('return NULL;')
1453    ri.cw.nl()
1454
1455    ri.cw.block_start(line='do')
1456    ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1457    ri.cw.p('if (len < 0)')
1458    ri.cw.p('goto free_list;')
1459    ri.cw.nl()
1460    ri.cw.p(f"err = {ri.nl.parse_cb_run('ynl_dump_trampoline', '&yds', False, indent=2)};")
1461    ri.cw.p('if (err < 0)')
1462    ri.cw.p('goto free_list;')
1463    ri.cw.block_end(line='while (err > 0);')
1464    ri.cw.nl()
1465
1466    ri.cw.p('return yds.first;')
1467    ri.cw.nl()
1468    ri.cw.p('free_list:')
1469    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1470    ri.cw.p('return NULL;')
1471    ri.cw.block_end()
1472
1473
1474def call_free(ri, direction, var):
1475    return f"{op_prefix(ri, direction)}_free({var});"
1476
1477
1478def free_arg_name(direction):
1479    if direction:
1480        return direction_to_suffix[direction][1:]
1481    return 'obj'
1482
1483
1484def print_free_prototype(ri, direction, suffix=';'):
1485    name = op_prefix(ri, direction)
1486    arg = free_arg_name(direction)
1487    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {name} *{arg}"], suffix=suffix)
1488
1489
1490def _print_type(ri, direction, struct):
1491    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1492
1493    if ri.op_mode == 'dump':
1494        suffix += '_dump'
1495
1496    ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1497
1498    meta_started = False
1499    for _, attr in struct.member_list():
1500        for type_filter in ['len', 'bit']:
1501            line = attr.presence_member(ri.ku_space, type_filter)
1502            if line:
1503                if not meta_started:
1504                    ri.cw.block_start(line=f"struct")
1505                    meta_started = True
1506                ri.cw.p(line)
1507    if meta_started:
1508        ri.cw.block_end(line='_present;')
1509        ri.cw.nl()
1510
1511    for arg in struct.inherited:
1512        ri.cw.p(f"__u32 {arg};")
1513
1514    for _, attr in struct.member_list():
1515        attr.struct_member(ri)
1516
1517    ri.cw.block_end(line=';')
1518    ri.cw.nl()
1519
1520
1521def print_type(ri, direction):
1522    _print_type(ri, direction, ri.struct[direction])
1523
1524
1525def print_type_full(ri, struct):
1526    _print_type(ri, "", struct)
1527
1528
1529def print_type_helpers(ri, direction, deref=False):
1530    print_free_prototype(ri, direction)
1531
1532    if ri.ku_space == 'user' and direction == 'request':
1533        for _, attr in ri.struct[direction].member_list():
1534            attr.setter(ri, ri.attr_set, direction, deref=deref)
1535    ri.cw.nl()
1536
1537
1538def print_req_type_helpers(ri):
1539    print_type_helpers(ri, "request")
1540
1541
1542def print_rsp_type_helpers(ri):
1543    if 'reply' not in ri.op[ri.op_mode]:
1544        return
1545    print_type_helpers(ri, "reply")
1546
1547
1548def print_parse_prototype(ri, direction, terminate=True):
1549    suffix = "_rsp" if direction == "reply" else "_req"
1550    term = ';' if terminate else ''
1551
1552    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1553                          ['const struct nlattr **tb',
1554                           f"struct {ri.op.render_name}{suffix} *req"],
1555                          suffix=term)
1556
1557
1558def print_req_type(ri):
1559    print_type(ri, "request")
1560
1561
1562def print_rsp_type(ri):
1563    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1564        direction = 'reply'
1565    elif ri.op_mode == 'event':
1566        direction = 'reply'
1567    else:
1568        return
1569    print_type(ri, direction)
1570
1571
1572def print_wrapped_type(ri):
1573    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1574    if ri.op_mode == 'dump':
1575        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1576    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1577        ri.cw.p('__u16 family;')
1578        ri.cw.p('__u8 cmd;')
1579        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1580    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1581    ri.cw.block_end(line=';')
1582    ri.cw.nl()
1583    print_free_prototype(ri, 'reply')
1584    ri.cw.nl()
1585
1586
1587def _free_type_members_iter(ri, struct):
1588    for _, attr in struct.member_list():
1589        if attr.free_needs_iter():
1590            ri.cw.p('unsigned int i;')
1591            ri.cw.nl()
1592            break
1593
1594
1595def _free_type_members(ri, var, struct, ref=''):
1596    for _, attr in struct.member_list():
1597        attr.free(ri, var, ref)
1598
1599
1600def _free_type(ri, direction, struct):
1601    var = free_arg_name(direction)
1602
1603    print_free_prototype(ri, direction, suffix='')
1604    ri.cw.block_start()
1605    _free_type_members_iter(ri, struct)
1606    _free_type_members(ri, var, struct)
1607    if direction:
1608        ri.cw.p(f'free({var});')
1609    ri.cw.block_end()
1610    ri.cw.nl()
1611
1612
1613def free_rsp_nested(ri, struct):
1614    _free_type(ri, "", struct)
1615
1616
1617def print_rsp_free(ri):
1618    if 'reply' not in ri.op[ri.op_mode]:
1619        return
1620    _free_type(ri, 'reply', ri.struct['reply'])
1621
1622
1623def print_dump_type_free(ri):
1624    sub_type = type_name(ri, 'reply')
1625
1626    print_free_prototype(ri, 'reply', suffix='')
1627    ri.cw.block_start()
1628    ri.cw.p(f"{sub_type} *next = rsp;")
1629    ri.cw.nl()
1630    ri.cw.block_start(line='while (next)')
1631    _free_type_members_iter(ri, ri.struct['reply'])
1632    ri.cw.p('rsp = next;')
1633    ri.cw.p('next = rsp->next;')
1634    ri.cw.nl()
1635
1636    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1637    ri.cw.p(f'free(rsp);')
1638    ri.cw.block_end()
1639    ri.cw.block_end()
1640    ri.cw.nl()
1641
1642
1643def print_ntf_type_free(ri):
1644    print_free_prototype(ri, 'reply', suffix='')
1645    ri.cw.block_start()
1646    _free_type_members_iter(ri, ri.struct['reply'])
1647    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1648    ri.cw.p(f'free(rsp);')
1649    ri.cw.block_end()
1650    ri.cw.nl()
1651
1652
1653def print_ntf_parse_prototype(family, cw, suffix=';'):
1654    cw.write_func_prot('struct ynl_ntf_base_type *', f"{family['name']}_ntf_parse",
1655                       ['struct ynl_sock *ys'], suffix=suffix)
1656
1657
1658def print_ntf_type_parse(family, cw, ku_mode):
1659    print_ntf_parse_prototype(family, cw, suffix='')
1660    cw.block_start()
1661    cw.write_func_lvar(['struct genlmsghdr *genlh;',
1662                        'struct nlmsghdr *nlh;',
1663                        'struct ynl_parse_arg yarg = { .ys = ys, };',
1664                        'struct ynl_ntf_base_type *rsp;',
1665                        'int len, err;',
1666                        'mnl_cb_t parse;'])
1667    cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1668    cw.p('if (len < (ssize_t)(sizeof(*nlh) + sizeof(*genlh)))')
1669    cw.p('return NULL;')
1670    cw.nl()
1671    cw.p('nlh = (struct nlmsghdr *)ys->rx_buf;')
1672    cw.p('genlh = mnl_nlmsg_get_payload(nlh);')
1673    cw.nl()
1674    cw.block_start(line='switch (genlh->cmd)')
1675    for ntf_op in sorted(family.all_notify.keys()):
1676        op = family.ops[ntf_op]
1677        ri = RenderInfo(cw, family, ku_mode, op, ntf_op, "notify")
1678        for ntf in op['notify']['cmds']:
1679            cw.p(f"case {ntf.enum_name}:")
1680        cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'notify')}));")
1681        cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1682        cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1683        cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1684        cw.p('break;')
1685    for op_name, op in family.ops.items():
1686        if 'event' not in op:
1687            continue
1688        ri = RenderInfo(cw, family, ku_mode, op, op_name, "event")
1689        cw.p(f"case {op.enum_name}:")
1690        cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'event')}));")
1691        cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1692        cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1693        cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1694        cw.p('break;')
1695    cw.p('default:')
1696    cw.p('ynl_error_unknown_notification(ys, genlh->cmd);')
1697    cw.p('return NULL;')
1698    cw.block_end()
1699    cw.nl()
1700    cw.p('yarg.data = rsp->data;')
1701    cw.nl()
1702    cw.p(f"err = {cw.nlib.parse_cb_run('parse', '&yarg', True)};")
1703    cw.p('if (err < 0)')
1704    cw.p('goto err_free;')
1705    cw.nl()
1706    cw.p('rsp->family = nlh->nlmsg_type;')
1707    cw.p('rsp->cmd = genlh->cmd;')
1708    cw.p('return rsp;')
1709    cw.nl()
1710    cw.p('err_free:')
1711    cw.p('free(rsp);')
1712    cw.p('return NULL;')
1713    cw.block_end()
1714    cw.nl()
1715
1716
1717def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1718    if terminate and ri and kernel_can_gen_family_struct(struct.family):
1719        return
1720
1721    if terminate:
1722        prefix = 'extern '
1723    else:
1724        if kernel_can_gen_family_struct(struct.family) and ri:
1725            prefix = 'static '
1726        else:
1727            prefix = ''
1728
1729    suffix = ';' if terminate else ' = {'
1730
1731    max_attr = struct.attr_max_val
1732    if ri:
1733        name = ri.op.render_name
1734        if ri.op.dual_policy:
1735            name += '_' + ri.op_mode
1736    else:
1737        name = struct.render_name
1738    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1739
1740
1741def print_req_policy(cw, struct, ri=None):
1742    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1743    for _, arg in struct.member_list():
1744        arg.attr_policy(cw)
1745    cw.p("};")
1746
1747
1748def kernel_can_gen_family_struct(family):
1749    return family.proto == 'genetlink'
1750
1751
1752def print_kernel_op_table_fwd(family, cw, terminate):
1753    exported = not kernel_can_gen_family_struct(family)
1754
1755    if not terminate or exported:
1756        cw.p(f"/* Ops table for {family.name} */")
1757
1758        pol_to_struct = {'global': 'genl_small_ops',
1759                         'per-op': 'genl_ops',
1760                         'split': 'genl_split_ops'}
1761        struct_type = pol_to_struct[family.kernel_policy]
1762
1763        if family.kernel_policy == 'split':
1764            cnt = 0
1765            for op in family.ops.values():
1766                if 'do' in op:
1767                    cnt += 1
1768                if 'dump' in op:
1769                    cnt += 1
1770        else:
1771            cnt = len(family.ops)
1772
1773        qual = 'static const' if not exported else 'const'
1774        line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1775        if terminate:
1776            cw.p(f"extern {line};")
1777        else:
1778            cw.block_start(line=line + ' =')
1779
1780    if not terminate:
1781        return
1782
1783    cw.nl()
1784    for name in family.hooks['pre']['do']['list']:
1785        cw.write_func_prot('int', c_lower(name),
1786                           ['const struct genl_split_ops *ops',
1787                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1788    for name in family.hooks['post']['do']['list']:
1789        cw.write_func_prot('void', c_lower(name),
1790                           ['const struct genl_split_ops *ops',
1791                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1792    for name in family.hooks['pre']['dump']['list']:
1793        cw.write_func_prot('int', c_lower(name),
1794                           ['struct netlink_callback *cb'], suffix=';')
1795    for name in family.hooks['post']['dump']['list']:
1796        cw.write_func_prot('int', c_lower(name),
1797                           ['struct netlink_callback *cb'], suffix=';')
1798
1799    cw.nl()
1800
1801    for op_name, op in family.ops.items():
1802        if op.is_async:
1803            continue
1804
1805        if 'do' in op:
1806            name = c_lower(f"{family.name}-nl-{op_name}-doit")
1807            cw.write_func_prot('int', name,
1808                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1809
1810        if 'dump' in op:
1811            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1812            cw.write_func_prot('int', name,
1813                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1814    cw.nl()
1815
1816
1817def print_kernel_op_table_hdr(family, cw):
1818    print_kernel_op_table_fwd(family, cw, terminate=True)
1819
1820
1821def print_kernel_op_table(family, cw):
1822    print_kernel_op_table_fwd(family, cw, terminate=False)
1823    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1824        for op_name, op in family.ops.items():
1825            if op.is_async:
1826                continue
1827
1828            cw.block_start()
1829            members = [('cmd', op.enum_name)]
1830            if 'dont-validate' in op:
1831                members.append(('validate',
1832                                ' | '.join([c_upper('genl-dont-validate-' + x)
1833                                            for x in op['dont-validate']])), )
1834            for op_mode in ['do', 'dump']:
1835                if op_mode in op:
1836                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1837                    members.append((op_mode + 'it', name))
1838            if family.kernel_policy == 'per-op':
1839                struct = Struct(family, op['attribute-set'],
1840                                type_list=op['do']['request']['attributes'])
1841
1842                name = c_lower(f"{family.name}-{op_name}-nl-policy")
1843                members.append(('policy', name))
1844                members.append(('maxattr', struct.attr_max_val.enum_name))
1845            if 'flags' in op:
1846                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1847            cw.write_struct_init(members)
1848            cw.block_end(line=',')
1849    elif family.kernel_policy == 'split':
1850        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1851                    'dump': {'pre': 'start', 'post': 'done'}}
1852
1853        for op_name, op in family.ops.items():
1854            for op_mode in ['do', 'dump']:
1855                if op.is_async or op_mode not in op:
1856                    continue
1857
1858                cw.block_start()
1859                members = [('cmd', op.enum_name)]
1860                if 'dont-validate' in op:
1861                    members.append(('validate',
1862                                    ' | '.join([c_upper('genl-dont-validate-' + x)
1863                                                for x in op['dont-validate']])), )
1864                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1865                if 'pre' in op[op_mode]:
1866                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
1867                members.append((op_mode + 'it', name))
1868                if 'post' in op[op_mode]:
1869                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
1870                if 'request' in op[op_mode]:
1871                    struct = Struct(family, op['attribute-set'],
1872                                    type_list=op[op_mode]['request']['attributes'])
1873
1874                    if op.dual_policy:
1875                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
1876                    else:
1877                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
1878                    members.append(('policy', name))
1879                    members.append(('maxattr', struct.attr_max_val.enum_name))
1880                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
1881                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
1882                cw.write_struct_init(members)
1883                cw.block_end(line=',')
1884
1885    cw.block_end(line=';')
1886    cw.nl()
1887
1888
1889def print_kernel_mcgrp_hdr(family, cw):
1890    if not family.mcgrps['list']:
1891        return
1892
1893    cw.block_start('enum')
1894    for grp in family.mcgrps['list']:
1895        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
1896        cw.p(grp_id)
1897    cw.block_end(';')
1898    cw.nl()
1899
1900
1901def print_kernel_mcgrp_src(family, cw):
1902    if not family.mcgrps['list']:
1903        return
1904
1905    cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
1906    for grp in family.mcgrps['list']:
1907        name = grp['name']
1908        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
1909        cw.p('[' + grp_id + '] = { "' + name + '", },')
1910    cw.block_end(';')
1911    cw.nl()
1912
1913
1914def print_kernel_family_struct_hdr(family, cw):
1915    if not kernel_can_gen_family_struct(family):
1916        return
1917
1918    cw.p(f"extern struct genl_family {family.name}_nl_family;")
1919    cw.nl()
1920
1921
1922def print_kernel_family_struct_src(family, cw):
1923    if not kernel_can_gen_family_struct(family):
1924        return
1925
1926    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
1927    cw.p('.name\t\t= ' + family.fam_key + ',')
1928    cw.p('.version\t= ' + family.ver_key + ',')
1929    cw.p('.netnsok\t= true,')
1930    cw.p('.parallel_ops\t= true,')
1931    cw.p('.module\t\t= THIS_MODULE,')
1932    if family.kernel_policy == 'per-op':
1933        cw.p(f'.ops\t\t= {family.name}_nl_ops,')
1934        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
1935    elif family.kernel_policy == 'split':
1936        cw.p(f'.split_ops\t= {family.name}_nl_ops,')
1937        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
1938    if family.mcgrps['list']:
1939        cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
1940        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
1941    cw.block_end(';')
1942
1943
1944def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
1945    start_line = 'enum'
1946    if enum_name in obj:
1947        if obj[enum_name]:
1948            start_line = 'enum ' + c_lower(obj[enum_name])
1949    elif ckey and ckey in obj:
1950        start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
1951    cw.block_start(line=start_line)
1952
1953
1954def render_uapi(family, cw):
1955    hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
1956    cw.p('#ifndef ' + hdr_prot)
1957    cw.p('#define ' + hdr_prot)
1958    cw.nl()
1959
1960    defines = [(family.fam_key, family["name"]),
1961               (family.ver_key, family.get('version', 1))]
1962    cw.writes_defines(defines)
1963    cw.nl()
1964
1965    defines = []
1966    for const in family['definitions']:
1967        if const['type'] != 'const':
1968            cw.writes_defines(defines)
1969            defines = []
1970            cw.nl()
1971
1972        # Write kdoc for enum and flags (one day maybe also structs)
1973        if const['type'] == 'enum' or const['type'] == 'flags':
1974            enum = family.consts[const['name']]
1975
1976            if enum.has_doc():
1977                cw.p('/**')
1978                doc = ''
1979                if 'doc' in enum:
1980                    doc = ' - ' + enum['doc']
1981                cw.write_doc_line(enum.enum_name + doc)
1982                for entry in enum.entry_list:
1983                    if entry.has_doc():
1984                        doc = '@' + entry.c_name + ': ' + entry['doc']
1985                        cw.write_doc_line(doc)
1986                cw.p(' */')
1987
1988            uapi_enum_start(family, cw, const, 'name')
1989            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
1990            for entry in enum.entry_list:
1991                suffix = ','
1992                if entry.value_change:
1993                    suffix = f" = {entry.user_value()}" + suffix
1994                cw.p(entry.c_name + suffix)
1995
1996            if const.get('render-max', False):
1997                cw.nl()
1998                max_name = c_upper(name_pfx + 'max')
1999                cw.p('__' + max_name + ',')
2000                cw.p(max_name + ' = (__' + max_name + ' - 1)')
2001            cw.block_end(line=';')
2002            cw.nl()
2003        elif const['type'] == 'const':
2004            defines.append([c_upper(family.get('c-define-name',
2005                                               f"{family.name}-{const['name']}")),
2006                            const['value']])
2007
2008    if defines:
2009        cw.writes_defines(defines)
2010        cw.nl()
2011
2012    max_by_define = family.get('max-by-define', False)
2013
2014    for _, attr_set in family.attr_sets.items():
2015        if attr_set.subset_of:
2016            continue
2017
2018        cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2019        max_value = f"({cnt_name} - 1)"
2020
2021        val = 0
2022        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2023        for _, attr in attr_set.items():
2024            suffix = ','
2025            if attr.value != val:
2026                suffix = f" = {attr.value},"
2027                val = attr.value
2028            val += 1
2029            cw.p(attr.enum_name + suffix)
2030        cw.nl()
2031        cw.p(cnt_name + ('' if max_by_define else ','))
2032        if not max_by_define:
2033            cw.p(f"{attr_set.max_name} = {max_value}")
2034        cw.block_end(line=';')
2035        if max_by_define:
2036            cw.p(f"#define {attr_set.max_name} {max_value}")
2037        cw.nl()
2038
2039    # Commands
2040    separate_ntf = 'async-prefix' in family['operations']
2041
2042    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2043    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2044    max_value = f"({cnt_name} - 1)"
2045
2046    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2047    for op in family.msgs.values():
2048        if separate_ntf and ('notify' in op or 'event' in op):
2049            continue
2050
2051        suffix = ','
2052        if 'value' in op:
2053            suffix = f" = {op['value']},"
2054        cw.p(op.enum_name + suffix)
2055    cw.nl()
2056    cw.p(cnt_name + ('' if max_by_define else ','))
2057    if not max_by_define:
2058        cw.p(f"{max_name} = {max_value}")
2059    cw.block_end(line=';')
2060    if max_by_define:
2061        cw.p(f"#define {max_name} {max_value}")
2062    cw.nl()
2063
2064    if separate_ntf:
2065        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2066        for op in family.msgs.values():
2067            if separate_ntf and not ('notify' in op or 'event' in op):
2068                continue
2069
2070            suffix = ','
2071            if 'value' in op:
2072                suffix = f" = {op['value']},"
2073            cw.p(op.enum_name + suffix)
2074        cw.block_end(line=';')
2075        cw.nl()
2076
2077    # Multicast
2078    defines = []
2079    for grp in family.mcgrps['list']:
2080        name = grp['name']
2081        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2082                        f'{name}'])
2083    cw.nl()
2084    if defines:
2085        cw.writes_defines(defines)
2086        cw.nl()
2087
2088    cw.p(f'#endif /* {hdr_prot} */')
2089
2090
2091def find_kernel_root(full_path):
2092    sub_path = ''
2093    while True:
2094        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2095        full_path = os.path.dirname(full_path)
2096        maintainers = os.path.join(full_path, "MAINTAINERS")
2097        if os.path.exists(maintainers):
2098            return full_path, sub_path[:-1]
2099
2100
2101def main():
2102    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2103    parser.add_argument('--mode', dest='mode', type=str, required=True)
2104    parser.add_argument('--spec', dest='spec', type=str, required=True)
2105    parser.add_argument('--header', dest='header', action='store_true', default=None)
2106    parser.add_argument('--source', dest='header', action='store_false')
2107    parser.add_argument('--user-header', nargs='+', default=[])
2108    parser.add_argument('-o', dest='out_file', type=str)
2109    args = parser.parse_args()
2110
2111    out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout
2112
2113    if args.header is None:
2114        parser.error("--header or --source is required")
2115
2116    try:
2117        parsed = Family(args.spec)
2118    except yaml.YAMLError as exc:
2119        print(exc)
2120        os.sys.exit(1)
2121        return
2122
2123    cw = CodeWriter(BaseNlLib(), out_file)
2124
2125    _, spec_kernel = find_kernel_root(args.spec)
2126    if args.mode == 'uapi':
2127        cw.p('/* SPDX-License-Identifier: GPL-2.0 WITH Linux-syscall-note */')
2128    else:
2129        if args.header:
2130            cw.p('/* SPDX-License-Identifier: BSD-3-Clause */')
2131        else:
2132            cw.p('// SPDX-License-Identifier: BSD-3-Clause')
2133    cw.p("/* Do not edit directly, auto-generated from: */")
2134    cw.p(f"/*\t{spec_kernel} */")
2135    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2136    cw.nl()
2137
2138    if args.mode == 'uapi':
2139        render_uapi(parsed, cw)
2140        return
2141
2142    hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2143    if args.header:
2144        cw.p('#ifndef ' + hdr_prot)
2145        cw.p('#define ' + hdr_prot)
2146        cw.nl()
2147
2148    if args.mode == 'kernel':
2149        cw.p('#include <net/netlink.h>')
2150        cw.p('#include <net/genetlink.h>')
2151        cw.nl()
2152        if not args.header:
2153            if args.out_file:
2154                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2155            cw.nl()
2156    headers = [parsed.uapi_header]
2157    for definition in parsed['definitions']:
2158        if 'header' in definition:
2159            headers.append(definition['header'])
2160    for one in headers:
2161        cw.p(f"#include <{one}>")
2162    cw.nl()
2163
2164    if args.mode == "user":
2165        if not args.header:
2166            cw.p("#include <stdlib.h>")
2167            cw.p("#include <stdio.h>")
2168            cw.p("#include <string.h>")
2169            cw.p("#include <libmnl/libmnl.h>")
2170            cw.p("#include <linux/genetlink.h>")
2171            cw.nl()
2172            for one in args.user_header:
2173                cw.p(f'#include "{one}"')
2174        else:
2175            cw.p('struct ynl_sock;')
2176        cw.nl()
2177
2178    if args.mode == "kernel":
2179        if args.header:
2180            for _, struct in sorted(parsed.pure_nested_structs.items()):
2181                if struct.request:
2182                    cw.p('/* Common nested types */')
2183                    break
2184            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2185                if struct.request:
2186                    print_req_policy_fwd(cw, struct)
2187            cw.nl()
2188
2189            if parsed.kernel_policy == 'global':
2190                cw.p(f"/* Global operation policy for {parsed.name} */")
2191
2192                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2193                print_req_policy_fwd(cw, struct)
2194                cw.nl()
2195
2196            if parsed.kernel_policy in {'per-op', 'split'}:
2197                for op_name, op in parsed.ops.items():
2198                    if 'do' in op and 'event' not in op:
2199                        ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2200                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2201                        cw.nl()
2202
2203            print_kernel_op_table_hdr(parsed, cw)
2204            print_kernel_mcgrp_hdr(parsed, cw)
2205            print_kernel_family_struct_hdr(parsed, cw)
2206        else:
2207            for _, struct in sorted(parsed.pure_nested_structs.items()):
2208                if struct.request:
2209                    cw.p('/* Common nested types */')
2210                    break
2211            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2212                if struct.request:
2213                    print_req_policy(cw, struct)
2214            cw.nl()
2215
2216            if parsed.kernel_policy == 'global':
2217                cw.p(f"/* Global operation policy for {parsed.name} */")
2218
2219                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2220                print_req_policy(cw, struct)
2221                cw.nl()
2222
2223            for op_name, op in parsed.ops.items():
2224                if parsed.kernel_policy in {'per-op', 'split'}:
2225                    for op_mode in ['do', 'dump']:
2226                        if op_mode in op and 'request' in op[op_mode]:
2227                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2228                            ri = RenderInfo(cw, parsed, args.mode, op, op_name, op_mode)
2229                            print_req_policy(cw, ri.struct['request'], ri=ri)
2230                            cw.nl()
2231
2232            print_kernel_op_table(parsed, cw)
2233            print_kernel_mcgrp_src(parsed, cw)
2234            print_kernel_family_struct_src(parsed, cw)
2235
2236    if args.mode == "user":
2237        has_ntf = False
2238        if args.header:
2239            cw.p('/* Common nested types */')
2240            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2241                ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2242                print_type_full(ri, struct)
2243
2244            for op_name, op in parsed.ops.items():
2245                cw.p(f"/* ============== {op.enum_name} ============== */")
2246
2247                if 'do' in op and 'event' not in op:
2248                    cw.p(f"/* {op.enum_name} - do */")
2249                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2250                    print_req_type(ri)
2251                    print_req_type_helpers(ri)
2252                    cw.nl()
2253                    print_rsp_type(ri)
2254                    print_rsp_type_helpers(ri)
2255                    cw.nl()
2256                    print_req_prototype(ri)
2257                    cw.nl()
2258
2259                if 'dump' in op:
2260                    cw.p(f"/* {op.enum_name} - dump */")
2261                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'dump')
2262                    if 'request' in op['dump']:
2263                        print_req_type(ri)
2264                        print_req_type_helpers(ri)
2265                    if not ri.type_consistent:
2266                        print_rsp_type(ri)
2267                    print_wrapped_type(ri)
2268                    print_dump_prototype(ri)
2269                    cw.nl()
2270
2271                if 'notify' in op:
2272                    cw.p(f"/* {op.enum_name} - notify */")
2273                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2274                    has_ntf = True
2275                    if not ri.type_consistent:
2276                        raise Exception('Only notifications with consistent types supported')
2277                    print_wrapped_type(ri)
2278
2279                if 'event' in op:
2280                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'event')
2281                    cw.p(f"/* {op.enum_name} - event */")
2282                    print_rsp_type(ri)
2283                    cw.nl()
2284                    print_wrapped_type(ri)
2285
2286            if has_ntf:
2287                cw.p('/* --------------- Common notification parsing --------------- */')
2288                print_ntf_parse_prototype(parsed, cw)
2289            cw.nl()
2290        else:
2291            cw.p('/* Policies */')
2292            for name, _ in parsed.attr_sets.items():
2293                struct = Struct(parsed, name)
2294                put_typol_fwd(cw, struct)
2295            cw.nl()
2296
2297            for name, _ in parsed.attr_sets.items():
2298                struct = Struct(parsed, name)
2299                put_typol(cw, struct)
2300
2301            cw.p('/* Common nested types */')
2302            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2303                ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2304
2305                free_rsp_nested(ri, struct)
2306                if struct.request:
2307                    put_req_nested(ri, struct)
2308                if struct.reply:
2309                    parse_rsp_nested(ri, struct)
2310
2311            for op_name, op in parsed.ops.items():
2312                cw.p(f"/* ============== {op.enum_name} ============== */")
2313                if 'do' in op and 'event' not in op:
2314                    cw.p(f"/* {op.enum_name} - do */")
2315                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2316                    print_rsp_free(ri)
2317                    parse_rsp_msg(ri)
2318                    print_req(ri)
2319                    cw.nl()
2320
2321                if 'dump' in op:
2322                    cw.p(f"/* {op.enum_name} - dump */")
2323                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "dump")
2324                    if not ri.type_consistent:
2325                        parse_rsp_msg(ri, deref=True)
2326                    print_dump_type_free(ri)
2327                    print_dump(ri)
2328                    cw.nl()
2329
2330                if 'notify' in op:
2331                    cw.p(f"/* {op.enum_name} - notify */")
2332                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2333                    has_ntf = True
2334                    if not ri.type_consistent:
2335                        raise Exception('Only notifications with consistent types supported')
2336                    print_ntf_type_free(ri)
2337
2338                if 'event' in op:
2339                    cw.p(f"/* {op.enum_name} - event */")
2340                    has_ntf = True
2341
2342                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2343                    parse_rsp_msg(ri)
2344
2345                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "event")
2346                    print_ntf_type_free(ri)
2347
2348            if has_ntf:
2349                cw.p('/* --------------- Common notification parsing --------------- */')
2350                print_ntf_type_parse(parsed, cw, args.mode)
2351
2352    if args.header:
2353        cw.p(f'#endif /* {hdr_prot} */')
2354
2355
2356if __name__ == "__main__":
2357    main()
2358