xref: /openbmc/linux/tools/net/ynl/ynl-gen-c.py (revision d55595f0)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import os
7import re
8import shutil
9import tempfile
10import yaml
11
12from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
13
14
15def c_upper(name):
16    return name.upper().replace('-', '_')
17
18
19def c_lower(name):
20    return name.lower().replace('-', '_')
21
22
23class BaseNlLib:
24    def get_family_id(self):
25        return 'ys->family_id'
26
27    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
28        ind = '\n\t\t' + '\t' * indent + ' '
29        if is_dump:
30            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
31        else:
32            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
33                   "ynl_cb_array, NLMSG_MIN_TYPE)"
34
35
36class Type(SpecAttr):
37    def __init__(self, family, attr_set, attr, value):
38        super().__init__(family, attr_set, attr, value)
39
40        self.attr = attr
41        self.attr_set = attr_set
42        self.type = attr['type']
43        self.checks = attr.get('checks', {})
44
45        if 'len' in attr:
46            self.len = attr['len']
47        if 'nested-attributes' in attr:
48            self.nested_attrs = attr['nested-attributes']
49            if self.nested_attrs == family.name:
50                self.nested_render_name = f"{family.name}"
51            else:
52                self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
53
54            if self.nested_attrs in self.family.consts:
55                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
56            else:
57                self.nested_struct_type = 'struct ' + self.nested_render_name
58
59        self.c_name = c_lower(self.name)
60        if self.c_name in _C_KW:
61            self.c_name += '_'
62
63        # Added by resolve():
64        self.enum_name = None
65        delattr(self, "enum_name")
66
67    def resolve(self):
68        if 'name-prefix' in self.attr:
69            enum_name = f"{self.attr['name-prefix']}{self.name}"
70        else:
71            enum_name = f"{self.attr_set.name_prefix}{self.name}"
72        self.enum_name = c_upper(enum_name)
73
74    def is_multi_val(self):
75        return None
76
77    def is_scalar(self):
78        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
79
80    def presence_type(self):
81        return 'bit'
82
83    def presence_member(self, space, type_filter):
84        if self.presence_type() != type_filter:
85            return
86
87        if self.presence_type() == 'bit':
88            pfx = '__' if space == 'user' else ''
89            return f"{pfx}u32 {self.c_name}:1;"
90
91        if self.presence_type() == 'len':
92            pfx = '__' if space == 'user' else ''
93            return f"{pfx}u32 {self.c_name}_len;"
94
95    def _complex_member_type(self, ri):
96        return None
97
98    def free_needs_iter(self):
99        return False
100
101    def free(self, ri, var, ref):
102        if self.is_multi_val() or self.presence_type() == 'len':
103            ri.cw.p(f'free({var}->{ref}{self.c_name});')
104
105    def arg_member(self, ri):
106        member = self._complex_member_type(ri)
107        if member:
108            arg = [member + ' *' + self.c_name]
109            if self.presence_type() == 'count':
110                arg += ['unsigned int n_' + self.c_name]
111            return arg
112        raise Exception(f"Struct member not implemented for class type {self.type}")
113
114    def struct_member(self, ri):
115        if self.is_multi_val():
116            ri.cw.p(f"unsigned int n_{self.c_name};")
117        member = self._complex_member_type(ri)
118        if member:
119            ptr = '*' if self.is_multi_val() else ''
120            ri.cw.p(f"{member} {ptr}{self.c_name};")
121            return
122        members = self.arg_member(ri)
123        for one in members:
124            ri.cw.p(one + ';')
125
126    def _attr_policy(self, policy):
127        return '{ .type = ' + policy + ', }'
128
129    def attr_policy(self, cw):
130        policy = c_upper('nla-' + self.attr['type'])
131
132        spec = self._attr_policy(policy)
133        cw.p(f"\t[{self.enum_name}] = {spec},")
134
135    def _attr_typol(self):
136        raise Exception(f"Type policy not implemented for class type {self.type}")
137
138    def attr_typol(self, cw):
139        typol = self._attr_typol()
140        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
141
142    def _attr_put_line(self, ri, var, line):
143        if self.presence_type() == 'bit':
144            ri.cw.p(f"if ({var}->_present.{self.c_name})")
145        elif self.presence_type() == 'len':
146            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
147        ri.cw.p(f"{line};")
148
149    def _attr_put_simple(self, ri, var, put_type):
150        line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
151        self._attr_put_line(ri, var, line)
152
153    def attr_put(self, ri, var):
154        raise Exception(f"Put not implemented for class type {self.type}")
155
156    def _attr_get(self, ri, var):
157        raise Exception(f"Attr get not implemented for class type {self.type}")
158
159    def attr_get(self, ri, var, first):
160        lines, init_lines, local_vars = self._attr_get(ri, var)
161        if type(lines) is str:
162            lines = [lines]
163        if type(init_lines) is str:
164            init_lines = [init_lines]
165
166        kw = 'if' if first else 'else if'
167        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
168        if local_vars:
169            for local in local_vars:
170                ri.cw.p(local)
171            ri.cw.nl()
172
173        if not self.is_multi_val():
174            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
175            ri.cw.p("return MNL_CB_ERROR;")
176            if self.presence_type() == 'bit':
177                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
178
179        if init_lines:
180            ri.cw.nl()
181            for line in init_lines:
182                ri.cw.p(line)
183
184        for line in lines:
185            ri.cw.p(line)
186        ri.cw.block_end()
187        return True
188
189    def _setter_lines(self, ri, member, presence):
190        raise Exception(f"Setter not implemented for class type {self.type}")
191
192    def setter(self, ri, space, direction, deref=False, ref=None):
193        ref = (ref if ref else []) + [self.c_name]
194        var = "req"
195        member = f"{var}->{'.'.join(ref)}"
196
197        code = []
198        presence = ''
199        for i in range(0, len(ref)):
200            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
201            if self.presence_type() == 'bit':
202                code.append(presence + ' = 1;')
203        code += self._setter_lines(ri, member, presence)
204
205        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
206        free = bool([x for x in code if 'free(' in x])
207        alloc = bool([x for x in code if 'alloc(' in x])
208        if free and not alloc:
209            func_name = '__' + func_name
210        ri.cw.write_func('static inline void', func_name, body=code,
211                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
212
213
214class TypeUnused(Type):
215    def presence_type(self):
216        return ''
217
218    def arg_member(self, ri):
219        return []
220
221    def _attr_get(self, ri, var):
222        return ['return MNL_CB_ERROR;'], None, None
223
224    def _attr_typol(self):
225        return '.type = YNL_PT_REJECT, '
226
227    def attr_policy(self, cw):
228        pass
229
230
231class TypePad(Type):
232    def presence_type(self):
233        return ''
234
235    def arg_member(self, ri):
236        return []
237
238    def _attr_typol(self):
239        return '.type = YNL_PT_IGNORE, '
240
241    def attr_put(self, ri, var):
242        pass
243
244    def attr_get(self, ri, var, first):
245        pass
246
247    def attr_policy(self, cw):
248        pass
249
250    def setter(self, ri, space, direction, deref=False, ref=None):
251        pass
252
253
254class TypeScalar(Type):
255    def __init__(self, family, attr_set, attr, value):
256        super().__init__(family, attr_set, attr, value)
257
258        self.byte_order_comment = ''
259        if 'byte-order' in attr:
260            self.byte_order_comment = f" /* {attr['byte-order']} */"
261
262        # Added by resolve():
263        self.is_bitfield = None
264        delattr(self, "is_bitfield")
265        self.type_name = None
266        delattr(self, "type_name")
267
268    def resolve(self):
269        self.resolve_up(super())
270
271        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
272            self.is_bitfield = True
273        elif 'enum' in self.attr:
274            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
275        else:
276            self.is_bitfield = False
277
278        maybe_enum = not self.is_bitfield and 'enum' in self.attr
279        if maybe_enum and self.family.consts[self.attr['enum']].enum_name:
280            self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
281        else:
282            self.type_name = '__' + self.type
283
284    def _mnl_type(self):
285        t = self.type
286        # mnl does not have a helper for signed types
287        if t[0] == 's':
288            t = 'u' + t[1:]
289        return t
290
291    def _attr_policy(self, policy):
292        if 'flags-mask' in self.checks or self.is_bitfield:
293            if self.is_bitfield:
294                enum = self.family.consts[self.attr['enum']]
295                mask = enum.get_mask(as_flags=True)
296            else:
297                flags = self.family.consts[self.checks['flags-mask']]
298                flag_cnt = len(flags['entries'])
299                mask = (1 << flag_cnt) - 1
300            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
301        elif 'min' in self.checks:
302            return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
303        elif 'enum' in self.attr:
304            enum = self.family.consts[self.attr['enum']]
305            low, high = enum.value_range()
306            if low == 0:
307                return f"NLA_POLICY_MAX({policy}, {high})"
308            return f"NLA_POLICY_RANGE({policy}, {low}, {high})"
309        return super()._attr_policy(policy)
310
311    def _attr_typol(self):
312        return f'.type = YNL_PT_U{self.type[1:]}, '
313
314    def arg_member(self, ri):
315        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
316
317    def attr_put(self, ri, var):
318        self._attr_put_simple(ri, var, self._mnl_type())
319
320    def _attr_get(self, ri, var):
321        return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
322
323    def _setter_lines(self, ri, member, presence):
324        return [f"{member} = {self.c_name};"]
325
326
327class TypeFlag(Type):
328    def arg_member(self, ri):
329        return []
330
331    def _attr_typol(self):
332        return '.type = YNL_PT_FLAG, '
333
334    def attr_put(self, ri, var):
335        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
336
337    def _attr_get(self, ri, var):
338        return [], None, None
339
340    def _setter_lines(self, ri, member, presence):
341        return []
342
343
344class TypeString(Type):
345    def arg_member(self, ri):
346        return [f"const char *{self.c_name}"]
347
348    def presence_type(self):
349        return 'len'
350
351    def struct_member(self, ri):
352        ri.cw.p(f"char *{self.c_name};")
353
354    def _attr_typol(self):
355        return f'.type = YNL_PT_NUL_STR, '
356
357    def _attr_policy(self, policy):
358        mem = '{ .type = ' + policy
359        if 'max-len' in self.checks:
360            mem += ', .len = ' + str(self.checks['max-len'])
361        mem += ', }'
362        return mem
363
364    def attr_policy(self, cw):
365        if self.checks.get('unterminated-ok', False):
366            policy = 'NLA_STRING'
367        else:
368            policy = 'NLA_NUL_STRING'
369
370        spec = self._attr_policy(policy)
371        cw.p(f"\t[{self.enum_name}] = {spec},")
372
373    def attr_put(self, ri, var):
374        self._attr_put_simple(ri, var, 'strz')
375
376    def _attr_get(self, ri, var):
377        len_mem = var + '->_present.' + self.c_name + '_len'
378        return [f"{len_mem} = len;",
379                f"{var}->{self.c_name} = malloc(len + 1);",
380                f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
381                f"{var}->{self.c_name}[len] = 0;"], \
382               ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
383               ['unsigned int len;']
384
385    def _setter_lines(self, ri, member, presence):
386        return [f"free({member});",
387                f"{presence}_len = strlen({self.c_name});",
388                f"{member} = malloc({presence}_len + 1);",
389                f'memcpy({member}, {self.c_name}, {presence}_len);',
390                f'{member}[{presence}_len] = 0;']
391
392
393class TypeBinary(Type):
394    def arg_member(self, ri):
395        return [f"const void *{self.c_name}", 'size_t len']
396
397    def presence_type(self):
398        return 'len'
399
400    def struct_member(self, ri):
401        ri.cw.p(f"void *{self.c_name};")
402
403    def _attr_typol(self):
404        return f'.type = YNL_PT_BINARY,'
405
406    def _attr_policy(self, policy):
407        mem = '{ '
408        if len(self.checks) == 1 and 'min-len' in self.checks:
409            mem += '.len = ' + str(self.checks['min-len'])
410        elif len(self.checks) == 0:
411            mem += '.type = NLA_BINARY'
412        else:
413            raise Exception('One or more of binary type checks not implemented, yet')
414        mem += ', }'
415        return mem
416
417    def attr_put(self, ri, var):
418        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
419                            f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
420
421    def _attr_get(self, ri, var):
422        len_mem = var + '->_present.' + self.c_name + '_len'
423        return [f"{len_mem} = len;",
424                f"{var}->{self.c_name} = malloc(len);",
425                f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
426               ['len = mnl_attr_get_payload_len(attr);'], \
427               ['unsigned int len;']
428
429    def _setter_lines(self, ri, member, presence):
430        return [f"free({member});",
431                f"{presence}_len = len;",
432                f"{member} = malloc({presence}_len);",
433                f'memcpy({member}, {self.c_name}, {presence}_len);']
434
435
436class TypeNest(Type):
437    def _complex_member_type(self, ri):
438        return self.nested_struct_type
439
440    def free(self, ri, var, ref):
441        ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
442
443    def _attr_typol(self):
444        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
445
446    def _attr_policy(self, policy):
447        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
448
449    def attr_put(self, ri, var):
450        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
451                            f"{self.enum_name}, &{var}->{self.c_name})")
452
453    def _attr_get(self, ri, var):
454        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
455                     "return MNL_CB_ERROR;"]
456        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
457                      f"parg.data = &{var}->{self.c_name};"]
458        return get_lines, init_lines, None
459
460    def setter(self, ri, space, direction, deref=False, ref=None):
461        ref = (ref if ref else []) + [self.c_name]
462
463        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
464            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
465
466
467class TypeMultiAttr(Type):
468    def __init__(self, family, attr_set, attr, value, base_type):
469        super().__init__(family, attr_set, attr, value)
470
471        self.base_type = base_type
472
473    def is_multi_val(self):
474        return True
475
476    def presence_type(self):
477        return 'count'
478
479    def _mnl_type(self):
480        t = self.type
481        # mnl does not have a helper for signed types
482        if t[0] == 's':
483            t = 'u' + t[1:]
484        return t
485
486    def _complex_member_type(self, ri):
487        if 'type' not in self.attr or self.attr['type'] == 'nest':
488            return self.nested_struct_type
489        elif self.attr['type'] in scalars:
490            scalar_pfx = '__' if ri.ku_space == 'user' else ''
491            return scalar_pfx + self.attr['type']
492        else:
493            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
494
495    def free_needs_iter(self):
496        return 'type' not in self.attr or self.attr['type'] == 'nest'
497
498    def free(self, ri, var, ref):
499        if self.attr['type'] in scalars:
500            ri.cw.p(f"free({var}->{ref}{self.c_name});")
501        elif 'type' not in self.attr or self.attr['type'] == 'nest':
502            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
503            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
504            ri.cw.p(f"free({var}->{ref}{self.c_name});")
505        else:
506            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
507
508    def _attr_policy(self, policy):
509        return self.base_type._attr_policy(policy)
510
511    def _attr_typol(self):
512        return self.base_type._attr_typol()
513
514    def _attr_get(self, ri, var):
515        return f'n_{self.c_name}++;', None, None
516
517    def attr_put(self, ri, var):
518        if self.attr['type'] in scalars:
519            put_type = self._mnl_type()
520            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
521            ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
522        elif 'type' not in self.attr or self.attr['type'] == 'nest':
523            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
524            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
525                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
526        else:
527            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
528
529    def _setter_lines(self, ri, member, presence):
530        # For multi-attr we have a count, not presence, hack up the presence
531        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
532        return [f"free({member});",
533                f"{member} = {self.c_name};",
534                f"{presence} = n_{self.c_name};"]
535
536
537class TypeArrayNest(Type):
538    def is_multi_val(self):
539        return True
540
541    def presence_type(self):
542        return 'count'
543
544    def _complex_member_type(self, ri):
545        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
546            return self.nested_struct_type
547        elif self.attr['sub-type'] in scalars:
548            scalar_pfx = '__' if ri.ku_space == 'user' else ''
549            return scalar_pfx + self.attr['sub-type']
550        else:
551            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
552
553    def _attr_typol(self):
554        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
555
556    def _attr_get(self, ri, var):
557        local_vars = ['const struct nlattr *attr2;']
558        get_lines = [f'attr_{self.c_name} = attr;',
559                     'mnl_attr_for_each_nested(attr2, attr)',
560                     f'\t{var}->n_{self.c_name}++;']
561        return get_lines, None, local_vars
562
563
564class TypeNestTypeValue(Type):
565    def _complex_member_type(self, ri):
566        return self.nested_struct_type
567
568    def _attr_typol(self):
569        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
570
571    def _attr_get(self, ri, var):
572        prev = 'attr'
573        tv_args = ''
574        get_lines = []
575        local_vars = []
576        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
577                      f"parg.data = &{var}->{self.c_name};"]
578        if 'type-value' in self.attr:
579            tv_names = [c_lower(x) for x in self.attr["type-value"]]
580            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
581            local_vars += [f'__u32 {", ".join(tv_names)};']
582            for level in self.attr["type-value"]:
583                level = c_lower(level)
584                get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
585                get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
586                prev = 'attr_' + level
587
588            tv_args = f", {', '.join(tv_names)}"
589
590        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
591        return get_lines, init_lines, local_vars
592
593
594class Struct:
595    def __init__(self, family, space_name, type_list=None, inherited=None):
596        self.family = family
597        self.space_name = space_name
598        self.attr_set = family.attr_sets[space_name]
599        # Use list to catch comparisons with empty sets
600        self._inherited = inherited if inherited is not None else []
601        self.inherited = []
602
603        self.nested = type_list is None
604        if family.name == c_lower(space_name):
605            self.render_name = f"{family.name}"
606        else:
607            self.render_name = f"{family.name}_{c_lower(space_name)}"
608        self.struct_name = 'struct ' + self.render_name
609        if self.nested and space_name in family.consts:
610            self.struct_name += '_'
611        self.ptr_name = self.struct_name + ' *'
612
613        self.request = False
614        self.reply = False
615
616        self.attr_list = []
617        self.attrs = dict()
618        if type_list is not None:
619            for t in type_list:
620                self.attr_list.append((t, self.attr_set[t]),)
621        else:
622            for t in self.attr_set:
623                self.attr_list.append((t, self.attr_set[t]),)
624
625        max_val = 0
626        self.attr_max_val = None
627        for name, attr in self.attr_list:
628            if attr.value >= max_val:
629                max_val = attr.value
630                self.attr_max_val = attr
631            self.attrs[name] = attr
632
633    def __iter__(self):
634        yield from self.attrs
635
636    def __getitem__(self, key):
637        return self.attrs[key]
638
639    def member_list(self):
640        return self.attr_list
641
642    def set_inherited(self, new_inherited):
643        if self._inherited != new_inherited:
644            raise Exception("Inheriting different members not supported")
645        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
646
647
648class EnumEntry(SpecEnumEntry):
649    def __init__(self, enum_set, yaml, prev, value_start):
650        super().__init__(enum_set, yaml, prev, value_start)
651
652        if prev:
653            self.value_change = (self.value != prev.value + 1)
654        else:
655            self.value_change = (self.value != 0)
656        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
657
658        # Added by resolve:
659        self.c_name = None
660        delattr(self, "c_name")
661
662    def resolve(self):
663        self.resolve_up(super())
664
665        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
666
667
668class EnumSet(SpecEnumSet):
669    def __init__(self, family, yaml):
670        self.render_name = c_lower(family.name + '-' + yaml['name'])
671
672        if 'enum-name' in yaml:
673            if yaml['enum-name']:
674                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
675            else:
676                self.enum_name = None
677        else:
678            self.enum_name = 'enum ' + self.render_name
679
680        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
681
682        super().__init__(family, yaml)
683
684    def new_entry(self, entry, prev_entry, value_start):
685        return EnumEntry(self, entry, prev_entry, value_start)
686
687    def value_range(self):
688        low = min([x.value for x in self.entries.values()])
689        high = max([x.value for x in self.entries.values()])
690
691        if high - low + 1 != len(self.entries):
692            raise Exception("Can't get value range for a noncontiguous enum")
693
694        return low, high
695
696
697class AttrSet(SpecAttrSet):
698    def __init__(self, family, yaml):
699        super().__init__(family, yaml)
700
701        if self.subset_of is None:
702            if 'name-prefix' in yaml:
703                pfx = yaml['name-prefix']
704            elif self.name == family.name:
705                pfx = family.name + '-a-'
706            else:
707                pfx = f"{family.name}-a-{self.name}-"
708            self.name_prefix = c_upper(pfx)
709            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
710        else:
711            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
712            self.max_name = family.attr_sets[self.subset_of].max_name
713
714        # Added by resolve:
715        self.c_name = None
716        delattr(self, "c_name")
717
718    def resolve(self):
719        self.c_name = c_lower(self.name)
720        if self.c_name in _C_KW:
721            self.c_name += '_'
722        if self.c_name == self.family.c_name:
723            self.c_name = ''
724
725    def new_attr(self, elem, value):
726        if elem['type'] in scalars:
727            t = TypeScalar(self.family, self, elem, value)
728        elif elem['type'] == 'unused':
729            t = TypeUnused(self.family, self, elem, value)
730        elif elem['type'] == 'pad':
731            t = TypePad(self.family, self, elem, value)
732        elif elem['type'] == 'flag':
733            t = TypeFlag(self.family, self, elem, value)
734        elif elem['type'] == 'string':
735            t = TypeString(self.family, self, elem, value)
736        elif elem['type'] == 'binary':
737            t = TypeBinary(self.family, self, elem, value)
738        elif elem['type'] == 'nest':
739            t = TypeNest(self.family, self, elem, value)
740        elif elem['type'] == 'array-nest':
741            t = TypeArrayNest(self.family, self, elem, value)
742        elif elem['type'] == 'nest-type-value':
743            t = TypeNestTypeValue(self.family, self, elem, value)
744        else:
745            raise Exception(f"No typed class for type {elem['type']}")
746
747        if 'multi-attr' in elem and elem['multi-attr']:
748            t = TypeMultiAttr(self.family, self, elem, value, t)
749
750        return t
751
752
753class Operation(SpecOperation):
754    def __init__(self, family, yaml, req_value, rsp_value):
755        super().__init__(family, yaml, req_value, rsp_value)
756
757        self.render_name = family.name + '_' + c_lower(self.name)
758
759        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
760                         ('dump' in yaml and 'request' in yaml['dump'])
761
762        self.has_ntf = False
763
764        # Added by resolve:
765        self.enum_name = None
766        delattr(self, "enum_name")
767
768    def resolve(self):
769        self.resolve_up(super())
770
771        if not self.is_async:
772            self.enum_name = self.family.op_prefix + c_upper(self.name)
773        else:
774            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
775
776    def mark_has_ntf(self):
777        self.has_ntf = True
778
779
780class Family(SpecFamily):
781    def __init__(self, file_name, exclude_ops):
782        # Added by resolve:
783        self.c_name = None
784        delattr(self, "c_name")
785        self.op_prefix = None
786        delattr(self, "op_prefix")
787        self.async_op_prefix = None
788        delattr(self, "async_op_prefix")
789        self.mcgrps = None
790        delattr(self, "mcgrps")
791        self.consts = None
792        delattr(self, "consts")
793        self.hooks = None
794        delattr(self, "hooks")
795
796        super().__init__(file_name, exclude_ops=exclude_ops)
797
798        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
799        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
800
801        if 'definitions' not in self.yaml:
802            self.yaml['definitions'] = []
803
804        if 'uapi-header' in self.yaml:
805            self.uapi_header = self.yaml['uapi-header']
806        else:
807            self.uapi_header = f"linux/{self.name}.h"
808
809    def resolve(self):
810        self.resolve_up(super())
811
812        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
813            raise Exception("Codegen only supported for genetlink")
814
815        self.c_name = c_lower(self.name)
816        if 'name-prefix' in self.yaml['operations']:
817            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
818        else:
819            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
820        if 'async-prefix' in self.yaml['operations']:
821            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
822        else:
823            self.async_op_prefix = self.op_prefix
824
825        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
826
827        self.hooks = dict()
828        for when in ['pre', 'post']:
829            self.hooks[when] = dict()
830            for op_mode in ['do', 'dump']:
831                self.hooks[when][op_mode] = dict()
832                self.hooks[when][op_mode]['set'] = set()
833                self.hooks[when][op_mode]['list'] = []
834
835        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
836        self.root_sets = dict()
837        # dict space-name -> set('request', 'reply')
838        self.pure_nested_structs = dict()
839
840        self._mark_notify()
841        self._mock_up_events()
842
843        self._load_root_sets()
844        self._load_nested_sets()
845        self._load_hooks()
846
847        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
848        if self.kernel_policy == 'global':
849            self._load_global_policy()
850
851    def new_enum(self, elem):
852        return EnumSet(self, elem)
853
854    def new_attr_set(self, elem):
855        return AttrSet(self, elem)
856
857    def new_operation(self, elem, req_value, rsp_value):
858        return Operation(self, elem, req_value, rsp_value)
859
860    def _mark_notify(self):
861        for op in self.msgs.values():
862            if 'notify' in op:
863                self.ops[op['notify']].mark_has_ntf()
864
865    # Fake a 'do' equivalent of all events, so that we can render their response parsing
866    def _mock_up_events(self):
867        for op in self.yaml['operations']['list']:
868            if 'event' in op:
869                op['do'] = {
870                    'reply': {
871                        'attributes': op['event']['attributes']
872                    }
873                }
874
875    def _load_root_sets(self):
876        for op_name, op in self.msgs.items():
877            if 'attribute-set' not in op:
878                continue
879
880            req_attrs = set()
881            rsp_attrs = set()
882            for op_mode in ['do', 'dump']:
883                if op_mode in op and 'request' in op[op_mode]:
884                    req_attrs.update(set(op[op_mode]['request']['attributes']))
885                if op_mode in op and 'reply' in op[op_mode]:
886                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
887            if 'event' in op:
888                rsp_attrs.update(set(op['event']['attributes']))
889
890            if op['attribute-set'] not in self.root_sets:
891                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
892            else:
893                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
894                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
895
896    def _load_nested_sets(self):
897        attr_set_queue = list(self.root_sets.keys())
898        attr_set_seen = set(self.root_sets.keys())
899
900        while len(attr_set_queue):
901            a_set = attr_set_queue.pop(0)
902            for attr, spec in self.attr_sets[a_set].items():
903                if 'nested-attributes' not in spec:
904                    continue
905
906                nested = spec['nested-attributes']
907                if nested not in attr_set_seen:
908                    attr_set_queue.append(nested)
909                    attr_set_seen.add(nested)
910
911                inherit = set()
912                if nested not in self.root_sets:
913                    if nested not in self.pure_nested_structs:
914                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
915                else:
916                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
917
918                if 'type-value' in spec:
919                    if nested in self.root_sets:
920                        raise Exception("Inheriting members to a space used as root not supported")
921                    inherit.update(set(spec['type-value']))
922                elif spec['type'] == 'array-nest':
923                    inherit.add('idx')
924                self.pure_nested_structs[nested].set_inherited(inherit)
925
926        for root_set, rs_members in self.root_sets.items():
927            for attr, spec in self.attr_sets[root_set].items():
928                if 'nested-attributes' in spec:
929                    nested = spec['nested-attributes']
930                    if attr in rs_members['request']:
931                        self.pure_nested_structs[nested].request = True
932                    if attr in rs_members['reply']:
933                        self.pure_nested_structs[nested].reply = True
934
935        # Try to reorder according to dependencies
936        pns_key_list = list(self.pure_nested_structs.keys())
937        pns_key_seen = set()
938        rounds = len(pns_key_list)**2  # it's basically bubble sort
939        for _ in range(rounds):
940            if len(pns_key_list) == 0:
941                break
942            name = pns_key_list.pop(0)
943            finished = True
944            for _, spec in self.attr_sets[name].items():
945                if 'nested-attributes' in spec:
946                    if spec['nested-attributes'] not in pns_key_seen:
947                        # Dicts are sorted, this will make struct last
948                        struct = self.pure_nested_structs.pop(name)
949                        self.pure_nested_structs[name] = struct
950                        finished = False
951                        break
952            if finished:
953                pns_key_seen.add(name)
954            else:
955                pns_key_list.append(name)
956        # Propagate the request / reply
957        for attr_set, struct in reversed(self.pure_nested_structs.items()):
958            for _, spec in self.attr_sets[attr_set].items():
959                if 'nested-attributes' in spec:
960                    child = self.pure_nested_structs.get(spec['nested-attributes'])
961                    if child:
962                        child.request |= struct.request
963                        child.reply |= struct.reply
964
965    def _load_global_policy(self):
966        global_set = set()
967        attr_set_name = None
968        for op_name, op in self.ops.items():
969            if not op:
970                continue
971            if 'attribute-set' not in op:
972                continue
973
974            if attr_set_name is None:
975                attr_set_name = op['attribute-set']
976            if attr_set_name != op['attribute-set']:
977                raise Exception('For a global policy all ops must use the same set')
978
979            for op_mode in ['do', 'dump']:
980                if op_mode in op:
981                    req = op[op_mode].get('request')
982                    if req:
983                        global_set.update(req.get('attributes', []))
984
985        self.global_policy = []
986        self.global_policy_set = attr_set_name
987        for attr in self.attr_sets[attr_set_name]:
988            if attr in global_set:
989                self.global_policy.append(attr)
990
991    def _load_hooks(self):
992        for op in self.ops.values():
993            for op_mode in ['do', 'dump']:
994                if op_mode not in op:
995                    continue
996                for when in ['pre', 'post']:
997                    if when not in op[op_mode]:
998                        continue
999                    name = op[op_mode][when]
1000                    if name in self.hooks[when][op_mode]['set']:
1001                        continue
1002                    self.hooks[when][op_mode]['set'].add(name)
1003                    self.hooks[when][op_mode]['list'].append(name)
1004
1005
1006class RenderInfo:
1007    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1008        self.family = family
1009        self.nl = cw.nlib
1010        self.ku_space = ku_space
1011        self.op_mode = op_mode
1012        self.op = op
1013
1014        # 'do' and 'dump' response parsing is identical
1015        self.type_consistent = True
1016        if op_mode != 'do' and 'dump' in op and 'do' in op:
1017            if ('reply' in op['do']) != ('reply' in op["dump"]):
1018                self.type_consistent = False
1019            elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1020                self.type_consistent = False
1021
1022        self.attr_set = attr_set
1023        if not self.attr_set:
1024            self.attr_set = op['attribute-set']
1025
1026        self.type_name_conflict = False
1027        if op:
1028            self.type_name = c_lower(op.name)
1029        else:
1030            self.type_name = c_lower(attr_set)
1031            if attr_set in family.consts:
1032                self.type_name_conflict = True
1033
1034        self.cw = cw
1035
1036        self.struct = dict()
1037        if op_mode == 'notify':
1038            op_mode = 'do'
1039        for op_dir in ['request', 'reply']:
1040            if op and op_dir in op[op_mode]:
1041                self.struct[op_dir] = Struct(family, self.attr_set,
1042                                             type_list=op[op_mode][op_dir]['attributes'])
1043        if op_mode == 'event':
1044            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1045
1046
1047class CodeWriter:
1048    def __init__(self, nlib, out_file):
1049        self.nlib = nlib
1050
1051        self._nl = False
1052        self._block_end = False
1053        self._silent_block = False
1054        self._ind = 0
1055        self._out = out_file
1056
1057    @classmethod
1058    def _is_cond(cls, line):
1059        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1060
1061    def p(self, line, add_ind=0):
1062        if self._block_end:
1063            self._block_end = False
1064            if line.startswith('else'):
1065                line = '} ' + line
1066            else:
1067                self._out.write('\t' * self._ind + '}\n')
1068
1069        if self._nl:
1070            self._out.write('\n')
1071            self._nl = False
1072
1073        ind = self._ind
1074        if line[-1] == ':':
1075            ind -= 1
1076        if self._silent_block:
1077            ind += 1
1078        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1079        if add_ind:
1080            ind += add_ind
1081        self._out.write('\t' * ind + line + '\n')
1082
1083    def nl(self):
1084        self._nl = True
1085
1086    def block_start(self, line=''):
1087        if line:
1088            line = line + ' '
1089        self.p(line + '{')
1090        self._ind += 1
1091
1092    def block_end(self, line=''):
1093        if line and line[0] not in {';', ','}:
1094            line = ' ' + line
1095        self._ind -= 1
1096        self._nl = False
1097        if not line:
1098            # Delay printing closing bracket in case "else" comes next
1099            if self._block_end:
1100                self._out.write('\t' * (self._ind + 1) + '}\n')
1101            self._block_end = True
1102        else:
1103            self.p('}' + line)
1104
1105    def write_doc_line(self, doc, indent=True):
1106        words = doc.split()
1107        line = ' *'
1108        for word in words:
1109            if len(line) + len(word) >= 79:
1110                self.p(line)
1111                line = ' *'
1112                if indent:
1113                    line += '  '
1114            line += ' ' + word
1115        self.p(line)
1116
1117    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1118        if not args:
1119            args = ['void']
1120
1121        if doc:
1122            self.p('/*')
1123            self.p(' * ' + doc)
1124            self.p(' */')
1125
1126        oneline = qual_ret
1127        if qual_ret[-1] != '*':
1128            oneline += ' '
1129        oneline += f"{name}({', '.join(args)}){suffix}"
1130
1131        if len(oneline) < 80:
1132            self.p(oneline)
1133            return
1134
1135        v = qual_ret
1136        if len(v) > 3:
1137            self.p(v)
1138            v = ''
1139        elif qual_ret[-1] != '*':
1140            v += ' '
1141        v += name + '('
1142        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1143        delta_ind = len(v) - len(ind)
1144        v += args[0]
1145        i = 1
1146        while i < len(args):
1147            next_len = len(v) + len(args[i])
1148            if v[0] == '\t':
1149                next_len += delta_ind
1150            if next_len > 76:
1151                self.p(v + ',')
1152                v = ind
1153            else:
1154                v += ', '
1155            v += args[i]
1156            i += 1
1157        self.p(v + ')' + suffix)
1158
1159    def write_func_lvar(self, local_vars):
1160        if not local_vars:
1161            return
1162
1163        if type(local_vars) is str:
1164            local_vars = [local_vars]
1165
1166        local_vars.sort(key=len, reverse=True)
1167        for var in local_vars:
1168            self.p(var)
1169        self.nl()
1170
1171    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1172        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1173        self.write_func_lvar(local_vars=local_vars)
1174
1175        self.block_start()
1176        for line in body:
1177            self.p(line)
1178        self.block_end()
1179
1180    def writes_defines(self, defines):
1181        longest = 0
1182        for define in defines:
1183            if len(define[0]) > longest:
1184                longest = len(define[0])
1185        longest = ((longest + 8) // 8) * 8
1186        for define in defines:
1187            line = '#define ' + define[0]
1188            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1189            if type(define[1]) is int:
1190                line += str(define[1])
1191            elif type(define[1]) is str:
1192                line += '"' + define[1] + '"'
1193            self.p(line)
1194
1195    def write_struct_init(self, members):
1196        longest = max([len(x[0]) for x in members])
1197        longest += 1  # because we prepend a .
1198        longest = ((longest + 8) // 8) * 8
1199        for one in members:
1200            line = '.' + one[0]
1201            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1202            line += '= ' + one[1] + ','
1203            self.p(line)
1204
1205
1206scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1207
1208direction_to_suffix = {
1209    'reply': '_rsp',
1210    'request': '_req',
1211    '': ''
1212}
1213
1214op_mode_to_wrapper = {
1215    'do': '',
1216    'dump': '_list',
1217    'notify': '_ntf',
1218    'event': '',
1219}
1220
1221_C_KW = {
1222    'auto',
1223    'bool',
1224    'break',
1225    'case',
1226    'char',
1227    'const',
1228    'continue',
1229    'default',
1230    'do',
1231    'double',
1232    'else',
1233    'enum',
1234    'extern',
1235    'float',
1236    'for',
1237    'goto',
1238    'if',
1239    'inline',
1240    'int',
1241    'long',
1242    'register',
1243    'return',
1244    'short',
1245    'signed',
1246    'sizeof',
1247    'static',
1248    'struct',
1249    'switch',
1250    'typedef',
1251    'union',
1252    'unsigned',
1253    'void',
1254    'volatile',
1255    'while'
1256}
1257
1258
1259def rdir(direction):
1260    if direction == 'reply':
1261        return 'request'
1262    if direction == 'request':
1263        return 'reply'
1264    return direction
1265
1266
1267def op_prefix(ri, direction, deref=False):
1268    suffix = f"_{ri.type_name}"
1269
1270    if not ri.op_mode or ri.op_mode == 'do':
1271        suffix += f"{direction_to_suffix[direction]}"
1272    else:
1273        if direction == 'request':
1274            suffix += '_req_dump'
1275        else:
1276            if ri.type_consistent:
1277                if deref:
1278                    suffix += f"{direction_to_suffix[direction]}"
1279                else:
1280                    suffix += op_mode_to_wrapper[ri.op_mode]
1281            else:
1282                suffix += '_rsp'
1283                suffix += '_dump' if deref else '_list'
1284
1285    return f"{ri.family['name']}{suffix}"
1286
1287
1288def type_name(ri, direction, deref=False):
1289    return f"struct {op_prefix(ri, direction, deref=deref)}"
1290
1291
1292def print_prototype(ri, direction, terminate=True, doc=None):
1293    suffix = ';' if terminate else ''
1294
1295    fname = ri.op.render_name
1296    if ri.op_mode == 'dump':
1297        fname += '_dump'
1298
1299    args = ['struct ynl_sock *ys']
1300    if 'request' in ri.op[ri.op_mode]:
1301        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1302
1303    ret = 'int'
1304    if 'reply' in ri.op[ri.op_mode]:
1305        ret = f"{type_name(ri, rdir(direction))} *"
1306
1307    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1308
1309
1310def print_req_prototype(ri):
1311    print_prototype(ri, "request", doc=ri.op['doc'])
1312
1313
1314def print_dump_prototype(ri):
1315    print_prototype(ri, "request")
1316
1317
1318def put_typol(cw, struct):
1319    type_max = struct.attr_set.max_name
1320    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1321
1322    for _, arg in struct.member_list():
1323        arg.attr_typol(cw)
1324
1325    cw.block_end(line=';')
1326    cw.nl()
1327
1328    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1329    cw.p(f'.max_attr = {type_max},')
1330    cw.p(f'.table = {struct.render_name}_policy,')
1331    cw.block_end(line=';')
1332    cw.nl()
1333
1334
1335def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1336    args = [f'int {arg_name}']
1337    if enum and not ('enum-name' in enum and not enum['enum-name']):
1338        args = [f'enum {render_name} {arg_name}']
1339    cw.write_func_prot('const char *', f'{render_name}_str', args)
1340    cw.block_start()
1341    if enum and enum.type == 'flags':
1342        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1343    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1344    cw.p('return NULL;')
1345    cw.p(f'return {map_name}[{arg_name}];')
1346    cw.block_end()
1347    cw.nl()
1348
1349
1350def put_op_name_fwd(family, cw):
1351    cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'], suffix=';')
1352
1353
1354def put_op_name(family, cw):
1355    map_name = f'{family.name}_op_strmap'
1356    cw.block_start(line=f"static const char * const {map_name}[] =")
1357    for op_name, op in family.msgs.items():
1358        if op.rsp_value:
1359            if op.req_value == op.rsp_value:
1360                cw.p(f'[{op.enum_name}] = "{op_name}",')
1361            else:
1362                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1363    cw.block_end(line=';')
1364    cw.nl()
1365
1366    _put_enum_to_str_helper(cw, family.name + '_op', map_name, 'op')
1367
1368
1369def put_enum_to_str_fwd(family, cw, enum):
1370    args = [f'enum {enum.render_name} value']
1371    if 'enum-name' in enum and not enum['enum-name']:
1372        args = ['int value']
1373    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1374
1375
1376def put_enum_to_str(family, cw, enum):
1377    map_name = f'{enum.render_name}_strmap'
1378    cw.block_start(line=f"static const char * const {map_name}[] =")
1379    for entry in enum.entries.values():
1380        cw.p(f'[{entry.value}] = "{entry.name}",')
1381    cw.block_end(line=';')
1382    cw.nl()
1383
1384    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1385
1386
1387def put_req_nested(ri, struct):
1388    func_args = ['struct nlmsghdr *nlh',
1389                 'unsigned int attr_type',
1390                 f'{struct.ptr_name}obj']
1391
1392    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1393    ri.cw.block_start()
1394    ri.cw.write_func_lvar('struct nlattr *nest;')
1395
1396    ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1397
1398    for _, arg in struct.member_list():
1399        arg.attr_put(ri, "obj")
1400
1401    ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1402
1403    ri.cw.nl()
1404    ri.cw.p('return 0;')
1405    ri.cw.block_end()
1406    ri.cw.nl()
1407
1408
1409def _multi_parse(ri, struct, init_lines, local_vars):
1410    if struct.nested:
1411        iter_line = "mnl_attr_for_each_nested(attr, nested)"
1412    else:
1413        iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1414
1415    array_nests = set()
1416    multi_attrs = set()
1417    needs_parg = False
1418    for arg, aspec in struct.member_list():
1419        if aspec['type'] == 'array-nest':
1420            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1421            array_nests.add(arg)
1422        if 'multi-attr' in aspec:
1423            multi_attrs.add(arg)
1424        needs_parg |= 'nested-attributes' in aspec
1425    if array_nests or multi_attrs:
1426        local_vars.append('int i;')
1427    if needs_parg:
1428        local_vars.append('struct ynl_parse_arg parg;')
1429        init_lines.append('parg.ys = yarg->ys;')
1430
1431    all_multi = array_nests | multi_attrs
1432
1433    for anest in sorted(all_multi):
1434        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1435
1436    ri.cw.block_start()
1437    ri.cw.write_func_lvar(local_vars)
1438
1439    for line in init_lines:
1440        ri.cw.p(line)
1441    ri.cw.nl()
1442
1443    for arg in struct.inherited:
1444        ri.cw.p(f'dst->{arg} = {arg};')
1445
1446    for anest in sorted(all_multi):
1447        aspec = struct[anest]
1448        ri.cw.p(f"if (dst->{aspec.c_name})")
1449        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1450
1451    ri.cw.nl()
1452    ri.cw.block_start(line=iter_line)
1453    ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1454    ri.cw.nl()
1455
1456    first = True
1457    for _, arg in struct.member_list():
1458        good = arg.attr_get(ri, 'dst', first=first)
1459        # First may be 'unused' or 'pad', ignore those
1460        first &= not good
1461
1462    ri.cw.block_end()
1463    ri.cw.nl()
1464
1465    for anest in sorted(array_nests):
1466        aspec = struct[anest]
1467
1468        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1469        ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1470        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1471        ri.cw.p('i = 0;')
1472        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1473        ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1474        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1475        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1476        ri.cw.p('return MNL_CB_ERROR;')
1477        ri.cw.p('i++;')
1478        ri.cw.block_end()
1479        ri.cw.block_end()
1480    ri.cw.nl()
1481
1482    for anest in sorted(multi_attrs):
1483        aspec = struct[anest]
1484        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1485        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1486        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1487        ri.cw.p('i = 0;')
1488        if 'nested-attributes' in aspec:
1489            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1490        ri.cw.block_start(line=iter_line)
1491        ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1492        if 'nested-attributes' in aspec:
1493            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1494            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1495            ri.cw.p('return MNL_CB_ERROR;')
1496        elif aspec['type'] in scalars:
1497            t = aspec['type']
1498            if t[0] == 's':
1499                t = 'u' + t[1:]
1500            ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1501        else:
1502            raise Exception('Nest parsing type not supported yet')
1503        ri.cw.p('i++;')
1504        ri.cw.block_end()
1505        ri.cw.block_end()
1506        ri.cw.block_end()
1507    ri.cw.nl()
1508
1509    if struct.nested:
1510        ri.cw.p('return 0;')
1511    else:
1512        ri.cw.p('return MNL_CB_OK;')
1513    ri.cw.block_end()
1514    ri.cw.nl()
1515
1516
1517def parse_rsp_nested(ri, struct):
1518    func_args = ['struct ynl_parse_arg *yarg',
1519                 'const struct nlattr *nested']
1520    for arg in struct.inherited:
1521        func_args.append('__u32 ' + arg)
1522
1523    local_vars = ['const struct nlattr *attr;',
1524                  f'{struct.ptr_name}dst = yarg->data;']
1525    init_lines = []
1526
1527    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1528
1529    _multi_parse(ri, struct, init_lines, local_vars)
1530
1531
1532def parse_rsp_msg(ri, deref=False):
1533    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1534        return
1535
1536    func_args = ['const struct nlmsghdr *nlh',
1537                 'void *data']
1538
1539    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1540                  'struct ynl_parse_arg *yarg = data;',
1541                  'const struct nlattr *attr;']
1542    init_lines = ['dst = yarg->data;']
1543
1544    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1545
1546    if ri.struct["reply"].member_list():
1547        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1548    else:
1549        # Empty reply
1550        ri.cw.block_start()
1551        ri.cw.p('return MNL_CB_OK;')
1552        ri.cw.block_end()
1553        ri.cw.nl()
1554
1555
1556def print_req(ri):
1557    ret_ok = '0'
1558    ret_err = '-1'
1559    direction = "request"
1560    local_vars = ['struct nlmsghdr *nlh;',
1561                  'int err;']
1562
1563    if 'reply' in ri.op[ri.op_mode]:
1564        ret_ok = 'rsp'
1565        ret_err = 'NULL'
1566        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1567                       'struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };']
1568
1569    print_prototype(ri, direction, terminate=False)
1570    ri.cw.block_start()
1571    ri.cw.write_func_lvar(local_vars)
1572
1573    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1574
1575    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1576    if 'reply' in ri.op[ri.op_mode]:
1577        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1578    ri.cw.nl()
1579    for _, attr in ri.struct["request"].member_list():
1580        attr.attr_put(ri, "req")
1581    ri.cw.nl()
1582
1583    parse_arg = "NULL"
1584    if 'reply' in ri.op[ri.op_mode]:
1585        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1586        ri.cw.p('yrs.yarg.data = rsp;')
1587        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1588        if ri.op.value is not None:
1589            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1590        else:
1591            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1592        ri.cw.nl()
1593        parse_arg = '&yrs'
1594    ri.cw.p(f"err = ynl_exec(ys, nlh, {parse_arg});")
1595    ri.cw.p('if (err < 0)')
1596    if 'reply' in ri.op[ri.op_mode]:
1597        ri.cw.p('goto err_free;')
1598    else:
1599        ri.cw.p('return -1;')
1600    ri.cw.nl()
1601
1602    ri.cw.p(f"return {ret_ok};")
1603    ri.cw.nl()
1604
1605    if 'reply' in ri.op[ri.op_mode]:
1606        ri.cw.p('err_free:')
1607        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1608        ri.cw.p(f"return {ret_err};")
1609
1610    ri.cw.block_end()
1611
1612
1613def print_dump(ri):
1614    direction = "request"
1615    print_prototype(ri, direction, terminate=False)
1616    ri.cw.block_start()
1617    local_vars = ['struct ynl_dump_state yds = {};',
1618                  'struct nlmsghdr *nlh;',
1619                  'int err;']
1620
1621    for var in local_vars:
1622        ri.cw.p(f'{var}')
1623    ri.cw.nl()
1624
1625    ri.cw.p('yds.ys = ys;')
1626    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1627    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1628    if ri.op.value is not None:
1629        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1630    else:
1631        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1632    ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1633    ri.cw.nl()
1634    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1635
1636    if "request" in ri.op[ri.op_mode]:
1637        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1638        ri.cw.nl()
1639        for _, attr in ri.struct["request"].member_list():
1640            attr.attr_put(ri, "req")
1641    ri.cw.nl()
1642
1643    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1644    ri.cw.p('if (err < 0)')
1645    ri.cw.p('goto free_list;')
1646    ri.cw.nl()
1647
1648    ri.cw.p('return yds.first;')
1649    ri.cw.nl()
1650    ri.cw.p('free_list:')
1651    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1652    ri.cw.p('return NULL;')
1653    ri.cw.block_end()
1654
1655
1656def call_free(ri, direction, var):
1657    return f"{op_prefix(ri, direction)}_free({var});"
1658
1659
1660def free_arg_name(direction):
1661    if direction:
1662        return direction_to_suffix[direction][1:]
1663    return 'obj'
1664
1665
1666def print_alloc_wrapper(ri, direction):
1667    name = op_prefix(ri, direction)
1668    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1669    ri.cw.block_start()
1670    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1671    ri.cw.block_end()
1672
1673
1674def print_free_prototype(ri, direction, suffix=';'):
1675    name = op_prefix(ri, direction)
1676    struct_name = name
1677    if ri.type_name_conflict:
1678        struct_name += '_'
1679    arg = free_arg_name(direction)
1680    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1681
1682
1683def _print_type(ri, direction, struct):
1684    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1685    if not direction and ri.type_name_conflict:
1686        suffix += '_'
1687
1688    if ri.op_mode == 'dump':
1689        suffix += '_dump'
1690
1691    ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1692
1693    meta_started = False
1694    for _, attr in struct.member_list():
1695        for type_filter in ['len', 'bit']:
1696            line = attr.presence_member(ri.ku_space, type_filter)
1697            if line:
1698                if not meta_started:
1699                    ri.cw.block_start(line=f"struct")
1700                    meta_started = True
1701                ri.cw.p(line)
1702    if meta_started:
1703        ri.cw.block_end(line='_present;')
1704        ri.cw.nl()
1705
1706    for arg in struct.inherited:
1707        ri.cw.p(f"__u32 {arg};")
1708
1709    for _, attr in struct.member_list():
1710        attr.struct_member(ri)
1711
1712    ri.cw.block_end(line=';')
1713    ri.cw.nl()
1714
1715
1716def print_type(ri, direction):
1717    _print_type(ri, direction, ri.struct[direction])
1718
1719
1720def print_type_full(ri, struct):
1721    _print_type(ri, "", struct)
1722
1723
1724def print_type_helpers(ri, direction, deref=False):
1725    print_free_prototype(ri, direction)
1726    ri.cw.nl()
1727
1728    if ri.ku_space == 'user' and direction == 'request':
1729        for _, attr in ri.struct[direction].member_list():
1730            attr.setter(ri, ri.attr_set, direction, deref=deref)
1731    ri.cw.nl()
1732
1733
1734def print_req_type_helpers(ri):
1735    print_alloc_wrapper(ri, "request")
1736    print_type_helpers(ri, "request")
1737
1738
1739def print_rsp_type_helpers(ri):
1740    if 'reply' not in ri.op[ri.op_mode]:
1741        return
1742    print_type_helpers(ri, "reply")
1743
1744
1745def print_parse_prototype(ri, direction, terminate=True):
1746    suffix = "_rsp" if direction == "reply" else "_req"
1747    term = ';' if terminate else ''
1748
1749    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1750                          ['const struct nlattr **tb',
1751                           f"struct {ri.op.render_name}{suffix} *req"],
1752                          suffix=term)
1753
1754
1755def print_req_type(ri):
1756    print_type(ri, "request")
1757
1758
1759def print_req_free(ri):
1760    if 'request' not in ri.op[ri.op_mode]:
1761        return
1762    _free_type(ri, 'request', ri.struct['request'])
1763
1764
1765def print_rsp_type(ri):
1766    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1767        direction = 'reply'
1768    elif ri.op_mode == 'event':
1769        direction = 'reply'
1770    else:
1771        return
1772    print_type(ri, direction)
1773
1774
1775def print_wrapped_type(ri):
1776    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1777    if ri.op_mode == 'dump':
1778        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1779    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1780        ri.cw.p('__u16 family;')
1781        ri.cw.p('__u8 cmd;')
1782        ri.cw.p('struct ynl_ntf_base_type *next;')
1783        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1784    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1785    ri.cw.block_end(line=';')
1786    ri.cw.nl()
1787    print_free_prototype(ri, 'reply')
1788    ri.cw.nl()
1789
1790
1791def _free_type_members_iter(ri, struct):
1792    for _, attr in struct.member_list():
1793        if attr.free_needs_iter():
1794            ri.cw.p('unsigned int i;')
1795            ri.cw.nl()
1796            break
1797
1798
1799def _free_type_members(ri, var, struct, ref=''):
1800    for _, attr in struct.member_list():
1801        attr.free(ri, var, ref)
1802
1803
1804def _free_type(ri, direction, struct):
1805    var = free_arg_name(direction)
1806
1807    print_free_prototype(ri, direction, suffix='')
1808    ri.cw.block_start()
1809    _free_type_members_iter(ri, struct)
1810    _free_type_members(ri, var, struct)
1811    if direction:
1812        ri.cw.p(f'free({var});')
1813    ri.cw.block_end()
1814    ri.cw.nl()
1815
1816
1817def free_rsp_nested(ri, struct):
1818    _free_type(ri, "", struct)
1819
1820
1821def print_rsp_free(ri):
1822    if 'reply' not in ri.op[ri.op_mode]:
1823        return
1824    _free_type(ri, 'reply', ri.struct['reply'])
1825
1826
1827def print_dump_type_free(ri):
1828    sub_type = type_name(ri, 'reply')
1829
1830    print_free_prototype(ri, 'reply', suffix='')
1831    ri.cw.block_start()
1832    ri.cw.p(f"{sub_type} *next = rsp;")
1833    ri.cw.nl()
1834    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1835    _free_type_members_iter(ri, ri.struct['reply'])
1836    ri.cw.p('rsp = next;')
1837    ri.cw.p('next = rsp->next;')
1838    ri.cw.nl()
1839
1840    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1841    ri.cw.p(f'free(rsp);')
1842    ri.cw.block_end()
1843    ri.cw.block_end()
1844    ri.cw.nl()
1845
1846
1847def print_ntf_type_free(ri):
1848    print_free_prototype(ri, 'reply', suffix='')
1849    ri.cw.block_start()
1850    _free_type_members_iter(ri, ri.struct['reply'])
1851    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1852    ri.cw.p(f'free(rsp);')
1853    ri.cw.block_end()
1854    ri.cw.nl()
1855
1856
1857def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1858    if terminate and ri and policy_should_be_static(struct.family):
1859        return
1860
1861    if terminate:
1862        prefix = 'extern '
1863    else:
1864        if ri and policy_should_be_static(struct.family):
1865            prefix = 'static '
1866        else:
1867            prefix = ''
1868
1869    suffix = ';' if terminate else ' = {'
1870
1871    max_attr = struct.attr_max_val
1872    if ri:
1873        name = ri.op.render_name
1874        if ri.op.dual_policy:
1875            name += '_' + ri.op_mode
1876    else:
1877        name = struct.render_name
1878    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1879
1880
1881def print_req_policy(cw, struct, ri=None):
1882    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1883    for _, arg in struct.member_list():
1884        arg.attr_policy(cw)
1885    cw.p("};")
1886    cw.nl()
1887
1888
1889def kernel_can_gen_family_struct(family):
1890    return family.proto == 'genetlink'
1891
1892
1893def policy_should_be_static(family):
1894    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
1895
1896
1897def print_kernel_op_table_fwd(family, cw, terminate):
1898    exported = not kernel_can_gen_family_struct(family)
1899
1900    if not terminate or exported:
1901        cw.p(f"/* Ops table for {family.name} */")
1902
1903        pol_to_struct = {'global': 'genl_small_ops',
1904                         'per-op': 'genl_ops',
1905                         'split': 'genl_split_ops'}
1906        struct_type = pol_to_struct[family.kernel_policy]
1907
1908        if not exported:
1909            cnt = ""
1910        elif family.kernel_policy == 'split':
1911            cnt = 0
1912            for op in family.ops.values():
1913                if 'do' in op:
1914                    cnt += 1
1915                if 'dump' in op:
1916                    cnt += 1
1917        else:
1918            cnt = len(family.ops)
1919
1920        qual = 'static const' if not exported else 'const'
1921        line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1922        if terminate:
1923            cw.p(f"extern {line};")
1924        else:
1925            cw.block_start(line=line + ' =')
1926
1927    if not terminate:
1928        return
1929
1930    cw.nl()
1931    for name in family.hooks['pre']['do']['list']:
1932        cw.write_func_prot('int', c_lower(name),
1933                           ['const struct genl_split_ops *ops',
1934                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1935    for name in family.hooks['post']['do']['list']:
1936        cw.write_func_prot('void', c_lower(name),
1937                           ['const struct genl_split_ops *ops',
1938                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1939    for name in family.hooks['pre']['dump']['list']:
1940        cw.write_func_prot('int', c_lower(name),
1941                           ['struct netlink_callback *cb'], suffix=';')
1942    for name in family.hooks['post']['dump']['list']:
1943        cw.write_func_prot('int', c_lower(name),
1944                           ['struct netlink_callback *cb'], suffix=';')
1945
1946    cw.nl()
1947
1948    for op_name, op in family.ops.items():
1949        if op.is_async:
1950            continue
1951
1952        if 'do' in op:
1953            name = c_lower(f"{family.name}-nl-{op_name}-doit")
1954            cw.write_func_prot('int', name,
1955                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1956
1957        if 'dump' in op:
1958            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1959            cw.write_func_prot('int', name,
1960                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1961    cw.nl()
1962
1963
1964def print_kernel_op_table_hdr(family, cw):
1965    print_kernel_op_table_fwd(family, cw, terminate=True)
1966
1967
1968def print_kernel_op_table(family, cw):
1969    print_kernel_op_table_fwd(family, cw, terminate=False)
1970    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1971        for op_name, op in family.ops.items():
1972            if op.is_async:
1973                continue
1974
1975            cw.block_start()
1976            members = [('cmd', op.enum_name)]
1977            if 'dont-validate' in op:
1978                members.append(('validate',
1979                                ' | '.join([c_upper('genl-dont-validate-' + x)
1980                                            for x in op['dont-validate']])), )
1981            for op_mode in ['do', 'dump']:
1982                if op_mode in op:
1983                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1984                    members.append((op_mode + 'it', name))
1985            if family.kernel_policy == 'per-op':
1986                struct = Struct(family, op['attribute-set'],
1987                                type_list=op['do']['request']['attributes'])
1988
1989                name = c_lower(f"{family.name}-{op_name}-nl-policy")
1990                members.append(('policy', name))
1991                members.append(('maxattr', struct.attr_max_val.enum_name))
1992            if 'flags' in op:
1993                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1994            cw.write_struct_init(members)
1995            cw.block_end(line=',')
1996    elif family.kernel_policy == 'split':
1997        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1998                    'dump': {'pre': 'start', 'post': 'done'}}
1999
2000        for op_name, op in family.ops.items():
2001            for op_mode in ['do', 'dump']:
2002                if op.is_async or op_mode not in op:
2003                    continue
2004
2005                cw.block_start()
2006                members = [('cmd', op.enum_name)]
2007                if 'dont-validate' in op:
2008                    dont_validate = []
2009                    for x in op['dont-validate']:
2010                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2011                            continue
2012                        if op_mode == "dump" and x == 'strict':
2013                            continue
2014                        dont_validate.append(x)
2015
2016                    if dont_validate:
2017                        members.append(('validate',
2018                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2019                                                    for x in dont_validate])), )
2020                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2021                if 'pre' in op[op_mode]:
2022                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2023                members.append((op_mode + 'it', name))
2024                if 'post' in op[op_mode]:
2025                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2026                if 'request' in op[op_mode]:
2027                    struct = Struct(family, op['attribute-set'],
2028                                    type_list=op[op_mode]['request']['attributes'])
2029
2030                    if op.dual_policy:
2031                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2032                    else:
2033                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
2034                    members.append(('policy', name))
2035                    members.append(('maxattr', struct.attr_max_val.enum_name))
2036                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2037                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2038                cw.write_struct_init(members)
2039                cw.block_end(line=',')
2040
2041    cw.block_end(line=';')
2042    cw.nl()
2043
2044
2045def print_kernel_mcgrp_hdr(family, cw):
2046    if not family.mcgrps['list']:
2047        return
2048
2049    cw.block_start('enum')
2050    for grp in family.mcgrps['list']:
2051        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2052        cw.p(grp_id)
2053    cw.block_end(';')
2054    cw.nl()
2055
2056
2057def print_kernel_mcgrp_src(family, cw):
2058    if not family.mcgrps['list']:
2059        return
2060
2061    cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
2062    for grp in family.mcgrps['list']:
2063        name = grp['name']
2064        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2065        cw.p('[' + grp_id + '] = { "' + name + '", },')
2066    cw.block_end(';')
2067    cw.nl()
2068
2069
2070def print_kernel_family_struct_hdr(family, cw):
2071    if not kernel_can_gen_family_struct(family):
2072        return
2073
2074    cw.p(f"extern struct genl_family {family.name}_nl_family;")
2075    cw.nl()
2076
2077
2078def print_kernel_family_struct_src(family, cw):
2079    if not kernel_can_gen_family_struct(family):
2080        return
2081
2082    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2083    cw.p('.name\t\t= ' + family.fam_key + ',')
2084    cw.p('.version\t= ' + family.ver_key + ',')
2085    cw.p('.netnsok\t= true,')
2086    cw.p('.parallel_ops\t= true,')
2087    cw.p('.module\t\t= THIS_MODULE,')
2088    if family.kernel_policy == 'per-op':
2089        cw.p(f'.ops\t\t= {family.name}_nl_ops,')
2090        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
2091    elif family.kernel_policy == 'split':
2092        cw.p(f'.split_ops\t= {family.name}_nl_ops,')
2093        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
2094    if family.mcgrps['list']:
2095        cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
2096        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
2097    cw.block_end(';')
2098
2099
2100def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2101    start_line = 'enum'
2102    if enum_name in obj:
2103        if obj[enum_name]:
2104            start_line = 'enum ' + c_lower(obj[enum_name])
2105    elif ckey and ckey in obj:
2106        start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
2107    cw.block_start(line=start_line)
2108
2109
2110def render_uapi(family, cw):
2111    hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
2112    cw.p('#ifndef ' + hdr_prot)
2113    cw.p('#define ' + hdr_prot)
2114    cw.nl()
2115
2116    defines = [(family.fam_key, family["name"]),
2117               (family.ver_key, family.get('version', 1))]
2118    cw.writes_defines(defines)
2119    cw.nl()
2120
2121    defines = []
2122    for const in family['definitions']:
2123        if const['type'] != 'const':
2124            cw.writes_defines(defines)
2125            defines = []
2126            cw.nl()
2127
2128        # Write kdoc for enum and flags (one day maybe also structs)
2129        if const['type'] == 'enum' or const['type'] == 'flags':
2130            enum = family.consts[const['name']]
2131
2132            if enum.has_doc():
2133                cw.p('/**')
2134                doc = ''
2135                if 'doc' in enum:
2136                    doc = ' - ' + enum['doc']
2137                cw.write_doc_line(enum.enum_name + doc)
2138                for entry in enum.entries.values():
2139                    if entry.has_doc():
2140                        doc = '@' + entry.c_name + ': ' + entry['doc']
2141                        cw.write_doc_line(doc)
2142                cw.p(' */')
2143
2144            uapi_enum_start(family, cw, const, 'name')
2145            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2146            for entry in enum.entries.values():
2147                suffix = ','
2148                if entry.value_change:
2149                    suffix = f" = {entry.user_value()}" + suffix
2150                cw.p(entry.c_name + suffix)
2151
2152            if const.get('render-max', False):
2153                cw.nl()
2154                cw.p('/* private: */')
2155                if const['type'] == 'flags':
2156                    max_name = c_upper(name_pfx + 'mask')
2157                    max_val = f' = {enum.get_mask()},'
2158                    cw.p(max_name + max_val)
2159                else:
2160                    max_name = c_upper(name_pfx + 'max')
2161                    cw.p('__' + max_name + ',')
2162                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2163            cw.block_end(line=';')
2164            cw.nl()
2165        elif const['type'] == 'const':
2166            defines.append([c_upper(family.get('c-define-name',
2167                                               f"{family.name}-{const['name']}")),
2168                            const['value']])
2169
2170    if defines:
2171        cw.writes_defines(defines)
2172        cw.nl()
2173
2174    max_by_define = family.get('max-by-define', False)
2175
2176    for _, attr_set in family.attr_sets.items():
2177        if attr_set.subset_of:
2178            continue
2179
2180        cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2181        max_value = f"({cnt_name} - 1)"
2182
2183        val = 0
2184        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2185        for _, attr in attr_set.items():
2186            suffix = ','
2187            if attr.value != val:
2188                suffix = f" = {attr.value},"
2189                val = attr.value
2190            val += 1
2191            cw.p(attr.enum_name + suffix)
2192        cw.nl()
2193        cw.p(cnt_name + ('' if max_by_define else ','))
2194        if not max_by_define:
2195            cw.p(f"{attr_set.max_name} = {max_value}")
2196        cw.block_end(line=';')
2197        if max_by_define:
2198            cw.p(f"#define {attr_set.max_name} {max_value}")
2199        cw.nl()
2200
2201    # Commands
2202    separate_ntf = 'async-prefix' in family['operations']
2203
2204    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2205    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2206    max_value = f"({cnt_name} - 1)"
2207
2208    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2209    val = 0
2210    for op in family.msgs.values():
2211        if separate_ntf and ('notify' in op or 'event' in op):
2212            continue
2213
2214        suffix = ','
2215        if op.value != val:
2216            suffix = f" = {op.value},"
2217            val = op.value
2218        cw.p(op.enum_name + suffix)
2219        val += 1
2220    cw.nl()
2221    cw.p(cnt_name + ('' if max_by_define else ','))
2222    if not max_by_define:
2223        cw.p(f"{max_name} = {max_value}")
2224    cw.block_end(line=';')
2225    if max_by_define:
2226        cw.p(f"#define {max_name} {max_value}")
2227    cw.nl()
2228
2229    if separate_ntf:
2230        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2231        for op in family.msgs.values():
2232            if separate_ntf and not ('notify' in op or 'event' in op):
2233                continue
2234
2235            suffix = ','
2236            if 'value' in op:
2237                suffix = f" = {op['value']},"
2238            cw.p(op.enum_name + suffix)
2239        cw.block_end(line=';')
2240        cw.nl()
2241
2242    # Multicast
2243    defines = []
2244    for grp in family.mcgrps['list']:
2245        name = grp['name']
2246        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2247                        f'{name}'])
2248    cw.nl()
2249    if defines:
2250        cw.writes_defines(defines)
2251        cw.nl()
2252
2253    cw.p(f'#endif /* {hdr_prot} */')
2254
2255
2256def _render_user_ntf_entry(ri, op):
2257    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2258    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2259    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2260    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2261    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2262    ri.cw.block_end(line=',')
2263
2264
2265def render_user_family(family, cw, prototype):
2266    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2267    if prototype:
2268        cw.p(f'extern {symbol};')
2269        return
2270
2271    if family.ntfs:
2272        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2273        for ntf_op_name, ntf_op in family.ntfs.items():
2274            if 'notify' in ntf_op:
2275                op = family.ops[ntf_op['notify']]
2276                ri = RenderInfo(cw, family, "user", op, "notify")
2277            elif 'event' in ntf_op:
2278                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2279            else:
2280                raise Exception('Invalid notification ' + ntf_op_name)
2281            _render_user_ntf_entry(ri, ntf_op)
2282        for op_name, op in family.ops.items():
2283            if 'event' not in op:
2284                continue
2285            ri = RenderInfo(cw, family, "user", op, "event")
2286            _render_user_ntf_entry(ri, op)
2287        cw.block_end(line=";")
2288        cw.nl()
2289
2290    cw.block_start(f'{symbol} = ')
2291    cw.p(f'.name\t\t= "{family.name}",')
2292    if family.ntfs:
2293        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2294        cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2295    cw.block_end(line=';')
2296
2297
2298def find_kernel_root(full_path):
2299    sub_path = ''
2300    while True:
2301        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2302        full_path = os.path.dirname(full_path)
2303        maintainers = os.path.join(full_path, "MAINTAINERS")
2304        if os.path.exists(maintainers):
2305            return full_path, sub_path[:-1]
2306
2307
2308def main():
2309    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2310    parser.add_argument('--mode', dest='mode', type=str, required=True)
2311    parser.add_argument('--spec', dest='spec', type=str, required=True)
2312    parser.add_argument('--header', dest='header', action='store_true', default=None)
2313    parser.add_argument('--source', dest='header', action='store_false')
2314    parser.add_argument('--user-header', nargs='+', default=[])
2315    parser.add_argument('--exclude-op', action='append', default=[])
2316    parser.add_argument('-o', dest='out_file', type=str)
2317    args = parser.parse_args()
2318
2319    tmp_file = tempfile.TemporaryFile('w+') if args.out_file else os.sys.stdout
2320
2321    if args.header is None:
2322        parser.error("--header or --source is required")
2323
2324    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2325
2326    try:
2327        parsed = Family(args.spec, exclude_ops)
2328        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2329            print('Spec license:', parsed.license)
2330            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2331            os.sys.exit(1)
2332    except yaml.YAMLError as exc:
2333        print(exc)
2334        os.sys.exit(1)
2335        return
2336
2337    supported_models = ['unified']
2338    if args.mode in ['user', 'kernel']:
2339        supported_models += ['directional']
2340    if parsed.msg_id_model not in supported_models:
2341        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2342        os.sys.exit(1)
2343
2344    cw = CodeWriter(BaseNlLib(), tmp_file)
2345
2346    _, spec_kernel = find_kernel_root(args.spec)
2347    if args.mode == 'uapi' or args.header:
2348        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2349    else:
2350        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2351    cw.p("/* Do not edit directly, auto-generated from: */")
2352    cw.p(f"/*\t{spec_kernel} */")
2353    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2354    if args.exclude_op or args.user_header:
2355        line = ''
2356        line += ' --user-header '.join([''] + args.user_header)
2357        line += ' --exclude-op '.join([''] + args.exclude_op)
2358        cw.p(f'/* YNL-ARG{line} */')
2359    cw.nl()
2360
2361    if args.mode == 'uapi':
2362        render_uapi(parsed, cw)
2363        return
2364
2365    hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2366    if args.header:
2367        cw.p('#ifndef ' + hdr_prot)
2368        cw.p('#define ' + hdr_prot)
2369        cw.nl()
2370
2371    if args.mode == 'kernel':
2372        cw.p('#include <net/netlink.h>')
2373        cw.p('#include <net/genetlink.h>')
2374        cw.nl()
2375        if not args.header:
2376            if args.out_file:
2377                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2378            cw.nl()
2379        headers = ['uapi/' + parsed.uapi_header]
2380    else:
2381        cw.p('#include <stdlib.h>')
2382        cw.p('#include <string.h>')
2383        if args.header:
2384            cw.p('#include <linux/types.h>')
2385        else:
2386            cw.p(f'#include "{parsed.name}-user.h"')
2387            cw.p('#include "ynl.h"')
2388        headers = [parsed.uapi_header]
2389    for definition in parsed['definitions']:
2390        if 'header' in definition:
2391            headers.append(definition['header'])
2392    for one in headers:
2393        cw.p(f"#include <{one}>")
2394    cw.nl()
2395
2396    if args.mode == "user":
2397        if not args.header:
2398            cw.p("#include <libmnl/libmnl.h>")
2399            cw.p("#include <linux/genetlink.h>")
2400            cw.nl()
2401            for one in args.user_header:
2402                cw.p(f'#include "{one}"')
2403        else:
2404            cw.p('struct ynl_sock;')
2405            cw.nl()
2406            render_user_family(parsed, cw, True)
2407        cw.nl()
2408
2409    if args.mode == "kernel":
2410        if args.header:
2411            for _, struct in sorted(parsed.pure_nested_structs.items()):
2412                if struct.request:
2413                    cw.p('/* Common nested types */')
2414                    break
2415            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2416                if struct.request:
2417                    print_req_policy_fwd(cw, struct)
2418            cw.nl()
2419
2420            if parsed.kernel_policy == 'global':
2421                cw.p(f"/* Global operation policy for {parsed.name} */")
2422
2423                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2424                print_req_policy_fwd(cw, struct)
2425                cw.nl()
2426
2427            if parsed.kernel_policy in {'per-op', 'split'}:
2428                for op_name, op in parsed.ops.items():
2429                    if 'do' in op and 'event' not in op:
2430                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2431                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2432                        cw.nl()
2433
2434            print_kernel_op_table_hdr(parsed, cw)
2435            print_kernel_mcgrp_hdr(parsed, cw)
2436            print_kernel_family_struct_hdr(parsed, cw)
2437        else:
2438            for _, struct in sorted(parsed.pure_nested_structs.items()):
2439                if struct.request:
2440                    cw.p('/* Common nested types */')
2441                    break
2442            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2443                if struct.request:
2444                    print_req_policy(cw, struct)
2445            cw.nl()
2446
2447            if parsed.kernel_policy == 'global':
2448                cw.p(f"/* Global operation policy for {parsed.name} */")
2449
2450                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2451                print_req_policy(cw, struct)
2452                cw.nl()
2453
2454            for op_name, op in parsed.ops.items():
2455                if parsed.kernel_policy in {'per-op', 'split'}:
2456                    for op_mode in ['do', 'dump']:
2457                        if op_mode in op and 'request' in op[op_mode]:
2458                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2459                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2460                            print_req_policy(cw, ri.struct['request'], ri=ri)
2461                            cw.nl()
2462
2463            print_kernel_op_table(parsed, cw)
2464            print_kernel_mcgrp_src(parsed, cw)
2465            print_kernel_family_struct_src(parsed, cw)
2466
2467    if args.mode == "user":
2468        if args.header:
2469            cw.p('/* Enums */')
2470            put_op_name_fwd(parsed, cw)
2471
2472            for name, const in parsed.consts.items():
2473                if isinstance(const, EnumSet):
2474                    put_enum_to_str_fwd(parsed, cw, const)
2475            cw.nl()
2476
2477            cw.p('/* Common nested types */')
2478            for attr_set, struct in parsed.pure_nested_structs.items():
2479                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2480                print_type_full(ri, struct)
2481
2482            for op_name, op in parsed.ops.items():
2483                cw.p(f"/* ============== {op.enum_name} ============== */")
2484
2485                if 'do' in op and 'event' not in op:
2486                    cw.p(f"/* {op.enum_name} - do */")
2487                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2488                    print_req_type(ri)
2489                    print_req_type_helpers(ri)
2490                    cw.nl()
2491                    print_rsp_type(ri)
2492                    print_rsp_type_helpers(ri)
2493                    cw.nl()
2494                    print_req_prototype(ri)
2495                    cw.nl()
2496
2497                if 'dump' in op:
2498                    cw.p(f"/* {op.enum_name} - dump */")
2499                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2500                    if 'request' in op['dump']:
2501                        print_req_type(ri)
2502                        print_req_type_helpers(ri)
2503                    if not ri.type_consistent:
2504                        print_rsp_type(ri)
2505                    print_wrapped_type(ri)
2506                    print_dump_prototype(ri)
2507                    cw.nl()
2508
2509                if op.has_ntf:
2510                    cw.p(f"/* {op.enum_name} - notify */")
2511                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2512                    if not ri.type_consistent:
2513                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2514                    print_wrapped_type(ri)
2515
2516            for op_name, op in parsed.ntfs.items():
2517                if 'event' in op:
2518                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2519                    cw.p(f"/* {op.enum_name} - event */")
2520                    print_rsp_type(ri)
2521                    cw.nl()
2522                    print_wrapped_type(ri)
2523            cw.nl()
2524        else:
2525            cw.p('/* Enums */')
2526            put_op_name(parsed, cw)
2527
2528            for name, const in parsed.consts.items():
2529                if isinstance(const, EnumSet):
2530                    put_enum_to_str(parsed, cw, const)
2531            cw.nl()
2532
2533            cw.p('/* Policies */')
2534            for name in parsed.pure_nested_structs:
2535                struct = Struct(parsed, name)
2536                put_typol(cw, struct)
2537            for name in parsed.root_sets:
2538                struct = Struct(parsed, name)
2539                put_typol(cw, struct)
2540
2541            cw.p('/* Common nested types */')
2542            for attr_set, struct in parsed.pure_nested_structs.items():
2543                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2544
2545                free_rsp_nested(ri, struct)
2546                if struct.request:
2547                    put_req_nested(ri, struct)
2548                if struct.reply:
2549                    parse_rsp_nested(ri, struct)
2550
2551            for op_name, op in parsed.ops.items():
2552                cw.p(f"/* ============== {op.enum_name} ============== */")
2553                if 'do' in op and 'event' not in op:
2554                    cw.p(f"/* {op.enum_name} - do */")
2555                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2556                    print_req_free(ri)
2557                    print_rsp_free(ri)
2558                    parse_rsp_msg(ri)
2559                    print_req(ri)
2560                    cw.nl()
2561
2562                if 'dump' in op:
2563                    cw.p(f"/* {op.enum_name} - dump */")
2564                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2565                    if not ri.type_consistent:
2566                        parse_rsp_msg(ri, deref=True)
2567                    print_dump_type_free(ri)
2568                    print_dump(ri)
2569                    cw.nl()
2570
2571                if op.has_ntf:
2572                    cw.p(f"/* {op.enum_name} - notify */")
2573                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2574                    if not ri.type_consistent:
2575                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2576                    print_ntf_type_free(ri)
2577
2578            for op_name, op in parsed.ntfs.items():
2579                if 'event' in op:
2580                    cw.p(f"/* {op.enum_name} - event */")
2581
2582                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2583                    parse_rsp_msg(ri)
2584
2585                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2586                    print_ntf_type_free(ri)
2587            cw.nl()
2588            render_user_family(parsed, cw, False)
2589
2590    if args.header:
2591        cw.p(f'#endif /* {hdr_prot} */')
2592
2593    if args.out_file:
2594        out_file = open(args.out_file, 'w+')
2595        tmp_file.seek(0)
2596        shutil.copyfileobj(tmp_file, out_file)
2597
2598if __name__ == "__main__":
2599    main()
2600