xref: /openbmc/linux/tools/net/ynl/ynl-gen-c.py (revision 13a9d0be)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
3
4import argparse
5import collections
6import os
7import yaml
8
9from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
10
11
12def c_upper(name):
13    return name.upper().replace('-', '_')
14
15
16def c_lower(name):
17    return name.lower().replace('-', '_')
18
19
20class BaseNlLib:
21    def get_family_id(self):
22        return 'ys->family_id'
23
24    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
25        ind = '\n\t\t' + '\t' * indent + ' '
26        if is_dump:
27            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
28        else:
29            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
30                   "ynl_cb_array, NLMSG_MIN_TYPE)"
31
32
33class Type(SpecAttr):
34    def __init__(self, family, attr_set, attr, value):
35        super().__init__(family, attr_set, attr, value)
36
37        self.attr = attr
38        self.attr_set = attr_set
39        self.type = attr['type']
40        self.checks = attr.get('checks', {})
41
42        if 'len' in attr:
43            self.len = attr['len']
44        if 'nested-attributes' in attr:
45            self.nested_attrs = attr['nested-attributes']
46            if self.nested_attrs == family.name:
47                self.nested_render_name = f"{family.name}"
48            else:
49                self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
50
51        self.c_name = c_lower(self.name)
52        if self.c_name in _C_KW:
53            self.c_name += '_'
54
55        # Added by resolve():
56        self.enum_name = None
57        delattr(self, "enum_name")
58
59    def resolve(self):
60        self.enum_name = f"{self.attr_set.name_prefix}{self.name}"
61        self.enum_name = c_upper(self.enum_name)
62
63    def is_multi_val(self):
64        return None
65
66    def is_scalar(self):
67        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
68
69    def presence_type(self):
70        return 'bit'
71
72    def presence_member(self, space, type_filter):
73        if self.presence_type() != type_filter:
74            return
75
76        if self.presence_type() == 'bit':
77            pfx = '__' if space == 'user' else ''
78            return f"{pfx}u32 {self.c_name}:1;"
79
80        if self.presence_type() == 'len':
81            pfx = '__' if space == 'user' else ''
82            return f"{pfx}u32 {self.c_name}_len;"
83
84    def _complex_member_type(self, ri):
85        return None
86
87    def free_needs_iter(self):
88        return False
89
90    def free(self, ri, var, ref):
91        if self.is_multi_val() or self.presence_type() == 'len':
92            ri.cw.p(f'free({var}->{ref}{self.c_name});')
93
94    def arg_member(self, ri):
95        member = self._complex_member_type(ri)
96        if member:
97            return [member + ' *' + self.c_name]
98        raise Exception(f"Struct member not implemented for class type {self.type}")
99
100    def struct_member(self, ri):
101        if self.is_multi_val():
102            ri.cw.p(f"unsigned int n_{self.c_name};")
103        member = self._complex_member_type(ri)
104        if member:
105            ptr = '*' if self.is_multi_val() else ''
106            ri.cw.p(f"{member} {ptr}{self.c_name};")
107            return
108        members = self.arg_member(ri)
109        for one in members:
110            ri.cw.p(one + ';')
111
112    def _attr_policy(self, policy):
113        return '{ .type = ' + policy + ', }'
114
115    def attr_policy(self, cw):
116        policy = c_upper('nla-' + self.attr['type'])
117
118        spec = self._attr_policy(policy)
119        cw.p(f"\t[{self.enum_name}] = {spec},")
120
121    def _attr_typol(self):
122        raise Exception(f"Type policy not implemented for class type {self.type}")
123
124    def attr_typol(self, cw):
125        typol = self._attr_typol()
126        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
127
128    def _attr_put_line(self, ri, var, line):
129        if self.presence_type() == 'bit':
130            ri.cw.p(f"if ({var}->_present.{self.c_name})")
131        elif self.presence_type() == 'len':
132            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
133        ri.cw.p(f"{line};")
134
135    def _attr_put_simple(self, ri, var, put_type):
136        line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
137        self._attr_put_line(ri, var, line)
138
139    def attr_put(self, ri, var):
140        raise Exception(f"Put not implemented for class type {self.type}")
141
142    def _attr_get(self, ri, var):
143        raise Exception(f"Attr get not implemented for class type {self.type}")
144
145    def attr_get(self, ri, var, first):
146        lines, init_lines, local_vars = self._attr_get(ri, var)
147        if type(lines) is str:
148            lines = [lines]
149        if type(init_lines) is str:
150            init_lines = [init_lines]
151
152        kw = 'if' if first else 'else if'
153        ri.cw.block_start(line=f"{kw} (mnl_attr_get_type(attr) == {self.enum_name})")
154        if local_vars:
155            for local in local_vars:
156                ri.cw.p(local)
157            ri.cw.nl()
158
159        if not self.is_multi_val():
160            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
161            ri.cw.p("return MNL_CB_ERROR;")
162            if self.presence_type() == 'bit':
163                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
164
165        if init_lines:
166            ri.cw.nl()
167            for line in init_lines:
168                ri.cw.p(line)
169
170        for line in lines:
171            ri.cw.p(line)
172        ri.cw.block_end()
173
174    def _setter_lines(self, ri, member, presence):
175        raise Exception(f"Setter not implemented for class type {self.type}")
176
177    def setter(self, ri, space, direction, deref=False, ref=None):
178        ref = (ref if ref else []) + [self.c_name]
179        var = "req"
180        member = f"{var}->{'.'.join(ref)}"
181
182        code = []
183        presence = ''
184        for i in range(0, len(ref)):
185            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
186            if self.presence_type() == 'bit':
187                code.append(presence + ' = 1;')
188        code += self._setter_lines(ri, member, presence)
189
190        ri.cw.write_func('static inline void',
191                         f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}",
192                         body=code,
193                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
194
195
196class TypeUnused(Type):
197    def presence_type(self):
198        return ''
199
200    def _attr_typol(self):
201        return '.type = YNL_PT_REJECT, '
202
203    def attr_policy(self, cw):
204        pass
205
206
207class TypePad(Type):
208    def presence_type(self):
209        return ''
210
211    def _attr_typol(self):
212        return '.type = YNL_PT_REJECT, '
213
214    def attr_policy(self, cw):
215        pass
216
217
218class TypeScalar(Type):
219    def __init__(self, family, attr_set, attr, value):
220        super().__init__(family, attr_set, attr, value)
221
222        self.byte_order_comment = ''
223        if 'byte-order' in attr:
224            self.byte_order_comment = f" /* {attr['byte-order']} */"
225
226        # Added by resolve():
227        self.is_bitfield = None
228        delattr(self, "is_bitfield")
229        self.type_name = None
230        delattr(self, "type_name")
231
232    def resolve(self):
233        self.resolve_up(super())
234
235        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
236            self.is_bitfield = True
237        elif 'enum' in self.attr:
238            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
239        else:
240            self.is_bitfield = False
241
242        if 'enum' in self.attr and not self.is_bitfield:
243            self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
244        else:
245            self.type_name = '__' + self.type
246
247    def _mnl_type(self):
248        t = self.type
249        # mnl does not have a helper for signed types
250        if t[0] == 's':
251            t = 'u' + t[1:]
252        return t
253
254    def _attr_policy(self, policy):
255        if 'flags-mask' in self.checks or self.is_bitfield:
256            if self.is_bitfield:
257                mask = self.family.consts[self.attr['enum']].get_mask()
258            else:
259                flags = self.family.consts[self.checks['flags-mask']]
260                flag_cnt = len(flags['entries'])
261                mask = (1 << flag_cnt) - 1
262            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
263        elif 'min' in self.checks:
264            return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
265        elif 'enum' in self.attr:
266            enum = self.family.consts[self.attr['enum']]
267            cnt = len(enum['entries'])
268            return f"NLA_POLICY_MAX({policy}, {cnt - 1})"
269        return super()._attr_policy(policy)
270
271    def _attr_typol(self):
272        return f'.type = YNL_PT_U{self.type[1:]}, '
273
274    def arg_member(self, ri):
275        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
276
277    def attr_put(self, ri, var):
278        self._attr_put_simple(ri, var, self._mnl_type())
279
280    def _attr_get(self, ri, var):
281        return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
282
283    def _setter_lines(self, ri, member, presence):
284        return [f"{member} = {self.c_name};"]
285
286
287class TypeFlag(Type):
288    def arg_member(self, ri):
289        return []
290
291    def _attr_typol(self):
292        return '.type = YNL_PT_FLAG, '
293
294    def attr_put(self, ri, var):
295        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
296
297    def _attr_get(self, ri, var):
298        return [], None, None
299
300    def _setter_lines(self, ri, member, presence):
301        return []
302
303
304class TypeString(Type):
305    def arg_member(self, ri):
306        return [f"const char *{self.c_name}"]
307
308    def presence_type(self):
309        return 'len'
310
311    def struct_member(self, ri):
312        ri.cw.p(f"char *{self.c_name};")
313
314    def _attr_typol(self):
315        return f'.type = YNL_PT_NUL_STR, '
316
317    def _attr_policy(self, policy):
318        mem = '{ .type = ' + policy
319        if 'max-len' in self.checks:
320            mem += ', .len = ' + str(self.checks['max-len'])
321        mem += ', }'
322        return mem
323
324    def attr_policy(self, cw):
325        if self.checks.get('unterminated-ok', False):
326            policy = 'NLA_STRING'
327        else:
328            policy = 'NLA_NUL_STRING'
329
330        spec = self._attr_policy(policy)
331        cw.p(f"\t[{self.enum_name}] = {spec},")
332
333    def attr_put(self, ri, var):
334        self._attr_put_simple(ri, var, 'strz')
335
336    def _attr_get(self, ri, var):
337        len_mem = var + '->_present.' + self.c_name + '_len'
338        return [f"{len_mem} = len;",
339                f"{var}->{self.c_name} = malloc(len + 1);",
340                f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
341                f"{var}->{self.c_name}[len] = 0;"], \
342               ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
343               ['unsigned int len;']
344
345    def _setter_lines(self, ri, member, presence):
346        return [f"free({member});",
347                f"{presence}_len = strlen({self.c_name});",
348                f"{member} = malloc({presence}_len + 1);",
349                f'memcpy({member}, {self.c_name}, {presence}_len);',
350                f'{member}[{presence}_len] = 0;']
351
352
353class TypeBinary(Type):
354    def arg_member(self, ri):
355        return [f"const void *{self.c_name}", 'size_t len']
356
357    def presence_type(self):
358        return 'len'
359
360    def struct_member(self, ri):
361        ri.cw.p(f"void *{self.c_name};")
362
363    def _attr_typol(self):
364        return f'.type = YNL_PT_BINARY,'
365
366    def _attr_policy(self, policy):
367        mem = '{ '
368        if len(self.checks) == 1 and 'min-len' in self.checks:
369            mem += '.len = ' + str(self.checks['min-len'])
370        elif len(self.checks) == 0:
371            mem += '.type = NLA_BINARY'
372        else:
373            raise Exception('One or more of binary type checks not implemented, yet')
374        mem += ', }'
375        return mem
376
377    def attr_put(self, ri, var):
378        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
379                            f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
380
381    def _attr_get(self, ri, var):
382        len_mem = var + '->_present.' + self.c_name + '_len'
383        return [f"{len_mem} = len;",
384                f"{var}->{self.c_name} = malloc(len);",
385                f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
386               ['len = mnl_attr_get_payload_len(attr);'], \
387               ['unsigned int len;']
388
389    def _setter_lines(self, ri, member, presence):
390        return [f"free({member});",
391                f"{member} = malloc({presence}_len);",
392                f'memcpy({member}, {self.c_name}, {presence}_len);']
393
394
395class TypeNest(Type):
396    def _complex_member_type(self, ri):
397        return f"struct {self.nested_render_name}"
398
399    def free(self, ri, var, ref):
400        ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
401
402    def _attr_typol(self):
403        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
404
405    def _attr_policy(self, policy):
406        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
407
408    def attr_put(self, ri, var):
409        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
410                            f"{self.enum_name}, &{var}->{self.c_name})")
411
412    def _attr_get(self, ri, var):
413        get_lines = [f"{self.nested_render_name}_parse(&parg, attr);"]
414        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
415                      f"parg.data = &{var}->{self.c_name};"]
416        return get_lines, init_lines, None
417
418    def setter(self, ri, space, direction, deref=False, ref=None):
419        ref = (ref if ref else []) + [self.c_name]
420
421        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
422            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
423
424
425class TypeMultiAttr(Type):
426    def is_multi_val(self):
427        return True
428
429    def presence_type(self):
430        return 'count'
431
432    def _complex_member_type(self, ri):
433        if 'type' not in self.attr or self.attr['type'] == 'nest':
434            return f"struct {self.nested_render_name}"
435        elif self.attr['type'] in scalars:
436            scalar_pfx = '__' if ri.ku_space == 'user' else ''
437            return scalar_pfx + self.attr['type']
438        else:
439            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
440
441    def free_needs_iter(self):
442        return 'type' not in self.attr or self.attr['type'] == 'nest'
443
444    def free(self, ri, var, ref):
445        if 'type' not in self.attr or self.attr['type'] == 'nest':
446            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
447            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
448
449    def _attr_typol(self):
450        if 'type' not in self.attr or self.attr['type'] == 'nest':
451            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
452        elif self.attr['type'] in scalars:
453            return f".type = YNL_PT_U{self.attr['type'][1:]}, "
454        else:
455            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
456
457    def _attr_get(self, ri, var):
458        return f'{var}->n_{self.c_name}++;', None, None
459
460
461class TypeArrayNest(Type):
462    def is_multi_val(self):
463        return True
464
465    def presence_type(self):
466        return 'count'
467
468    def _complex_member_type(self, ri):
469        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
470            return f"struct {self.nested_render_name}"
471        elif self.attr['sub-type'] in scalars:
472            scalar_pfx = '__' if ri.ku_space == 'user' else ''
473            return scalar_pfx + self.attr['sub-type']
474        else:
475            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
476
477    def _attr_typol(self):
478        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
479
480    def _attr_get(self, ri, var):
481        local_vars = ['const struct nlattr *attr2;']
482        get_lines = [f'attr_{self.c_name} = attr;',
483                     'mnl_attr_for_each_nested(attr2, attr)',
484                     f'\t{var}->n_{self.c_name}++;']
485        return get_lines, None, local_vars
486
487
488class TypeNestTypeValue(Type):
489    def _complex_member_type(self, ri):
490        return f"struct {self.nested_render_name}"
491
492    def _attr_typol(self):
493        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
494
495    def _attr_get(self, ri, var):
496        prev = 'attr'
497        tv_args = ''
498        get_lines = []
499        local_vars = []
500        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
501                      f"parg.data = &{var}->{self.c_name};"]
502        if 'type-value' in self.attr:
503            tv_names = [c_lower(x) for x in self.attr["type-value"]]
504            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
505            local_vars += [f'__u32 {", ".join(tv_names)};']
506            for level in self.attr["type-value"]:
507                level = c_lower(level)
508                get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
509                get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
510                prev = 'attr_' + level
511
512            tv_args = f", {', '.join(tv_names)}"
513
514        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
515        return get_lines, init_lines, local_vars
516
517
518class Struct:
519    def __init__(self, family, space_name, type_list=None, inherited=None):
520        self.family = family
521        self.space_name = space_name
522        self.attr_set = family.attr_sets[space_name]
523        # Use list to catch comparisons with empty sets
524        self._inherited = inherited if inherited is not None else []
525        self.inherited = []
526
527        self.nested = type_list is None
528        if family.name == c_lower(space_name):
529            self.render_name = f"{family.name}"
530        else:
531            self.render_name = f"{family.name}_{c_lower(space_name)}"
532        self.struct_name = 'struct ' + self.render_name
533        self.ptr_name = self.struct_name + ' *'
534
535        self.request = False
536        self.reply = False
537
538        self.attr_list = []
539        self.attrs = dict()
540        if type_list:
541            for t in type_list:
542                self.attr_list.append((t, self.attr_set[t]),)
543        else:
544            for t in self.attr_set:
545                self.attr_list.append((t, self.attr_set[t]),)
546
547        max_val = 0
548        self.attr_max_val = None
549        for name, attr in self.attr_list:
550            if attr.value >= max_val:
551                max_val = attr.value
552                self.attr_max_val = attr
553            self.attrs[name] = attr
554
555    def __iter__(self):
556        yield from self.attrs
557
558    def __getitem__(self, key):
559        return self.attrs[key]
560
561    def member_list(self):
562        return self.attr_list
563
564    def set_inherited(self, new_inherited):
565        if self._inherited != new_inherited:
566            raise Exception("Inheriting different members not supported")
567        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
568
569
570class EnumEntry(SpecEnumEntry):
571    def __init__(self, enum_set, yaml, prev, value_start):
572        super().__init__(enum_set, yaml, prev, value_start)
573
574        if prev:
575            self.value_change = (self.value != prev.value + 1)
576        else:
577            self.value_change = (self.value != 0)
578        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
579
580        # Added by resolve:
581        self.c_name = None
582        delattr(self, "c_name")
583
584    def resolve(self):
585        self.resolve_up(super())
586
587        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
588
589
590class EnumSet(SpecEnumSet):
591    def __init__(self, family, yaml):
592        self.render_name = c_lower(family.name + '-' + yaml['name'])
593        self.enum_name = 'enum ' + self.render_name
594
595        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
596
597        super().__init__(family, yaml)
598
599    def new_entry(self, entry, prev_entry, value_start):
600        return EnumEntry(self, entry, prev_entry, value_start)
601
602
603class AttrSet(SpecAttrSet):
604    def __init__(self, family, yaml):
605        super().__init__(family, yaml)
606
607        if self.subset_of is None:
608            if 'name-prefix' in yaml:
609                pfx = yaml['name-prefix']
610            elif self.name == family.name:
611                pfx = family.name + '-a-'
612            else:
613                pfx = f"{family.name}-a-{self.name}-"
614            self.name_prefix = c_upper(pfx)
615            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
616        else:
617            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
618            self.max_name = family.attr_sets[self.subset_of].max_name
619
620        # Added by resolve:
621        self.c_name = None
622        delattr(self, "c_name")
623
624    def resolve(self):
625        self.c_name = c_lower(self.name)
626        if self.c_name in _C_KW:
627            self.c_name += '_'
628        if self.c_name == self.family.c_name:
629            self.c_name = ''
630
631    def new_attr(self, elem, value):
632        if 'multi-attr' in elem and elem['multi-attr']:
633            return TypeMultiAttr(self.family, self, elem, value)
634        elif elem['type'] in scalars:
635            return TypeScalar(self.family, self, elem, value)
636        elif elem['type'] == 'unused':
637            return TypeUnused(self.family, self, elem, value)
638        elif elem['type'] == 'pad':
639            return TypePad(self.family, self, elem, value)
640        elif elem['type'] == 'flag':
641            return TypeFlag(self.family, self, elem, value)
642        elif elem['type'] == 'string':
643            return TypeString(self.family, self, elem, value)
644        elif elem['type'] == 'binary':
645            return TypeBinary(self.family, self, elem, value)
646        elif elem['type'] == 'nest':
647            return TypeNest(self.family, self, elem, value)
648        elif elem['type'] == 'array-nest':
649            return TypeArrayNest(self.family, self, elem, value)
650        elif elem['type'] == 'nest-type-value':
651            return TypeNestTypeValue(self.family, self, elem, value)
652        else:
653            raise Exception(f"No typed class for type {elem['type']}")
654
655
656class Operation(SpecOperation):
657    def __init__(self, family, yaml, req_value, rsp_value):
658        super().__init__(family, yaml, req_value, rsp_value)
659
660        if req_value != rsp_value:
661            raise Exception("Directional messages not supported by codegen")
662
663        self.render_name = family.name + '_' + c_lower(self.name)
664
665        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
666                         ('dump' in yaml and 'request' in yaml['dump'])
667
668        # Added by resolve:
669        self.enum_name = None
670        delattr(self, "enum_name")
671
672    def resolve(self):
673        self.resolve_up(super())
674
675        if not self.is_async:
676            self.enum_name = self.family.op_prefix + c_upper(self.name)
677        else:
678            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
679
680    def add_notification(self, op):
681        if 'notify' not in self.yaml:
682            self.yaml['notify'] = dict()
683            self.yaml['notify']['reply'] = self.yaml['do']['reply']
684            self.yaml['notify']['cmds'] = []
685        self.yaml['notify']['cmds'].append(op)
686
687
688class Family(SpecFamily):
689    def __init__(self, file_name):
690        # Added by resolve:
691        self.c_name = None
692        delattr(self, "c_name")
693        self.op_prefix = None
694        delattr(self, "op_prefix")
695        self.async_op_prefix = None
696        delattr(self, "async_op_prefix")
697        self.mcgrps = None
698        delattr(self, "mcgrps")
699        self.consts = None
700        delattr(self, "consts")
701        self.hooks = None
702        delattr(self, "hooks")
703
704        super().__init__(file_name)
705
706        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
707        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
708
709        if 'definitions' not in self.yaml:
710            self.yaml['definitions'] = []
711
712        if 'uapi-header' in self.yaml:
713            self.uapi_header = self.yaml['uapi-header']
714        else:
715            self.uapi_header = f"linux/{self.name}.h"
716
717    def resolve(self):
718        self.resolve_up(super())
719
720        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
721            raise Exception("Codegen only supported for genetlink")
722
723        self.c_name = c_lower(self.name)
724        if 'name-prefix' in self.yaml['operations']:
725            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
726        else:
727            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
728        if 'async-prefix' in self.yaml['operations']:
729            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
730        else:
731            self.async_op_prefix = self.op_prefix
732
733        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
734
735        self.hooks = dict()
736        for when in ['pre', 'post']:
737            self.hooks[when] = dict()
738            for op_mode in ['do', 'dump']:
739                self.hooks[when][op_mode] = dict()
740                self.hooks[when][op_mode]['set'] = set()
741                self.hooks[when][op_mode]['list'] = []
742
743        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
744        self.root_sets = dict()
745        # dict space-name -> set('request', 'reply')
746        self.pure_nested_structs = dict()
747        self.all_notify = dict()
748
749        self._mock_up_events()
750
751        self._dictify()
752        self._load_root_sets()
753        self._load_nested_sets()
754        self._load_all_notify()
755        self._load_hooks()
756
757        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
758        if self.kernel_policy == 'global':
759            self._load_global_policy()
760
761    def new_enum(self, elem):
762        return EnumSet(self, elem)
763
764    def new_attr_set(self, elem):
765        return AttrSet(self, elem)
766
767    def new_operation(self, elem, req_value, rsp_value):
768        return Operation(self, elem, req_value, rsp_value)
769
770    # Fake a 'do' equivalent of all events, so that we can render their response parsing
771    def _mock_up_events(self):
772        for op in self.yaml['operations']['list']:
773            if 'event' in op:
774                op['do'] = {
775                    'reply': {
776                        'attributes': op['event']['attributes']
777                    }
778                }
779
780    def _dictify(self):
781        ntf = []
782        for msg in self.msgs.values():
783            if 'notify' in msg:
784                ntf.append(msg)
785        for n in ntf:
786            self.ops[n['notify']].add_notification(n)
787
788    def _load_root_sets(self):
789        for op_name, op in self.ops.items():
790            if 'attribute-set' not in op:
791                continue
792
793            req_attrs = set()
794            rsp_attrs = set()
795            for op_mode in ['do', 'dump']:
796                if op_mode in op and 'request' in op[op_mode]:
797                    req_attrs.update(set(op[op_mode]['request']['attributes']))
798                if op_mode in op and 'reply' in op[op_mode]:
799                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
800
801            if op['attribute-set'] not in self.root_sets:
802                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
803            else:
804                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
805                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
806
807    def _load_nested_sets(self):
808        for root_set, rs_members in self.root_sets.items():
809            for attr, spec in self.attr_sets[root_set].items():
810                if 'nested-attributes' in spec:
811                    inherit = set()
812                    nested = spec['nested-attributes']
813                    if nested not in self.root_sets:
814                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
815                    if attr in rs_members['request']:
816                        self.pure_nested_structs[nested].request = True
817                    if attr in rs_members['reply']:
818                        self.pure_nested_structs[nested].reply = True
819
820                    if 'type-value' in spec:
821                        if nested in self.root_sets:
822                            raise Exception("Inheriting members to a space used as root not supported")
823                        inherit.update(set(spec['type-value']))
824                    elif spec['type'] == 'array-nest':
825                        inherit.add('idx')
826                    self.pure_nested_structs[nested].set_inherited(inherit)
827
828    def _load_all_notify(self):
829        for op_name, op in self.ops.items():
830            if not op:
831                continue
832
833            if 'notify' in op:
834                self.all_notify[op_name] = op['notify']['cmds']
835
836    def _load_global_policy(self):
837        global_set = set()
838        attr_set_name = None
839        for op_name, op in self.ops.items():
840            if not op:
841                continue
842            if 'attribute-set' not in op:
843                continue
844
845            if attr_set_name is None:
846                attr_set_name = op['attribute-set']
847            if attr_set_name != op['attribute-set']:
848                raise Exception('For a global policy all ops must use the same set')
849
850            for op_mode in ['do', 'dump']:
851                if op_mode in op:
852                    global_set.update(op[op_mode].get('request', []))
853
854        self.global_policy = []
855        self.global_policy_set = attr_set_name
856        for attr in self.attr_sets[attr_set_name]:
857            if attr in global_set:
858                self.global_policy.append(attr)
859
860    def _load_hooks(self):
861        for op in self.ops.values():
862            for op_mode in ['do', 'dump']:
863                if op_mode not in op:
864                    continue
865                for when in ['pre', 'post']:
866                    if when not in op[op_mode]:
867                        continue
868                    name = op[op_mode][when]
869                    if name in self.hooks[when][op_mode]['set']:
870                        continue
871                    self.hooks[when][op_mode]['set'].add(name)
872                    self.hooks[when][op_mode]['list'].append(name)
873
874
875class RenderInfo:
876    def __init__(self, cw, family, ku_space, op, op_name, op_mode, attr_set=None):
877        self.family = family
878        self.nl = cw.nlib
879        self.ku_space = ku_space
880        self.op = op
881        self.op_name = op_name
882        self.op_mode = op_mode
883
884        # 'do' and 'dump' response parsing is identical
885        if op_mode != 'do' and 'dump' in op and 'do' in op and 'reply' in op['do'] and \
886           op["do"]["reply"] == op["dump"]["reply"]:
887            self.type_consistent = True
888        else:
889            self.type_consistent = op_mode == 'event'
890
891        self.attr_set = attr_set
892        if not self.attr_set:
893            self.attr_set = op['attribute-set']
894
895        if op:
896            self.type_name = c_lower(op_name)
897        else:
898            self.type_name = c_lower(attr_set)
899
900        self.cw = cw
901
902        self.struct = dict()
903        for op_dir in ['request', 'reply']:
904            if op and op_dir in op[op_mode]:
905                self.struct[op_dir] = Struct(family, self.attr_set,
906                                             type_list=op[op_mode][op_dir]['attributes'])
907        if op_mode == 'event':
908            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
909
910
911class CodeWriter:
912    def __init__(self, nlib, out_file):
913        self.nlib = nlib
914
915        self._nl = False
916        self._silent_block = False
917        self._ind = 0
918        self._out = out_file
919
920    @classmethod
921    def _is_cond(cls, line):
922        return line.startswith('if') or line.startswith('while') or line.startswith('for')
923
924    def p(self, line, add_ind=0):
925        if self._nl:
926            self._out.write('\n')
927            self._nl = False
928        ind = self._ind
929        if line[-1] == ':':
930            ind -= 1
931        if self._silent_block:
932            ind += 1
933        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
934        if add_ind:
935            ind += add_ind
936        self._out.write('\t' * ind + line + '\n')
937
938    def nl(self):
939        self._nl = True
940
941    def block_start(self, line=''):
942        if line:
943            line = line + ' '
944        self.p(line + '{')
945        self._ind += 1
946
947    def block_end(self, line=''):
948        if line and line[0] not in {';', ','}:
949            line = ' ' + line
950        self._ind -= 1
951        self.p('}' + line)
952
953    def write_doc_line(self, doc, indent=True):
954        words = doc.split()
955        line = ' *'
956        for word in words:
957            if len(line) + len(word) >= 79:
958                self.p(line)
959                line = ' *'
960                if indent:
961                    line += '  '
962            line += ' ' + word
963        self.p(line)
964
965    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
966        if not args:
967            args = ['void']
968
969        if doc:
970            self.p('/*')
971            self.p(' * ' + doc)
972            self.p(' */')
973
974        oneline = qual_ret
975        if qual_ret[-1] != '*':
976            oneline += ' '
977        oneline += f"{name}({', '.join(args)}){suffix}"
978
979        if len(oneline) < 80:
980            self.p(oneline)
981            return
982
983        v = qual_ret
984        if len(v) > 3:
985            self.p(v)
986            v = ''
987        elif qual_ret[-1] != '*':
988            v += ' '
989        v += name + '('
990        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
991        delta_ind = len(v) - len(ind)
992        v += args[0]
993        i = 1
994        while i < len(args):
995            next_len = len(v) + len(args[i])
996            if v[0] == '\t':
997                next_len += delta_ind
998            if next_len > 76:
999                self.p(v + ',')
1000                v = ind
1001            else:
1002                v += ', '
1003            v += args[i]
1004            i += 1
1005        self.p(v + ')' + suffix)
1006
1007    def write_func_lvar(self, local_vars):
1008        if not local_vars:
1009            return
1010
1011        if type(local_vars) is str:
1012            local_vars = [local_vars]
1013
1014        local_vars.sort(key=len, reverse=True)
1015        for var in local_vars:
1016            self.p(var)
1017        self.nl()
1018
1019    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1020        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1021        self.write_func_lvar(local_vars=local_vars)
1022
1023        self.block_start()
1024        for line in body:
1025            self.p(line)
1026        self.block_end()
1027
1028    def writes_defines(self, defines):
1029        longest = 0
1030        for define in defines:
1031            if len(define[0]) > longest:
1032                longest = len(define[0])
1033        longest = ((longest + 8) // 8) * 8
1034        for define in defines:
1035            line = '#define ' + define[0]
1036            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1037            if type(define[1]) is int:
1038                line += str(define[1])
1039            elif type(define[1]) is str:
1040                line += '"' + define[1] + '"'
1041            self.p(line)
1042
1043    def write_struct_init(self, members):
1044        longest = max([len(x[0]) for x in members])
1045        longest += 1  # because we prepend a .
1046        longest = ((longest + 8) // 8) * 8
1047        for one in members:
1048            line = '.' + one[0]
1049            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1050            line += '= ' + one[1] + ','
1051            self.p(line)
1052
1053
1054scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1055
1056direction_to_suffix = {
1057    'reply': '_rsp',
1058    'request': '_req',
1059    '': ''
1060}
1061
1062op_mode_to_wrapper = {
1063    'do': '',
1064    'dump': '_list',
1065    'notify': '_ntf',
1066    'event': '',
1067}
1068
1069_C_KW = {
1070    'do'
1071}
1072
1073
1074def rdir(direction):
1075    if direction == 'reply':
1076        return 'request'
1077    if direction == 'request':
1078        return 'reply'
1079    return direction
1080
1081
1082def op_prefix(ri, direction, deref=False):
1083    suffix = f"_{ri.type_name}"
1084
1085    if not ri.op_mode or ri.op_mode == 'do':
1086        suffix += f"{direction_to_suffix[direction]}"
1087    else:
1088        if direction == 'request':
1089            suffix += '_req_dump'
1090        else:
1091            if ri.type_consistent:
1092                if deref:
1093                    suffix += f"{direction_to_suffix[direction]}"
1094                else:
1095                    suffix += op_mode_to_wrapper[ri.op_mode]
1096            else:
1097                suffix += '_rsp'
1098                suffix += '_dump' if deref else '_list'
1099
1100    return f"{ri.family['name']}{suffix}"
1101
1102
1103def type_name(ri, direction, deref=False):
1104    return f"struct {op_prefix(ri, direction, deref=deref)}"
1105
1106
1107def print_prototype(ri, direction, terminate=True, doc=None):
1108    suffix = ';' if terminate else ''
1109
1110    fname = ri.op.render_name
1111    if ri.op_mode == 'dump':
1112        fname += '_dump'
1113
1114    args = ['struct ynl_sock *ys']
1115    if 'request' in ri.op[ri.op_mode]:
1116        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1117
1118    ret = 'int'
1119    if 'reply' in ri.op[ri.op_mode]:
1120        ret = f"{type_name(ri, rdir(direction))} *"
1121
1122    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1123
1124
1125def print_req_prototype(ri):
1126    print_prototype(ri, "request", doc=ri.op['doc'])
1127
1128
1129def print_dump_prototype(ri):
1130    print_prototype(ri, "request")
1131
1132
1133def put_typol_fwd(cw, struct):
1134    cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;')
1135
1136
1137def put_typol(cw, struct):
1138    type_max = struct.attr_set.max_name
1139    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1140
1141    for _, arg in struct.member_list():
1142        arg.attr_typol(cw)
1143
1144    cw.block_end(line=';')
1145    cw.nl()
1146
1147    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1148    cw.p(f'.max_attr = {type_max},')
1149    cw.p(f'.table = {struct.render_name}_policy,')
1150    cw.block_end(line=';')
1151    cw.nl()
1152
1153
1154def put_req_nested(ri, struct):
1155    func_args = ['struct nlmsghdr *nlh',
1156                 'unsigned int attr_type',
1157                 f'{struct.ptr_name}obj']
1158
1159    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1160    ri.cw.block_start()
1161    ri.cw.write_func_lvar('struct nlattr *nest;')
1162
1163    ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1164
1165    for _, arg in struct.member_list():
1166        arg.attr_put(ri, "obj")
1167
1168    ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1169
1170    ri.cw.nl()
1171    ri.cw.p('return 0;')
1172    ri.cw.block_end()
1173    ri.cw.nl()
1174
1175
1176def _multi_parse(ri, struct, init_lines, local_vars):
1177    if struct.nested:
1178        iter_line = "mnl_attr_for_each_nested(attr, nested)"
1179    else:
1180        iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1181
1182    array_nests = set()
1183    multi_attrs = set()
1184    needs_parg = False
1185    for arg, aspec in struct.member_list():
1186        if aspec['type'] == 'array-nest':
1187            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1188            array_nests.add(arg)
1189        if 'multi-attr' in aspec:
1190            multi_attrs.add(arg)
1191        needs_parg |= 'nested-attributes' in aspec
1192    if array_nests or multi_attrs:
1193        local_vars.append('int i;')
1194    if needs_parg:
1195        local_vars.append('struct ynl_parse_arg parg;')
1196        init_lines.append('parg.ys = yarg->ys;')
1197
1198    ri.cw.block_start()
1199    ri.cw.write_func_lvar(local_vars)
1200
1201    for line in init_lines:
1202        ri.cw.p(line)
1203    ri.cw.nl()
1204
1205    for arg in struct.inherited:
1206        ri.cw.p(f'dst->{arg} = {arg};')
1207
1208    ri.cw.nl()
1209    ri.cw.block_start(line=iter_line)
1210
1211    first = True
1212    for _, arg in struct.member_list():
1213        arg.attr_get(ri, 'dst', first=first)
1214        first = False
1215
1216    ri.cw.block_end()
1217    ri.cw.nl()
1218
1219    for anest in sorted(array_nests):
1220        aspec = struct[anest]
1221
1222        ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1223        ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1224        ri.cw.p('i = 0;')
1225        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1226        ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1227        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1228        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1229        ri.cw.p('return MNL_CB_ERROR;')
1230        ri.cw.p('i++;')
1231        ri.cw.block_end()
1232        ri.cw.block_end()
1233    ri.cw.nl()
1234
1235    for anest in sorted(multi_attrs):
1236        aspec = struct[anest]
1237        ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1238        ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1239        ri.cw.p('i = 0;')
1240        if 'nested-attributes' in aspec:
1241            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1242        ri.cw.block_start(line=iter_line)
1243        ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1244        if 'nested-attributes' in aspec:
1245            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1246            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1247            ri.cw.p('return MNL_CB_ERROR;')
1248        elif aspec['type'] in scalars:
1249            t = aspec['type']
1250            if t[0] == 's':
1251                t = 'u' + t[1:]
1252            ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1253        else:
1254            raise Exception('Nest parsing type not supported yet')
1255        ri.cw.p('i++;')
1256        ri.cw.block_end()
1257        ri.cw.block_end()
1258        ri.cw.block_end()
1259    ri.cw.nl()
1260
1261    if struct.nested:
1262        ri.cw.p('return 0;')
1263    else:
1264        ri.cw.p('return MNL_CB_OK;')
1265    ri.cw.block_end()
1266    ri.cw.nl()
1267
1268
1269def parse_rsp_nested(ri, struct):
1270    func_args = ['struct ynl_parse_arg *yarg',
1271                 'const struct nlattr *nested']
1272    for arg in struct.inherited:
1273        func_args.append('__u32 ' + arg)
1274
1275    local_vars = ['const struct nlattr *attr;',
1276                  f'{struct.ptr_name}dst = yarg->data;']
1277    init_lines = []
1278
1279    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1280
1281    _multi_parse(ri, struct, init_lines, local_vars)
1282
1283
1284def parse_rsp_msg(ri, deref=False):
1285    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1286        return
1287
1288    func_args = ['const struct nlmsghdr *nlh',
1289                 'void *data']
1290
1291    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1292                  'struct ynl_parse_arg *yarg = data;',
1293                  'const struct nlattr *attr;']
1294    init_lines = ['dst = yarg->data;']
1295
1296    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1297
1298    _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1299
1300
1301def print_req(ri):
1302    ret_ok = '0'
1303    ret_err = '-1'
1304    direction = "request"
1305    local_vars = ['struct nlmsghdr *nlh;',
1306                  'int len, err;']
1307
1308    if 'reply' in ri.op[ri.op_mode]:
1309        ret_ok = 'rsp'
1310        ret_err = 'NULL'
1311        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1312                       'struct ynl_parse_arg yarg = { .ys = ys, };']
1313
1314    print_prototype(ri, direction, terminate=False)
1315    ri.cw.block_start()
1316    ri.cw.write_func_lvar(local_vars)
1317
1318    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1319
1320    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1321    if 'reply' in ri.op[ri.op_mode]:
1322        ri.cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1323    ri.cw.nl()
1324    for _, attr in ri.struct["request"].member_list():
1325        attr.attr_put(ri, "req")
1326    ri.cw.nl()
1327
1328    ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);')
1329    ri.cw.p('if (err < 0)')
1330    ri.cw.p(f"return {ret_err};")
1331    ri.cw.nl()
1332    ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1333    ri.cw.p('if (len < 0)')
1334    ri.cw.p(f"return {ret_err};")
1335    ri.cw.nl()
1336
1337    if 'reply' in ri.op[ri.op_mode]:
1338        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1339        ri.cw.p('yarg.data = rsp;')
1340        ri.cw.nl()
1341        ri.cw.p(f"err = {ri.nl.parse_cb_run(op_prefix(ri, 'reply') + '_parse', '&yarg', False)};")
1342        ri.cw.p('if (err < 0)')
1343        ri.cw.p('goto err_free;')
1344        ri.cw.nl()
1345
1346    ri.cw.p('err = ynl_recv_ack(ys, err);')
1347    ri.cw.p('if (err)')
1348    ri.cw.p('goto err_free;')
1349    ri.cw.nl()
1350    ri.cw.p(f"return {ret_ok};")
1351    ri.cw.nl()
1352    ri.cw.p('err_free:')
1353
1354    if 'reply' in ri.op[ri.op_mode]:
1355        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1356    ri.cw.p(f"return {ret_err};")
1357    ri.cw.block_end()
1358
1359
1360def print_dump(ri):
1361    direction = "request"
1362    print_prototype(ri, direction, terminate=False)
1363    ri.cw.block_start()
1364    local_vars = ['struct ynl_dump_state yds = {};',
1365                  'struct nlmsghdr *nlh;',
1366                  'int len, err;']
1367
1368    for var in local_vars:
1369        ri.cw.p(f'{var}')
1370    ri.cw.nl()
1371
1372    ri.cw.p('yds.ys = ys;')
1373    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1374    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1375    ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1376    ri.cw.nl()
1377    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1378
1379    if "request" in ri.op[ri.op_mode]:
1380        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1381        ri.cw.nl()
1382        for _, attr in ri.struct["request"].member_list():
1383            attr.attr_put(ri, "req")
1384    ri.cw.nl()
1385
1386    ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);')
1387    ri.cw.p('if (err < 0)')
1388    ri.cw.p('return NULL;')
1389    ri.cw.nl()
1390
1391    ri.cw.block_start(line='do')
1392    ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1393    ri.cw.p('if (len < 0)')
1394    ri.cw.p('goto free_list;')
1395    ri.cw.nl()
1396    ri.cw.p(f"err = {ri.nl.parse_cb_run('ynl_dump_trampoline', '&yds', False, indent=2)};")
1397    ri.cw.p('if (err < 0)')
1398    ri.cw.p('goto free_list;')
1399    ri.cw.block_end(line='while (err > 0);')
1400    ri.cw.nl()
1401
1402    ri.cw.p('return yds.first;')
1403    ri.cw.nl()
1404    ri.cw.p('free_list:')
1405    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1406    ri.cw.p('return NULL;')
1407    ri.cw.block_end()
1408
1409
1410def call_free(ri, direction, var):
1411    return f"{op_prefix(ri, direction)}_free({var});"
1412
1413
1414def free_arg_name(direction):
1415    if direction:
1416        return direction_to_suffix[direction][1:]
1417    return 'obj'
1418
1419
1420def print_free_prototype(ri, direction, suffix=';'):
1421    name = op_prefix(ri, direction)
1422    arg = free_arg_name(direction)
1423    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {name} *{arg}"], suffix=suffix)
1424
1425
1426def _print_type(ri, direction, struct):
1427    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1428
1429    if ri.op_mode == 'dump':
1430        suffix += '_dump'
1431
1432    ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1433
1434    meta_started = False
1435    for _, attr in struct.member_list():
1436        for type_filter in ['len', 'bit']:
1437            line = attr.presence_member(ri.ku_space, type_filter)
1438            if line:
1439                if not meta_started:
1440                    ri.cw.block_start(line=f"struct")
1441                    meta_started = True
1442                ri.cw.p(line)
1443    if meta_started:
1444        ri.cw.block_end(line='_present;')
1445        ri.cw.nl()
1446
1447    for arg in struct.inherited:
1448        ri.cw.p(f"__u32 {arg};")
1449
1450    for _, attr in struct.member_list():
1451        attr.struct_member(ri)
1452
1453    ri.cw.block_end(line=';')
1454    ri.cw.nl()
1455
1456
1457def print_type(ri, direction):
1458    _print_type(ri, direction, ri.struct[direction])
1459
1460
1461def print_type_full(ri, struct):
1462    _print_type(ri, "", struct)
1463
1464
1465def print_type_helpers(ri, direction, deref=False):
1466    print_free_prototype(ri, direction)
1467
1468    if ri.ku_space == 'user' and direction == 'request':
1469        for _, attr in ri.struct[direction].member_list():
1470            attr.setter(ri, ri.attr_set, direction, deref=deref)
1471    ri.cw.nl()
1472
1473
1474def print_req_type_helpers(ri):
1475    print_type_helpers(ri, "request")
1476
1477
1478def print_rsp_type_helpers(ri):
1479    if 'reply' not in ri.op[ri.op_mode]:
1480        return
1481    print_type_helpers(ri, "reply")
1482
1483
1484def print_parse_prototype(ri, direction, terminate=True):
1485    suffix = "_rsp" if direction == "reply" else "_req"
1486    term = ';' if terminate else ''
1487
1488    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1489                          ['const struct nlattr **tb',
1490                           f"struct {ri.op.render_name}{suffix} *req"],
1491                          suffix=term)
1492
1493
1494def print_req_type(ri):
1495    print_type(ri, "request")
1496
1497
1498def print_rsp_type(ri):
1499    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1500        direction = 'reply'
1501    elif ri.op_mode == 'event':
1502        direction = 'reply'
1503    else:
1504        return
1505    print_type(ri, direction)
1506
1507
1508def print_wrapped_type(ri):
1509    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1510    if ri.op_mode == 'dump':
1511        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1512    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1513        ri.cw.p('__u16 family;')
1514        ri.cw.p('__u8 cmd;')
1515        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1516    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1517    ri.cw.block_end(line=';')
1518    ri.cw.nl()
1519    print_free_prototype(ri, 'reply')
1520    ri.cw.nl()
1521
1522
1523def _free_type_members_iter(ri, struct):
1524    for _, attr in struct.member_list():
1525        if attr.free_needs_iter():
1526            ri.cw.p('unsigned int i;')
1527            ri.cw.nl()
1528            break
1529
1530
1531def _free_type_members(ri, var, struct, ref=''):
1532    for _, attr in struct.member_list():
1533        attr.free(ri, var, ref)
1534
1535
1536def _free_type(ri, direction, struct):
1537    var = free_arg_name(direction)
1538
1539    print_free_prototype(ri, direction, suffix='')
1540    ri.cw.block_start()
1541    _free_type_members_iter(ri, struct)
1542    _free_type_members(ri, var, struct)
1543    if direction:
1544        ri.cw.p(f'free({var});')
1545    ri.cw.block_end()
1546    ri.cw.nl()
1547
1548
1549def free_rsp_nested(ri, struct):
1550    _free_type(ri, "", struct)
1551
1552
1553def print_rsp_free(ri):
1554    if 'reply' not in ri.op[ri.op_mode]:
1555        return
1556    _free_type(ri, 'reply', ri.struct['reply'])
1557
1558
1559def print_dump_type_free(ri):
1560    sub_type = type_name(ri, 'reply')
1561
1562    print_free_prototype(ri, 'reply', suffix='')
1563    ri.cw.block_start()
1564    ri.cw.p(f"{sub_type} *next = rsp;")
1565    ri.cw.nl()
1566    ri.cw.block_start(line='while (next)')
1567    _free_type_members_iter(ri, ri.struct['reply'])
1568    ri.cw.p('rsp = next;')
1569    ri.cw.p('next = rsp->next;')
1570    ri.cw.nl()
1571
1572    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1573    ri.cw.p(f'free(rsp);')
1574    ri.cw.block_end()
1575    ri.cw.block_end()
1576    ri.cw.nl()
1577
1578
1579def print_ntf_type_free(ri):
1580    print_free_prototype(ri, 'reply', suffix='')
1581    ri.cw.block_start()
1582    _free_type_members_iter(ri, ri.struct['reply'])
1583    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1584    ri.cw.p(f'free(rsp);')
1585    ri.cw.block_end()
1586    ri.cw.nl()
1587
1588
1589def print_ntf_parse_prototype(family, cw, suffix=';'):
1590    cw.write_func_prot('struct ynl_ntf_base_type *', f"{family['name']}_ntf_parse",
1591                       ['struct ynl_sock *ys'], suffix=suffix)
1592
1593
1594def print_ntf_type_parse(family, cw, ku_mode):
1595    print_ntf_parse_prototype(family, cw, suffix='')
1596    cw.block_start()
1597    cw.write_func_lvar(['struct genlmsghdr *genlh;',
1598                        'struct nlmsghdr *nlh;',
1599                        'struct ynl_parse_arg yarg = { .ys = ys, };',
1600                        'struct ynl_ntf_base_type *rsp;',
1601                        'int len, err;',
1602                        'mnl_cb_t parse;'])
1603    cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1604    cw.p('if (len < (ssize_t)(sizeof(*nlh) + sizeof(*genlh)))')
1605    cw.p('return NULL;')
1606    cw.nl()
1607    cw.p('nlh = (struct nlmsghdr *)ys->rx_buf;')
1608    cw.p('genlh = mnl_nlmsg_get_payload(nlh);')
1609    cw.nl()
1610    cw.block_start(line='switch (genlh->cmd)')
1611    for ntf_op in sorted(family.all_notify.keys()):
1612        op = family.ops[ntf_op]
1613        ri = RenderInfo(cw, family, ku_mode, op, ntf_op, "notify")
1614        for ntf in op['notify']['cmds']:
1615            cw.p(f"case {ntf.enum_name}:")
1616        cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'notify')}));")
1617        cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1618        cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1619        cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1620        cw.p('break;')
1621    for op_name, op in family.ops.items():
1622        if 'event' not in op:
1623            continue
1624        ri = RenderInfo(cw, family, ku_mode, op, op_name, "event")
1625        cw.p(f"case {op.enum_name}:")
1626        cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'event')}));")
1627        cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1628        cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1629        cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1630        cw.p('break;')
1631    cw.p('default:')
1632    cw.p('ynl_error_unknown_notification(ys, genlh->cmd);')
1633    cw.p('return NULL;')
1634    cw.block_end()
1635    cw.nl()
1636    cw.p('yarg.data = rsp->data;')
1637    cw.nl()
1638    cw.p(f"err = {cw.nlib.parse_cb_run('parse', '&yarg', True)};")
1639    cw.p('if (err < 0)')
1640    cw.p('goto err_free;')
1641    cw.nl()
1642    cw.p('rsp->family = nlh->nlmsg_type;')
1643    cw.p('rsp->cmd = genlh->cmd;')
1644    cw.p('return rsp;')
1645    cw.nl()
1646    cw.p('err_free:')
1647    cw.p('free(rsp);')
1648    cw.p('return NULL;')
1649    cw.block_end()
1650    cw.nl()
1651
1652
1653def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1654    if terminate and ri and kernel_can_gen_family_struct(struct.family):
1655        return
1656
1657    if terminate:
1658        prefix = 'extern '
1659    else:
1660        if kernel_can_gen_family_struct(struct.family) and ri:
1661            prefix = 'static '
1662        else:
1663            prefix = ''
1664
1665    suffix = ';' if terminate else ' = {'
1666
1667    max_attr = struct.attr_max_val
1668    if ri:
1669        name = ri.op.render_name
1670        if ri.op.dual_policy:
1671            name += '_' + ri.op_mode
1672    else:
1673        name = struct.render_name
1674    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1675
1676
1677def print_req_policy(cw, struct, ri=None):
1678    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1679    for _, arg in struct.member_list():
1680        arg.attr_policy(cw)
1681    cw.p("};")
1682
1683
1684def kernel_can_gen_family_struct(family):
1685    return family.proto == 'genetlink'
1686
1687
1688def print_kernel_op_table_fwd(family, cw, terminate):
1689    exported = not kernel_can_gen_family_struct(family)
1690
1691    if not terminate or exported:
1692        cw.p(f"/* Ops table for {family.name} */")
1693
1694        pol_to_struct = {'global': 'genl_small_ops',
1695                         'per-op': 'genl_ops',
1696                         'split': 'genl_split_ops'}
1697        struct_type = pol_to_struct[family.kernel_policy]
1698
1699        if family.kernel_policy == 'split':
1700            cnt = 0
1701            for op in family.ops.values():
1702                if 'do' in op:
1703                    cnt += 1
1704                if 'dump' in op:
1705                    cnt += 1
1706        else:
1707            cnt = len(family.ops)
1708
1709        qual = 'static const' if not exported else 'const'
1710        line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1711        if terminate:
1712            cw.p(f"extern {line};")
1713        else:
1714            cw.block_start(line=line + ' =')
1715
1716    if not terminate:
1717        return
1718
1719    cw.nl()
1720    for name in family.hooks['pre']['do']['list']:
1721        cw.write_func_prot('int', c_lower(name),
1722                           ['const struct genl_split_ops *ops',
1723                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1724    for name in family.hooks['post']['do']['list']:
1725        cw.write_func_prot('void', c_lower(name),
1726                           ['const struct genl_split_ops *ops',
1727                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1728    for name in family.hooks['pre']['dump']['list']:
1729        cw.write_func_prot('int', c_lower(name),
1730                           ['struct netlink_callback *cb'], suffix=';')
1731    for name in family.hooks['post']['dump']['list']:
1732        cw.write_func_prot('int', c_lower(name),
1733                           ['struct netlink_callback *cb'], suffix=';')
1734
1735    cw.nl()
1736
1737    for op_name, op in family.ops.items():
1738        if op.is_async:
1739            continue
1740
1741        if 'do' in op:
1742            name = c_lower(f"{family.name}-nl-{op_name}-doit")
1743            cw.write_func_prot('int', name,
1744                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1745
1746        if 'dump' in op:
1747            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1748            cw.write_func_prot('int', name,
1749                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1750    cw.nl()
1751
1752
1753def print_kernel_op_table_hdr(family, cw):
1754    print_kernel_op_table_fwd(family, cw, terminate=True)
1755
1756
1757def print_kernel_op_table(family, cw):
1758    print_kernel_op_table_fwd(family, cw, terminate=False)
1759    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1760        for op_name, op in family.ops.items():
1761            if op.is_async:
1762                continue
1763
1764            cw.block_start()
1765            members = [('cmd', op.enum_name)]
1766            if 'dont-validate' in op:
1767                members.append(('validate',
1768                                ' | '.join([c_upper('genl-dont-validate-' + x)
1769                                            for x in op['dont-validate']])), )
1770            for op_mode in ['do', 'dump']:
1771                if op_mode in op:
1772                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1773                    members.append((op_mode + 'it', name))
1774            if family.kernel_policy == 'per-op':
1775                struct = Struct(family, op['attribute-set'],
1776                                type_list=op['do']['request']['attributes'])
1777
1778                name = c_lower(f"{family.name}-{op_name}-nl-policy")
1779                members.append(('policy', name))
1780                members.append(('maxattr', struct.attr_max_val.enum_name))
1781            if 'flags' in op:
1782                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1783            cw.write_struct_init(members)
1784            cw.block_end(line=',')
1785    elif family.kernel_policy == 'split':
1786        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1787                    'dump': {'pre': 'start', 'post': 'done'}}
1788
1789        for op_name, op in family.ops.items():
1790            for op_mode in ['do', 'dump']:
1791                if op.is_async or op_mode not in op:
1792                    continue
1793
1794                cw.block_start()
1795                members = [('cmd', op.enum_name)]
1796                if 'dont-validate' in op:
1797                    members.append(('validate',
1798                                    ' | '.join([c_upper('genl-dont-validate-' + x)
1799                                                for x in op['dont-validate']])), )
1800                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1801                if 'pre' in op[op_mode]:
1802                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
1803                members.append((op_mode + 'it', name))
1804                if 'post' in op[op_mode]:
1805                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
1806                if 'request' in op[op_mode]:
1807                    struct = Struct(family, op['attribute-set'],
1808                                    type_list=op[op_mode]['request']['attributes'])
1809
1810                    if op.dual_policy:
1811                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
1812                    else:
1813                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
1814                    members.append(('policy', name))
1815                    members.append(('maxattr', struct.attr_max_val.enum_name))
1816                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
1817                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
1818                cw.write_struct_init(members)
1819                cw.block_end(line=',')
1820
1821    cw.block_end(line=';')
1822    cw.nl()
1823
1824
1825def print_kernel_mcgrp_hdr(family, cw):
1826    if not family.mcgrps['list']:
1827        return
1828
1829    cw.block_start('enum')
1830    for grp in family.mcgrps['list']:
1831        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
1832        cw.p(grp_id)
1833    cw.block_end(';')
1834    cw.nl()
1835
1836
1837def print_kernel_mcgrp_src(family, cw):
1838    if not family.mcgrps['list']:
1839        return
1840
1841    cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
1842    for grp in family.mcgrps['list']:
1843        name = grp['name']
1844        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
1845        cw.p('[' + grp_id + '] = { "' + name + '", },')
1846    cw.block_end(';')
1847    cw.nl()
1848
1849
1850def print_kernel_family_struct_hdr(family, cw):
1851    if not kernel_can_gen_family_struct(family):
1852        return
1853
1854    cw.p(f"extern struct genl_family {family.name}_nl_family;")
1855    cw.nl()
1856
1857
1858def print_kernel_family_struct_src(family, cw):
1859    if not kernel_can_gen_family_struct(family):
1860        return
1861
1862    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
1863    cw.p('.name\t\t= ' + family.fam_key + ',')
1864    cw.p('.version\t= ' + family.ver_key + ',')
1865    cw.p('.netnsok\t= true,')
1866    cw.p('.parallel_ops\t= true,')
1867    cw.p('.module\t\t= THIS_MODULE,')
1868    if family.kernel_policy == 'per-op':
1869        cw.p(f'.ops\t\t= {family.name}_nl_ops,')
1870        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
1871    elif family.kernel_policy == 'split':
1872        cw.p(f'.split_ops\t= {family.name}_nl_ops,')
1873        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
1874    if family.mcgrps['list']:
1875        cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
1876        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
1877    cw.block_end(';')
1878
1879
1880def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
1881    start_line = 'enum'
1882    if enum_name in obj:
1883        if obj[enum_name]:
1884            start_line = 'enum ' + c_lower(obj[enum_name])
1885    elif ckey and ckey in obj:
1886        start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
1887    cw.block_start(line=start_line)
1888
1889
1890def render_uapi(family, cw):
1891    hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
1892    cw.p('#ifndef ' + hdr_prot)
1893    cw.p('#define ' + hdr_prot)
1894    cw.nl()
1895
1896    defines = [(family.fam_key, family["name"]),
1897               (family.ver_key, family.get('version', 1))]
1898    cw.writes_defines(defines)
1899    cw.nl()
1900
1901    defines = []
1902    for const in family['definitions']:
1903        if const['type'] != 'const':
1904            cw.writes_defines(defines)
1905            defines = []
1906            cw.nl()
1907
1908        # Write kdoc for enum and flags (one day maybe also structs)
1909        if const['type'] == 'enum' or const['type'] == 'flags':
1910            enum = family.consts[const['name']]
1911
1912            if enum.has_doc():
1913                cw.p('/**')
1914                doc = ''
1915                if 'doc' in enum:
1916                    doc = ' - ' + enum['doc']
1917                cw.write_doc_line(enum.enum_name + doc)
1918                for entry in enum.entries.values():
1919                    if entry.has_doc():
1920                        doc = '@' + entry.c_name + ': ' + entry['doc']
1921                        cw.write_doc_line(doc)
1922                cw.p(' */')
1923
1924            uapi_enum_start(family, cw, const, 'name')
1925            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
1926            for entry in enum.entries.values():
1927                suffix = ','
1928                if entry.value_change:
1929                    suffix = f" = {entry.user_value()}" + suffix
1930                cw.p(entry.c_name + suffix)
1931
1932            if const.get('render-max', False):
1933                cw.nl()
1934                max_name = c_upper(name_pfx + 'max')
1935                cw.p('__' + max_name + ',')
1936                cw.p(max_name + ' = (__' + max_name + ' - 1)')
1937            cw.block_end(line=';')
1938            cw.nl()
1939        elif const['type'] == 'const':
1940            defines.append([c_upper(family.get('c-define-name',
1941                                               f"{family.name}-{const['name']}")),
1942                            const['value']])
1943
1944    if defines:
1945        cw.writes_defines(defines)
1946        cw.nl()
1947
1948    max_by_define = family.get('max-by-define', False)
1949
1950    for _, attr_set in family.attr_sets.items():
1951        if attr_set.subset_of:
1952            continue
1953
1954        cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
1955        max_value = f"({cnt_name} - 1)"
1956
1957        val = 0
1958        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
1959        for _, attr in attr_set.items():
1960            suffix = ','
1961            if attr.value != val:
1962                suffix = f" = {attr.value},"
1963                val = attr.value
1964            val += 1
1965            cw.p(attr.enum_name + suffix)
1966        cw.nl()
1967        cw.p(cnt_name + ('' if max_by_define else ','))
1968        if not max_by_define:
1969            cw.p(f"{attr_set.max_name} = {max_value}")
1970        cw.block_end(line=';')
1971        if max_by_define:
1972            cw.p(f"#define {attr_set.max_name} {max_value}")
1973        cw.nl()
1974
1975    # Commands
1976    separate_ntf = 'async-prefix' in family['operations']
1977
1978    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
1979    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
1980    max_value = f"({cnt_name} - 1)"
1981
1982    uapi_enum_start(family, cw, family['operations'], 'enum-name')
1983    val = 0
1984    for op in family.msgs.values():
1985        if separate_ntf and ('notify' in op or 'event' in op):
1986            continue
1987
1988        suffix = ','
1989        if op.value != val:
1990            suffix = f" = {op.value},"
1991            val = op.value
1992        cw.p(op.enum_name + suffix)
1993        val += 1
1994    cw.nl()
1995    cw.p(cnt_name + ('' if max_by_define else ','))
1996    if not max_by_define:
1997        cw.p(f"{max_name} = {max_value}")
1998    cw.block_end(line=';')
1999    if max_by_define:
2000        cw.p(f"#define {max_name} {max_value}")
2001    cw.nl()
2002
2003    if separate_ntf:
2004        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2005        for op in family.msgs.values():
2006            if separate_ntf and not ('notify' in op or 'event' in op):
2007                continue
2008
2009            suffix = ','
2010            if 'value' in op:
2011                suffix = f" = {op['value']},"
2012            cw.p(op.enum_name + suffix)
2013        cw.block_end(line=';')
2014        cw.nl()
2015
2016    # Multicast
2017    defines = []
2018    for grp in family.mcgrps['list']:
2019        name = grp['name']
2020        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2021                        f'{name}'])
2022    cw.nl()
2023    if defines:
2024        cw.writes_defines(defines)
2025        cw.nl()
2026
2027    cw.p(f'#endif /* {hdr_prot} */')
2028
2029
2030def find_kernel_root(full_path):
2031    sub_path = ''
2032    while True:
2033        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2034        full_path = os.path.dirname(full_path)
2035        maintainers = os.path.join(full_path, "MAINTAINERS")
2036        if os.path.exists(maintainers):
2037            return full_path, sub_path[:-1]
2038
2039
2040def main():
2041    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2042    parser.add_argument('--mode', dest='mode', type=str, required=True)
2043    parser.add_argument('--spec', dest='spec', type=str, required=True)
2044    parser.add_argument('--header', dest='header', action='store_true', default=None)
2045    parser.add_argument('--source', dest='header', action='store_false')
2046    parser.add_argument('--user-header', nargs='+', default=[])
2047    parser.add_argument('-o', dest='out_file', type=str)
2048    args = parser.parse_args()
2049
2050    out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout
2051
2052    if args.header is None:
2053        parser.error("--header or --source is required")
2054
2055    try:
2056        parsed = Family(args.spec)
2057    except yaml.YAMLError as exc:
2058        print(exc)
2059        os.sys.exit(1)
2060        return
2061
2062    cw = CodeWriter(BaseNlLib(), out_file)
2063
2064    _, spec_kernel = find_kernel_root(args.spec)
2065    if args.mode == 'uapi':
2066        cw.p('/* SPDX-License-Identifier: (GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause */')
2067    else:
2068        if args.header:
2069            cw.p('/* SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause */')
2070        else:
2071            cw.p('// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause')
2072    cw.p("/* Do not edit directly, auto-generated from: */")
2073    cw.p(f"/*\t{spec_kernel} */")
2074    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2075    cw.nl()
2076
2077    if args.mode == 'uapi':
2078        render_uapi(parsed, cw)
2079        return
2080
2081    hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2082    if args.header:
2083        cw.p('#ifndef ' + hdr_prot)
2084        cw.p('#define ' + hdr_prot)
2085        cw.nl()
2086
2087    if args.mode == 'kernel':
2088        cw.p('#include <net/netlink.h>')
2089        cw.p('#include <net/genetlink.h>')
2090        cw.nl()
2091        if not args.header:
2092            if args.out_file:
2093                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2094            cw.nl()
2095    headers = [parsed.uapi_header]
2096    for definition in parsed['definitions']:
2097        if 'header' in definition:
2098            headers.append(definition['header'])
2099    for one in headers:
2100        cw.p(f"#include <{one}>")
2101    cw.nl()
2102
2103    if args.mode == "user":
2104        if not args.header:
2105            cw.p("#include <stdlib.h>")
2106            cw.p("#include <stdio.h>")
2107            cw.p("#include <string.h>")
2108            cw.p("#include <libmnl/libmnl.h>")
2109            cw.p("#include <linux/genetlink.h>")
2110            cw.nl()
2111            for one in args.user_header:
2112                cw.p(f'#include "{one}"')
2113        else:
2114            cw.p('struct ynl_sock;')
2115        cw.nl()
2116
2117    if args.mode == "kernel":
2118        if args.header:
2119            for _, struct in sorted(parsed.pure_nested_structs.items()):
2120                if struct.request:
2121                    cw.p('/* Common nested types */')
2122                    break
2123            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2124                if struct.request:
2125                    print_req_policy_fwd(cw, struct)
2126            cw.nl()
2127
2128            if parsed.kernel_policy == 'global':
2129                cw.p(f"/* Global operation policy for {parsed.name} */")
2130
2131                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2132                print_req_policy_fwd(cw, struct)
2133                cw.nl()
2134
2135            if parsed.kernel_policy in {'per-op', 'split'}:
2136                for op_name, op in parsed.ops.items():
2137                    if 'do' in op and 'event' not in op:
2138                        ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2139                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2140                        cw.nl()
2141
2142            print_kernel_op_table_hdr(parsed, cw)
2143            print_kernel_mcgrp_hdr(parsed, cw)
2144            print_kernel_family_struct_hdr(parsed, cw)
2145        else:
2146            for _, struct in sorted(parsed.pure_nested_structs.items()):
2147                if struct.request:
2148                    cw.p('/* Common nested types */')
2149                    break
2150            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2151                if struct.request:
2152                    print_req_policy(cw, struct)
2153            cw.nl()
2154
2155            if parsed.kernel_policy == 'global':
2156                cw.p(f"/* Global operation policy for {parsed.name} */")
2157
2158                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2159                print_req_policy(cw, struct)
2160                cw.nl()
2161
2162            for op_name, op in parsed.ops.items():
2163                if parsed.kernel_policy in {'per-op', 'split'}:
2164                    for op_mode in ['do', 'dump']:
2165                        if op_mode in op and 'request' in op[op_mode]:
2166                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2167                            ri = RenderInfo(cw, parsed, args.mode, op, op_name, op_mode)
2168                            print_req_policy(cw, ri.struct['request'], ri=ri)
2169                            cw.nl()
2170
2171            print_kernel_op_table(parsed, cw)
2172            print_kernel_mcgrp_src(parsed, cw)
2173            print_kernel_family_struct_src(parsed, cw)
2174
2175    if args.mode == "user":
2176        has_ntf = False
2177        if args.header:
2178            cw.p('/* Common nested types */')
2179            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2180                ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2181                print_type_full(ri, struct)
2182
2183            for op_name, op in parsed.ops.items():
2184                cw.p(f"/* ============== {op.enum_name} ============== */")
2185
2186                if 'do' in op and 'event' not in op:
2187                    cw.p(f"/* {op.enum_name} - do */")
2188                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2189                    print_req_type(ri)
2190                    print_req_type_helpers(ri)
2191                    cw.nl()
2192                    print_rsp_type(ri)
2193                    print_rsp_type_helpers(ri)
2194                    cw.nl()
2195                    print_req_prototype(ri)
2196                    cw.nl()
2197
2198                if 'dump' in op:
2199                    cw.p(f"/* {op.enum_name} - dump */")
2200                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'dump')
2201                    if 'request' in op['dump']:
2202                        print_req_type(ri)
2203                        print_req_type_helpers(ri)
2204                    if not ri.type_consistent:
2205                        print_rsp_type(ri)
2206                    print_wrapped_type(ri)
2207                    print_dump_prototype(ri)
2208                    cw.nl()
2209
2210                if 'notify' in op:
2211                    cw.p(f"/* {op.enum_name} - notify */")
2212                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2213                    has_ntf = True
2214                    if not ri.type_consistent:
2215                        raise Exception('Only notifications with consistent types supported')
2216                    print_wrapped_type(ri)
2217
2218                if 'event' in op:
2219                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'event')
2220                    cw.p(f"/* {op.enum_name} - event */")
2221                    print_rsp_type(ri)
2222                    cw.nl()
2223                    print_wrapped_type(ri)
2224
2225            if has_ntf:
2226                cw.p('/* --------------- Common notification parsing --------------- */')
2227                print_ntf_parse_prototype(parsed, cw)
2228            cw.nl()
2229        else:
2230            cw.p('/* Policies */')
2231            for name, _ in parsed.attr_sets.items():
2232                struct = Struct(parsed, name)
2233                put_typol_fwd(cw, struct)
2234            cw.nl()
2235
2236            for name, _ in parsed.attr_sets.items():
2237                struct = Struct(parsed, name)
2238                put_typol(cw, struct)
2239
2240            cw.p('/* Common nested types */')
2241            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2242                ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2243
2244                free_rsp_nested(ri, struct)
2245                if struct.request:
2246                    put_req_nested(ri, struct)
2247                if struct.reply:
2248                    parse_rsp_nested(ri, struct)
2249
2250            for op_name, op in parsed.ops.items():
2251                cw.p(f"/* ============== {op.enum_name} ============== */")
2252                if 'do' in op and 'event' not in op:
2253                    cw.p(f"/* {op.enum_name} - do */")
2254                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2255                    print_rsp_free(ri)
2256                    parse_rsp_msg(ri)
2257                    print_req(ri)
2258                    cw.nl()
2259
2260                if 'dump' in op:
2261                    cw.p(f"/* {op.enum_name} - dump */")
2262                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "dump")
2263                    if not ri.type_consistent:
2264                        parse_rsp_msg(ri, deref=True)
2265                    print_dump_type_free(ri)
2266                    print_dump(ri)
2267                    cw.nl()
2268
2269                if 'notify' in op:
2270                    cw.p(f"/* {op.enum_name} - notify */")
2271                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2272                    has_ntf = True
2273                    if not ri.type_consistent:
2274                        raise Exception('Only notifications with consistent types supported')
2275                    print_ntf_type_free(ri)
2276
2277                if 'event' in op:
2278                    cw.p(f"/* {op.enum_name} - event */")
2279                    has_ntf = True
2280
2281                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2282                    parse_rsp_msg(ri)
2283
2284                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "event")
2285                    print_ntf_type_free(ri)
2286
2287            if has_ntf:
2288                cw.p('/* --------------- Common notification parsing --------------- */')
2289                print_ntf_type_parse(parsed, cw, args.mode)
2290
2291    if args.header:
2292        cw.p(f'#endif /* {hdr_prot} */')
2293
2294
2295if __name__ == "__main__":
2296    main()
2297