xref: /openbmc/linux/tools/net/ynl/ynl-gen-c.py (revision e855afd715656a9f25cf62fa68d99c33213b83b7)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import os
7import re
8import shutil
9import tempfile
10import yaml
11
12from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
13
14
15def c_upper(name):
16    return name.upper().replace('-', '_')
17
18
19def c_lower(name):
20    return name.lower().replace('-', '_')
21
22
23class BaseNlLib:
24    def get_family_id(self):
25        return 'ys->family_id'
26
27    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
28        ind = '\n\t\t' + '\t' * indent + ' '
29        if is_dump:
30            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
31        else:
32            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
33                   "ynl_cb_array, NLMSG_MIN_TYPE)"
34
35
36class Type(SpecAttr):
37    def __init__(self, family, attr_set, attr, value):
38        super().__init__(family, attr_set, attr, value)
39
40        self.attr = attr
41        self.attr_set = attr_set
42        self.type = attr['type']
43        self.checks = attr.get('checks', {})
44
45        if 'len' in attr:
46            self.len = attr['len']
47        if 'nested-attributes' in attr:
48            self.nested_attrs = attr['nested-attributes']
49            if self.nested_attrs == family.name:
50                self.nested_render_name = f"{family.name}"
51            else:
52                self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
53
54            if self.nested_attrs in self.family.consts:
55                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
56            else:
57                self.nested_struct_type = 'struct ' + self.nested_render_name
58
59        self.c_name = c_lower(self.name)
60        if self.c_name in _C_KW:
61            self.c_name += '_'
62
63        # Added by resolve():
64        self.enum_name = None
65        delattr(self, "enum_name")
66
67    def resolve(self):
68        if 'name-prefix' in self.attr:
69            enum_name = f"{self.attr['name-prefix']}{self.name}"
70        else:
71            enum_name = f"{self.attr_set.name_prefix}{self.name}"
72        self.enum_name = c_upper(enum_name)
73
74    def is_multi_val(self):
75        return None
76
77    def is_scalar(self):
78        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
79
80    def presence_type(self):
81        return 'bit'
82
83    def presence_member(self, space, type_filter):
84        if self.presence_type() != type_filter:
85            return
86
87        if self.presence_type() == 'bit':
88            pfx = '__' if space == 'user' else ''
89            return f"{pfx}u32 {self.c_name}:1;"
90
91        if self.presence_type() == 'len':
92            pfx = '__' if space == 'user' else ''
93            return f"{pfx}u32 {self.c_name}_len;"
94
95    def _complex_member_type(self, ri):
96        return None
97
98    def free_needs_iter(self):
99        return False
100
101    def free(self, ri, var, ref):
102        if self.is_multi_val() or self.presence_type() == 'len':
103            ri.cw.p(f'free({var}->{ref}{self.c_name});')
104
105    def arg_member(self, ri):
106        member = self._complex_member_type(ri)
107        if member:
108            arg = [member + ' *' + self.c_name]
109            if self.presence_type() == 'count':
110                arg += ['unsigned int n_' + self.c_name]
111            return arg
112        raise Exception(f"Struct member not implemented for class type {self.type}")
113
114    def struct_member(self, ri):
115        if self.is_multi_val():
116            ri.cw.p(f"unsigned int n_{self.c_name};")
117        member = self._complex_member_type(ri)
118        if member:
119            ptr = '*' if self.is_multi_val() else ''
120            ri.cw.p(f"{member} {ptr}{self.c_name};")
121            return
122        members = self.arg_member(ri)
123        for one in members:
124            ri.cw.p(one + ';')
125
126    def _attr_policy(self, policy):
127        return '{ .type = ' + policy + ', }'
128
129    def attr_policy(self, cw):
130        policy = c_upper('nla-' + self.attr['type'])
131
132        spec = self._attr_policy(policy)
133        cw.p(f"\t[{self.enum_name}] = {spec},")
134
135    def _attr_typol(self):
136        raise Exception(f"Type policy not implemented for class type {self.type}")
137
138    def attr_typol(self, cw):
139        typol = self._attr_typol()
140        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
141
142    def _attr_put_line(self, ri, var, line):
143        if self.presence_type() == 'bit':
144            ri.cw.p(f"if ({var}->_present.{self.c_name})")
145        elif self.presence_type() == 'len':
146            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
147        ri.cw.p(f"{line};")
148
149    def _attr_put_simple(self, ri, var, put_type):
150        line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
151        self._attr_put_line(ri, var, line)
152
153    def attr_put(self, ri, var):
154        raise Exception(f"Put not implemented for class type {self.type}")
155
156    def _attr_get(self, ri, var):
157        raise Exception(f"Attr get not implemented for class type {self.type}")
158
159    def attr_get(self, ri, var, first):
160        lines, init_lines, local_vars = self._attr_get(ri, var)
161        if type(lines) is str:
162            lines = [lines]
163        if type(init_lines) is str:
164            init_lines = [init_lines]
165
166        kw = 'if' if first else 'else if'
167        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
168        if local_vars:
169            for local in local_vars:
170                ri.cw.p(local)
171            ri.cw.nl()
172
173        if not self.is_multi_val():
174            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
175            ri.cw.p("return MNL_CB_ERROR;")
176            if self.presence_type() == 'bit':
177                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
178
179        if init_lines:
180            ri.cw.nl()
181            for line in init_lines:
182                ri.cw.p(line)
183
184        for line in lines:
185            ri.cw.p(line)
186        ri.cw.block_end()
187        return True
188
189    def _setter_lines(self, ri, member, presence):
190        raise Exception(f"Setter not implemented for class type {self.type}")
191
192    def setter(self, ri, space, direction, deref=False, ref=None):
193        ref = (ref if ref else []) + [self.c_name]
194        var = "req"
195        member = f"{var}->{'.'.join(ref)}"
196
197        code = []
198        presence = ''
199        for i in range(0, len(ref)):
200            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
201            if self.presence_type() == 'bit':
202                code.append(presence + ' = 1;')
203        code += self._setter_lines(ri, member, presence)
204
205        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
206        free = bool([x for x in code if 'free(' in x])
207        alloc = bool([x for x in code if 'alloc(' in x])
208        if free and not alloc:
209            func_name = '__' + func_name
210        ri.cw.write_func('static inline void', func_name, body=code,
211                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
212
213
214class TypeUnused(Type):
215    def presence_type(self):
216        return ''
217
218    def arg_member(self, ri):
219        return []
220
221    def _attr_get(self, ri, var):
222        return ['return MNL_CB_ERROR;'], None, None
223
224    def _attr_typol(self):
225        return '.type = YNL_PT_REJECT, '
226
227    def attr_policy(self, cw):
228        pass
229
230
231class TypePad(Type):
232    def presence_type(self):
233        return ''
234
235    def arg_member(self, ri):
236        return []
237
238    def _attr_typol(self):
239        return '.type = YNL_PT_IGNORE, '
240
241    def attr_put(self, ri, var):
242        pass
243
244    def attr_get(self, ri, var, first):
245        pass
246
247    def attr_policy(self, cw):
248        pass
249
250    def setter(self, ri, space, direction, deref=False, ref=None):
251        pass
252
253
254class TypeScalar(Type):
255    def __init__(self, family, attr_set, attr, value):
256        super().__init__(family, attr_set, attr, value)
257
258        self.byte_order_comment = ''
259        if 'byte-order' in attr:
260            self.byte_order_comment = f" /* {attr['byte-order']} */"
261
262        # Added by resolve():
263        self.is_bitfield = None
264        delattr(self, "is_bitfield")
265        self.type_name = None
266        delattr(self, "type_name")
267
268    def resolve(self):
269        self.resolve_up(super())
270
271        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
272            self.is_bitfield = True
273        elif 'enum' in self.attr:
274            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
275        else:
276            self.is_bitfield = False
277
278        maybe_enum = not self.is_bitfield and 'enum' in self.attr
279        if maybe_enum and self.family.consts[self.attr['enum']].enum_name:
280            self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
281        else:
282            self.type_name = '__' + self.type
283
284    def _mnl_type(self):
285        t = self.type
286        # mnl does not have a helper for signed types
287        if t[0] == 's':
288            t = 'u' + t[1:]
289        return t
290
291    def _attr_policy(self, policy):
292        if 'flags-mask' in self.checks or self.is_bitfield:
293            if self.is_bitfield:
294                enum = self.family.consts[self.attr['enum']]
295                mask = enum.get_mask(as_flags=True)
296            else:
297                flags = self.family.consts[self.checks['flags-mask']]
298                flag_cnt = len(flags['entries'])
299                mask = (1 << flag_cnt) - 1
300            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
301        elif 'min' in self.checks:
302            return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
303        elif 'enum' in self.attr:
304            enum = self.family.consts[self.attr['enum']]
305            low, high = enum.value_range()
306            if low == 0:
307                return f"NLA_POLICY_MAX({policy}, {high})"
308            return f"NLA_POLICY_RANGE({policy}, {low}, {high})"
309        return super()._attr_policy(policy)
310
311    def _attr_typol(self):
312        return f'.type = YNL_PT_U{self.type[1:]}, '
313
314    def arg_member(self, ri):
315        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
316
317    def attr_put(self, ri, var):
318        self._attr_put_simple(ri, var, self._mnl_type())
319
320    def _attr_get(self, ri, var):
321        return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
322
323    def _setter_lines(self, ri, member, presence):
324        return [f"{member} = {self.c_name};"]
325
326
327class TypeFlag(Type):
328    def arg_member(self, ri):
329        return []
330
331    def _attr_typol(self):
332        return '.type = YNL_PT_FLAG, '
333
334    def attr_put(self, ri, var):
335        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
336
337    def _attr_get(self, ri, var):
338        return [], None, None
339
340    def _setter_lines(self, ri, member, presence):
341        return []
342
343
344class TypeString(Type):
345    def arg_member(self, ri):
346        return [f"const char *{self.c_name}"]
347
348    def presence_type(self):
349        return 'len'
350
351    def struct_member(self, ri):
352        ri.cw.p(f"char *{self.c_name};")
353
354    def _attr_typol(self):
355        return f'.type = YNL_PT_NUL_STR, '
356
357    def _attr_policy(self, policy):
358        mem = '{ .type = ' + policy
359        if 'max-len' in self.checks:
360            mem += ', .len = ' + str(self.checks['max-len'])
361        mem += ', }'
362        return mem
363
364    def attr_policy(self, cw):
365        if self.checks.get('unterminated-ok', False):
366            policy = 'NLA_STRING'
367        else:
368            policy = 'NLA_NUL_STRING'
369
370        spec = self._attr_policy(policy)
371        cw.p(f"\t[{self.enum_name}] = {spec},")
372
373    def attr_put(self, ri, var):
374        self._attr_put_simple(ri, var, 'strz')
375
376    def _attr_get(self, ri, var):
377        len_mem = var + '->_present.' + self.c_name + '_len'
378        return [f"{len_mem} = len;",
379                f"{var}->{self.c_name} = malloc(len + 1);",
380                f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
381                f"{var}->{self.c_name}[len] = 0;"], \
382               ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
383               ['unsigned int len;']
384
385    def _setter_lines(self, ri, member, presence):
386        return [f"free({member});",
387                f"{presence}_len = strlen({self.c_name});",
388                f"{member} = malloc({presence}_len + 1);",
389                f'memcpy({member}, {self.c_name}, {presence}_len);',
390                f'{member}[{presence}_len] = 0;']
391
392
393class TypeBinary(Type):
394    def arg_member(self, ri):
395        return [f"const void *{self.c_name}", 'size_t len']
396
397    def presence_type(self):
398        return 'len'
399
400    def struct_member(self, ri):
401        ri.cw.p(f"void *{self.c_name};")
402
403    def _attr_typol(self):
404        return f'.type = YNL_PT_BINARY,'
405
406    def _attr_policy(self, policy):
407        mem = '{ '
408        if len(self.checks) == 1 and 'min-len' in self.checks:
409            mem += '.len = ' + str(self.checks['min-len'])
410        elif len(self.checks) == 0:
411            mem += '.type = NLA_BINARY'
412        else:
413            raise Exception('One or more of binary type checks not implemented, yet')
414        mem += ', }'
415        return mem
416
417    def attr_put(self, ri, var):
418        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
419                            f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
420
421    def _attr_get(self, ri, var):
422        len_mem = var + '->_present.' + self.c_name + '_len'
423        return [f"{len_mem} = len;",
424                f"{var}->{self.c_name} = malloc(len);",
425                f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
426               ['len = mnl_attr_get_payload_len(attr);'], \
427               ['unsigned int len;']
428
429    def _setter_lines(self, ri, member, presence):
430        return [f"free({member});",
431                f"{member} = malloc({presence}_len);",
432                f'memcpy({member}, {self.c_name}, {presence}_len);']
433
434
435class TypeNest(Type):
436    def _complex_member_type(self, ri):
437        return self.nested_struct_type
438
439    def free(self, ri, var, ref):
440        ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
441
442    def _attr_typol(self):
443        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
444
445    def _attr_policy(self, policy):
446        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
447
448    def attr_put(self, ri, var):
449        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
450                            f"{self.enum_name}, &{var}->{self.c_name})")
451
452    def _attr_get(self, ri, var):
453        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
454                     "return MNL_CB_ERROR;"]
455        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
456                      f"parg.data = &{var}->{self.c_name};"]
457        return get_lines, init_lines, None
458
459    def setter(self, ri, space, direction, deref=False, ref=None):
460        ref = (ref if ref else []) + [self.c_name]
461
462        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
463            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
464
465
466class TypeMultiAttr(Type):
467    def __init__(self, family, attr_set, attr, value, base_type):
468        super().__init__(family, attr_set, attr, value)
469
470        self.base_type = base_type
471
472    def is_multi_val(self):
473        return True
474
475    def presence_type(self):
476        return 'count'
477
478    def _mnl_type(self):
479        t = self.type
480        # mnl does not have a helper for signed types
481        if t[0] == 's':
482            t = 'u' + t[1:]
483        return t
484
485    def _complex_member_type(self, ri):
486        if 'type' not in self.attr or self.attr['type'] == 'nest':
487            return self.nested_struct_type
488        elif self.attr['type'] in scalars:
489            scalar_pfx = '__' if ri.ku_space == 'user' else ''
490            return scalar_pfx + self.attr['type']
491        else:
492            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
493
494    def free_needs_iter(self):
495        return 'type' not in self.attr or self.attr['type'] == 'nest'
496
497    def free(self, ri, var, ref):
498        if self.attr['type'] in scalars:
499            ri.cw.p(f"free({var}->{ref}{self.c_name});")
500        elif 'type' not in self.attr or self.attr['type'] == 'nest':
501            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
502            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
503            ri.cw.p(f"free({var}->{ref}{self.c_name});")
504        else:
505            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
506
507    def _attr_policy(self, policy):
508        return self.base_type._attr_policy(policy)
509
510    def _attr_typol(self):
511        return self.base_type._attr_typol()
512
513    def _attr_get(self, ri, var):
514        return f'n_{self.c_name}++;', None, None
515
516    def attr_put(self, ri, var):
517        if self.attr['type'] in scalars:
518            put_type = self._mnl_type()
519            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
520            ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
521        elif 'type' not in self.attr or self.attr['type'] == 'nest':
522            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
523            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
524                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
525        else:
526            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
527
528    def _setter_lines(self, ri, member, presence):
529        # For multi-attr we have a count, not presence, hack up the presence
530        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
531        return [f"free({member});",
532                f"{member} = {self.c_name};",
533                f"{presence} = n_{self.c_name};"]
534
535
536class TypeArrayNest(Type):
537    def is_multi_val(self):
538        return True
539
540    def presence_type(self):
541        return 'count'
542
543    def _complex_member_type(self, ri):
544        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
545            return self.nested_struct_type
546        elif self.attr['sub-type'] in scalars:
547            scalar_pfx = '__' if ri.ku_space == 'user' else ''
548            return scalar_pfx + self.attr['sub-type']
549        else:
550            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
551
552    def _attr_typol(self):
553        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
554
555    def _attr_get(self, ri, var):
556        local_vars = ['const struct nlattr *attr2;']
557        get_lines = [f'attr_{self.c_name} = attr;',
558                     'mnl_attr_for_each_nested(attr2, attr)',
559                     f'\t{var}->n_{self.c_name}++;']
560        return get_lines, None, local_vars
561
562
563class TypeNestTypeValue(Type):
564    def _complex_member_type(self, ri):
565        return self.nested_struct_type
566
567    def _attr_typol(self):
568        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
569
570    def _attr_get(self, ri, var):
571        prev = 'attr'
572        tv_args = ''
573        get_lines = []
574        local_vars = []
575        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
576                      f"parg.data = &{var}->{self.c_name};"]
577        if 'type-value' in self.attr:
578            tv_names = [c_lower(x) for x in self.attr["type-value"]]
579            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
580            local_vars += [f'__u32 {", ".join(tv_names)};']
581            for level in self.attr["type-value"]:
582                level = c_lower(level)
583                get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
584                get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
585                prev = 'attr_' + level
586
587            tv_args = f", {', '.join(tv_names)}"
588
589        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
590        return get_lines, init_lines, local_vars
591
592
593class Struct:
594    def __init__(self, family, space_name, type_list=None, inherited=None):
595        self.family = family
596        self.space_name = space_name
597        self.attr_set = family.attr_sets[space_name]
598        # Use list to catch comparisons with empty sets
599        self._inherited = inherited if inherited is not None else []
600        self.inherited = []
601
602        self.nested = type_list is None
603        if family.name == c_lower(space_name):
604            self.render_name = f"{family.name}"
605        else:
606            self.render_name = f"{family.name}_{c_lower(space_name)}"
607        self.struct_name = 'struct ' + self.render_name
608        if self.nested and space_name in family.consts:
609            self.struct_name += '_'
610        self.ptr_name = self.struct_name + ' *'
611
612        self.request = False
613        self.reply = False
614
615        self.attr_list = []
616        self.attrs = dict()
617        if type_list:
618            for t in type_list:
619                self.attr_list.append((t, self.attr_set[t]),)
620        else:
621            for t in self.attr_set:
622                self.attr_list.append((t, self.attr_set[t]),)
623
624        max_val = 0
625        self.attr_max_val = None
626        for name, attr in self.attr_list:
627            if attr.value >= max_val:
628                max_val = attr.value
629                self.attr_max_val = attr
630            self.attrs[name] = attr
631
632    def __iter__(self):
633        yield from self.attrs
634
635    def __getitem__(self, key):
636        return self.attrs[key]
637
638    def member_list(self):
639        return self.attr_list
640
641    def set_inherited(self, new_inherited):
642        if self._inherited != new_inherited:
643            raise Exception("Inheriting different members not supported")
644        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
645
646
647class EnumEntry(SpecEnumEntry):
648    def __init__(self, enum_set, yaml, prev, value_start):
649        super().__init__(enum_set, yaml, prev, value_start)
650
651        if prev:
652            self.value_change = (self.value != prev.value + 1)
653        else:
654            self.value_change = (self.value != 0)
655        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
656
657        # Added by resolve:
658        self.c_name = None
659        delattr(self, "c_name")
660
661    def resolve(self):
662        self.resolve_up(super())
663
664        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
665
666
667class EnumSet(SpecEnumSet):
668    def __init__(self, family, yaml):
669        self.render_name = c_lower(family.name + '-' + yaml['name'])
670
671        if 'enum-name' in yaml:
672            if yaml['enum-name']:
673                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
674            else:
675                self.enum_name = None
676        else:
677            self.enum_name = 'enum ' + self.render_name
678
679        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
680
681        super().__init__(family, yaml)
682
683    def new_entry(self, entry, prev_entry, value_start):
684        return EnumEntry(self, entry, prev_entry, value_start)
685
686    def value_range(self):
687        low = min([x.value for x in self.entries.values()])
688        high = max([x.value for x in self.entries.values()])
689
690        if high - low + 1 != len(self.entries):
691            raise Exception("Can't get value range for a noncontiguous enum")
692
693        return low, high
694
695
696class AttrSet(SpecAttrSet):
697    def __init__(self, family, yaml):
698        super().__init__(family, yaml)
699
700        if self.subset_of is None:
701            if 'name-prefix' in yaml:
702                pfx = yaml['name-prefix']
703            elif self.name == family.name:
704                pfx = family.name + '-a-'
705            else:
706                pfx = f"{family.name}-a-{self.name}-"
707            self.name_prefix = c_upper(pfx)
708            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
709        else:
710            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
711            self.max_name = family.attr_sets[self.subset_of].max_name
712
713        # Added by resolve:
714        self.c_name = None
715        delattr(self, "c_name")
716
717    def resolve(self):
718        self.c_name = c_lower(self.name)
719        if self.c_name in _C_KW:
720            self.c_name += '_'
721        if self.c_name == self.family.c_name:
722            self.c_name = ''
723
724    def new_attr(self, elem, value):
725        if elem['type'] in scalars:
726            t = TypeScalar(self.family, self, elem, value)
727        elif elem['type'] == 'unused':
728            t = TypeUnused(self.family, self, elem, value)
729        elif elem['type'] == 'pad':
730            t = TypePad(self.family, self, elem, value)
731        elif elem['type'] == 'flag':
732            t = TypeFlag(self.family, self, elem, value)
733        elif elem['type'] == 'string':
734            t = TypeString(self.family, self, elem, value)
735        elif elem['type'] == 'binary':
736            t = TypeBinary(self.family, self, elem, value)
737        elif elem['type'] == 'nest':
738            t = TypeNest(self.family, self, elem, value)
739        elif elem['type'] == 'array-nest':
740            t = TypeArrayNest(self.family, self, elem, value)
741        elif elem['type'] == 'nest-type-value':
742            t = TypeNestTypeValue(self.family, self, elem, value)
743        else:
744            raise Exception(f"No typed class for type {elem['type']}")
745
746        if 'multi-attr' in elem and elem['multi-attr']:
747            t = TypeMultiAttr(self.family, self, elem, value, t)
748
749        return t
750
751
752class Operation(SpecOperation):
753    def __init__(self, family, yaml, req_value, rsp_value):
754        super().__init__(family, yaml, req_value, rsp_value)
755
756        self.render_name = family.name + '_' + c_lower(self.name)
757
758        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
759                         ('dump' in yaml and 'request' in yaml['dump'])
760
761        self.has_ntf = False
762
763        # Added by resolve:
764        self.enum_name = None
765        delattr(self, "enum_name")
766
767    def resolve(self):
768        self.resolve_up(super())
769
770        if not self.is_async:
771            self.enum_name = self.family.op_prefix + c_upper(self.name)
772        else:
773            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
774
775    def mark_has_ntf(self):
776        self.has_ntf = True
777
778
779class Family(SpecFamily):
780    def __init__(self, file_name, exclude_ops):
781        # Added by resolve:
782        self.c_name = None
783        delattr(self, "c_name")
784        self.op_prefix = None
785        delattr(self, "op_prefix")
786        self.async_op_prefix = None
787        delattr(self, "async_op_prefix")
788        self.mcgrps = None
789        delattr(self, "mcgrps")
790        self.consts = None
791        delattr(self, "consts")
792        self.hooks = None
793        delattr(self, "hooks")
794
795        super().__init__(file_name, exclude_ops=exclude_ops)
796
797        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
798        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
799
800        if 'definitions' not in self.yaml:
801            self.yaml['definitions'] = []
802
803        if 'uapi-header' in self.yaml:
804            self.uapi_header = self.yaml['uapi-header']
805        else:
806            self.uapi_header = f"linux/{self.name}.h"
807
808    def resolve(self):
809        self.resolve_up(super())
810
811        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
812            raise Exception("Codegen only supported for genetlink")
813
814        self.c_name = c_lower(self.name)
815        if 'name-prefix' in self.yaml['operations']:
816            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
817        else:
818            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
819        if 'async-prefix' in self.yaml['operations']:
820            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
821        else:
822            self.async_op_prefix = self.op_prefix
823
824        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
825
826        self.hooks = dict()
827        for when in ['pre', 'post']:
828            self.hooks[when] = dict()
829            for op_mode in ['do', 'dump']:
830                self.hooks[when][op_mode] = dict()
831                self.hooks[when][op_mode]['set'] = set()
832                self.hooks[when][op_mode]['list'] = []
833
834        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
835        self.root_sets = dict()
836        # dict space-name -> set('request', 'reply')
837        self.pure_nested_structs = dict()
838
839        self._mark_notify()
840        self._mock_up_events()
841
842        self._load_root_sets()
843        self._load_nested_sets()
844        self._load_hooks()
845
846        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
847        if self.kernel_policy == 'global':
848            self._load_global_policy()
849
850    def new_enum(self, elem):
851        return EnumSet(self, elem)
852
853    def new_attr_set(self, elem):
854        return AttrSet(self, elem)
855
856    def new_operation(self, elem, req_value, rsp_value):
857        return Operation(self, elem, req_value, rsp_value)
858
859    def _mark_notify(self):
860        for op in self.msgs.values():
861            if 'notify' in op:
862                self.ops[op['notify']].mark_has_ntf()
863
864    # Fake a 'do' equivalent of all events, so that we can render their response parsing
865    def _mock_up_events(self):
866        for op in self.yaml['operations']['list']:
867            if 'event' in op:
868                op['do'] = {
869                    'reply': {
870                        'attributes': op['event']['attributes']
871                    }
872                }
873
874    def _load_root_sets(self):
875        for op_name, op in self.msgs.items():
876            if 'attribute-set' not in op:
877                continue
878
879            req_attrs = set()
880            rsp_attrs = set()
881            for op_mode in ['do', 'dump']:
882                if op_mode in op and 'request' in op[op_mode]:
883                    req_attrs.update(set(op[op_mode]['request']['attributes']))
884                if op_mode in op and 'reply' in op[op_mode]:
885                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
886            if 'event' in op:
887                rsp_attrs.update(set(op['event']['attributes']))
888
889            if op['attribute-set'] not in self.root_sets:
890                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
891            else:
892                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
893                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
894
895    def _load_nested_sets(self):
896        attr_set_queue = list(self.root_sets.keys())
897        attr_set_seen = set(self.root_sets.keys())
898
899        while len(attr_set_queue):
900            a_set = attr_set_queue.pop(0)
901            for attr, spec in self.attr_sets[a_set].items():
902                if 'nested-attributes' not in spec:
903                    continue
904
905                nested = spec['nested-attributes']
906                if nested not in attr_set_seen:
907                    attr_set_queue.append(nested)
908                    attr_set_seen.add(nested)
909
910                inherit = set()
911                if nested not in self.root_sets:
912                    if nested not in self.pure_nested_structs:
913                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
914                else:
915                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
916
917                if 'type-value' in spec:
918                    if nested in self.root_sets:
919                        raise Exception("Inheriting members to a space used as root not supported")
920                    inherit.update(set(spec['type-value']))
921                elif spec['type'] == 'array-nest':
922                    inherit.add('idx')
923                self.pure_nested_structs[nested].set_inherited(inherit)
924
925        for root_set, rs_members in self.root_sets.items():
926            for attr, spec in self.attr_sets[root_set].items():
927                if 'nested-attributes' in spec:
928                    nested = spec['nested-attributes']
929                    if attr in rs_members['request']:
930                        self.pure_nested_structs[nested].request = True
931                    if attr in rs_members['reply']:
932                        self.pure_nested_structs[nested].reply = True
933
934        # Try to reorder according to dependencies
935        pns_key_list = list(self.pure_nested_structs.keys())
936        pns_key_seen = set()
937        rounds = len(pns_key_list)**2  # it's basically bubble sort
938        for _ in range(rounds):
939            if len(pns_key_list) == 0:
940                break
941            name = pns_key_list.pop(0)
942            finished = True
943            for _, spec in self.attr_sets[name].items():
944                if 'nested-attributes' in spec:
945                    if spec['nested-attributes'] not in pns_key_seen:
946                        # Dicts are sorted, this will make struct last
947                        struct = self.pure_nested_structs.pop(name)
948                        self.pure_nested_structs[name] = struct
949                        finished = False
950                        break
951            if finished:
952                pns_key_seen.add(name)
953            else:
954                pns_key_list.append(name)
955        # Propagate the request / reply
956        for attr_set, struct in reversed(self.pure_nested_structs.items()):
957            for _, spec in self.attr_sets[attr_set].items():
958                if 'nested-attributes' in spec:
959                    child = self.pure_nested_structs.get(spec['nested-attributes'])
960                    if child:
961                        child.request |= struct.request
962                        child.reply |= struct.reply
963
964    def _load_global_policy(self):
965        global_set = set()
966        attr_set_name = None
967        for op_name, op in self.ops.items():
968            if not op:
969                continue
970            if 'attribute-set' not in op:
971                continue
972
973            if attr_set_name is None:
974                attr_set_name = op['attribute-set']
975            if attr_set_name != op['attribute-set']:
976                raise Exception('For a global policy all ops must use the same set')
977
978            for op_mode in ['do', 'dump']:
979                if op_mode in op:
980                    global_set.update(op[op_mode].get('request', []))
981
982        self.global_policy = []
983        self.global_policy_set = attr_set_name
984        for attr in self.attr_sets[attr_set_name]:
985            if attr in global_set:
986                self.global_policy.append(attr)
987
988    def _load_hooks(self):
989        for op in self.ops.values():
990            for op_mode in ['do', 'dump']:
991                if op_mode not in op:
992                    continue
993                for when in ['pre', 'post']:
994                    if when not in op[op_mode]:
995                        continue
996                    name = op[op_mode][when]
997                    if name in self.hooks[when][op_mode]['set']:
998                        continue
999                    self.hooks[when][op_mode]['set'].add(name)
1000                    self.hooks[when][op_mode]['list'].append(name)
1001
1002
1003class RenderInfo:
1004    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1005        self.family = family
1006        self.nl = cw.nlib
1007        self.ku_space = ku_space
1008        self.op_mode = op_mode
1009        self.op = op
1010
1011        # 'do' and 'dump' response parsing is identical
1012        self.type_consistent = True
1013        if op_mode != 'do' and 'dump' in op and 'do' in op:
1014            if ('reply' in op['do']) != ('reply' in op["dump"]):
1015                self.type_consistent = False
1016            elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1017                self.type_consistent = False
1018
1019        self.attr_set = attr_set
1020        if not self.attr_set:
1021            self.attr_set = op['attribute-set']
1022
1023        self.type_name_conflict = False
1024        if op:
1025            self.type_name = c_lower(op.name)
1026        else:
1027            self.type_name = c_lower(attr_set)
1028            if attr_set in family.consts:
1029                self.type_name_conflict = True
1030
1031        self.cw = cw
1032
1033        self.struct = dict()
1034        if op_mode == 'notify':
1035            op_mode = 'do'
1036        for op_dir in ['request', 'reply']:
1037            if op and op_dir in op[op_mode]:
1038                self.struct[op_dir] = Struct(family, self.attr_set,
1039                                             type_list=op[op_mode][op_dir]['attributes'])
1040        if op_mode == 'event':
1041            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1042
1043
1044class CodeWriter:
1045    def __init__(self, nlib, out_file):
1046        self.nlib = nlib
1047
1048        self._nl = False
1049        self._block_end = False
1050        self._silent_block = False
1051        self._ind = 0
1052        self._out = out_file
1053
1054    @classmethod
1055    def _is_cond(cls, line):
1056        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1057
1058    def p(self, line, add_ind=0):
1059        if self._block_end:
1060            self._block_end = False
1061            if line.startswith('else'):
1062                line = '} ' + line
1063            else:
1064                self._out.write('\t' * self._ind + '}\n')
1065
1066        if self._nl:
1067            self._out.write('\n')
1068            self._nl = False
1069
1070        ind = self._ind
1071        if line[-1] == ':':
1072            ind -= 1
1073        if self._silent_block:
1074            ind += 1
1075        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1076        if add_ind:
1077            ind += add_ind
1078        self._out.write('\t' * ind + line + '\n')
1079
1080    def nl(self):
1081        self._nl = True
1082
1083    def block_start(self, line=''):
1084        if line:
1085            line = line + ' '
1086        self.p(line + '{')
1087        self._ind += 1
1088
1089    def block_end(self, line=''):
1090        if line and line[0] not in {';', ','}:
1091            line = ' ' + line
1092        self._ind -= 1
1093        self._nl = False
1094        if not line:
1095            # Delay printing closing bracket in case "else" comes next
1096            if self._block_end:
1097                self._out.write('\t' * (self._ind + 1) + '}\n')
1098            self._block_end = True
1099        else:
1100            self.p('}' + line)
1101
1102    def write_doc_line(self, doc, indent=True):
1103        words = doc.split()
1104        line = ' *'
1105        for word in words:
1106            if len(line) + len(word) >= 79:
1107                self.p(line)
1108                line = ' *'
1109                if indent:
1110                    line += '  '
1111            line += ' ' + word
1112        self.p(line)
1113
1114    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1115        if not args:
1116            args = ['void']
1117
1118        if doc:
1119            self.p('/*')
1120            self.p(' * ' + doc)
1121            self.p(' */')
1122
1123        oneline = qual_ret
1124        if qual_ret[-1] != '*':
1125            oneline += ' '
1126        oneline += f"{name}({', '.join(args)}){suffix}"
1127
1128        if len(oneline) < 80:
1129            self.p(oneline)
1130            return
1131
1132        v = qual_ret
1133        if len(v) > 3:
1134            self.p(v)
1135            v = ''
1136        elif qual_ret[-1] != '*':
1137            v += ' '
1138        v += name + '('
1139        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1140        delta_ind = len(v) - len(ind)
1141        v += args[0]
1142        i = 1
1143        while i < len(args):
1144            next_len = len(v) + len(args[i])
1145            if v[0] == '\t':
1146                next_len += delta_ind
1147            if next_len > 76:
1148                self.p(v + ',')
1149                v = ind
1150            else:
1151                v += ', '
1152            v += args[i]
1153            i += 1
1154        self.p(v + ')' + suffix)
1155
1156    def write_func_lvar(self, local_vars):
1157        if not local_vars:
1158            return
1159
1160        if type(local_vars) is str:
1161            local_vars = [local_vars]
1162
1163        local_vars.sort(key=len, reverse=True)
1164        for var in local_vars:
1165            self.p(var)
1166        self.nl()
1167
1168    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1169        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1170        self.write_func_lvar(local_vars=local_vars)
1171
1172        self.block_start()
1173        for line in body:
1174            self.p(line)
1175        self.block_end()
1176
1177    def writes_defines(self, defines):
1178        longest = 0
1179        for define in defines:
1180            if len(define[0]) > longest:
1181                longest = len(define[0])
1182        longest = ((longest + 8) // 8) * 8
1183        for define in defines:
1184            line = '#define ' + define[0]
1185            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1186            if type(define[1]) is int:
1187                line += str(define[1])
1188            elif type(define[1]) is str:
1189                line += '"' + define[1] + '"'
1190            self.p(line)
1191
1192    def write_struct_init(self, members):
1193        longest = max([len(x[0]) for x in members])
1194        longest += 1  # because we prepend a .
1195        longest = ((longest + 8) // 8) * 8
1196        for one in members:
1197            line = '.' + one[0]
1198            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1199            line += '= ' + one[1] + ','
1200            self.p(line)
1201
1202
1203scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1204
1205direction_to_suffix = {
1206    'reply': '_rsp',
1207    'request': '_req',
1208    '': ''
1209}
1210
1211op_mode_to_wrapper = {
1212    'do': '',
1213    'dump': '_list',
1214    'notify': '_ntf',
1215    'event': '',
1216}
1217
1218_C_KW = {
1219    'auto',
1220    'bool',
1221    'break',
1222    'case',
1223    'char',
1224    'const',
1225    'continue',
1226    'default',
1227    'do',
1228    'double',
1229    'else',
1230    'enum',
1231    'extern',
1232    'float',
1233    'for',
1234    'goto',
1235    'if',
1236    'inline',
1237    'int',
1238    'long',
1239    'register',
1240    'return',
1241    'short',
1242    'signed',
1243    'sizeof',
1244    'static',
1245    'struct',
1246    'switch',
1247    'typedef',
1248    'union',
1249    'unsigned',
1250    'void',
1251    'volatile',
1252    'while'
1253}
1254
1255
1256def rdir(direction):
1257    if direction == 'reply':
1258        return 'request'
1259    if direction == 'request':
1260        return 'reply'
1261    return direction
1262
1263
1264def op_prefix(ri, direction, deref=False):
1265    suffix = f"_{ri.type_name}"
1266
1267    if not ri.op_mode or ri.op_mode == 'do':
1268        suffix += f"{direction_to_suffix[direction]}"
1269    else:
1270        if direction == 'request':
1271            suffix += '_req_dump'
1272        else:
1273            if ri.type_consistent:
1274                if deref:
1275                    suffix += f"{direction_to_suffix[direction]}"
1276                else:
1277                    suffix += op_mode_to_wrapper[ri.op_mode]
1278            else:
1279                suffix += '_rsp'
1280                suffix += '_dump' if deref else '_list'
1281
1282    return f"{ri.family['name']}{suffix}"
1283
1284
1285def type_name(ri, direction, deref=False):
1286    return f"struct {op_prefix(ri, direction, deref=deref)}"
1287
1288
1289def print_prototype(ri, direction, terminate=True, doc=None):
1290    suffix = ';' if terminate else ''
1291
1292    fname = ri.op.render_name
1293    if ri.op_mode == 'dump':
1294        fname += '_dump'
1295
1296    args = ['struct ynl_sock *ys']
1297    if 'request' in ri.op[ri.op_mode]:
1298        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1299
1300    ret = 'int'
1301    if 'reply' in ri.op[ri.op_mode]:
1302        ret = f"{type_name(ri, rdir(direction))} *"
1303
1304    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1305
1306
1307def print_req_prototype(ri):
1308    print_prototype(ri, "request", doc=ri.op['doc'])
1309
1310
1311def print_dump_prototype(ri):
1312    print_prototype(ri, "request")
1313
1314
1315def put_typol(cw, struct):
1316    type_max = struct.attr_set.max_name
1317    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1318
1319    for _, arg in struct.member_list():
1320        arg.attr_typol(cw)
1321
1322    cw.block_end(line=';')
1323    cw.nl()
1324
1325    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1326    cw.p(f'.max_attr = {type_max},')
1327    cw.p(f'.table = {struct.render_name}_policy,')
1328    cw.block_end(line=';')
1329    cw.nl()
1330
1331
1332def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1333    args = [f'int {arg_name}']
1334    if enum and not ('enum-name' in enum and not enum['enum-name']):
1335        args = [f'enum {render_name} {arg_name}']
1336    cw.write_func_prot('const char *', f'{render_name}_str', args)
1337    cw.block_start()
1338    if enum and enum.type == 'flags':
1339        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1340    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1341    cw.p('return NULL;')
1342    cw.p(f'return {map_name}[{arg_name}];')
1343    cw.block_end()
1344    cw.nl()
1345
1346
1347def put_op_name_fwd(family, cw):
1348    cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'], suffix=';')
1349
1350
1351def put_op_name(family, cw):
1352    map_name = f'{family.name}_op_strmap'
1353    cw.block_start(line=f"static const char * const {map_name}[] =")
1354    for op_name, op in family.msgs.items():
1355        if op.rsp_value:
1356            if op.req_value == op.rsp_value:
1357                cw.p(f'[{op.enum_name}] = "{op_name}",')
1358            else:
1359                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1360    cw.block_end(line=';')
1361    cw.nl()
1362
1363    _put_enum_to_str_helper(cw, family.name + '_op', map_name, 'op')
1364
1365
1366def put_enum_to_str_fwd(family, cw, enum):
1367    args = [f'enum {enum.render_name} value']
1368    if 'enum-name' in enum and not enum['enum-name']:
1369        args = ['int value']
1370    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1371
1372
1373def put_enum_to_str(family, cw, enum):
1374    map_name = f'{enum.render_name}_strmap'
1375    cw.block_start(line=f"static const char * const {map_name}[] =")
1376    for entry in enum.entries.values():
1377        cw.p(f'[{entry.value}] = "{entry.name}",')
1378    cw.block_end(line=';')
1379    cw.nl()
1380
1381    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1382
1383
1384def put_req_nested(ri, struct):
1385    func_args = ['struct nlmsghdr *nlh',
1386                 'unsigned int attr_type',
1387                 f'{struct.ptr_name}obj']
1388
1389    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1390    ri.cw.block_start()
1391    ri.cw.write_func_lvar('struct nlattr *nest;')
1392
1393    ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1394
1395    for _, arg in struct.member_list():
1396        arg.attr_put(ri, "obj")
1397
1398    ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1399
1400    ri.cw.nl()
1401    ri.cw.p('return 0;')
1402    ri.cw.block_end()
1403    ri.cw.nl()
1404
1405
1406def _multi_parse(ri, struct, init_lines, local_vars):
1407    if struct.nested:
1408        iter_line = "mnl_attr_for_each_nested(attr, nested)"
1409    else:
1410        iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1411
1412    array_nests = set()
1413    multi_attrs = set()
1414    needs_parg = False
1415    for arg, aspec in struct.member_list():
1416        if aspec['type'] == 'array-nest':
1417            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1418            array_nests.add(arg)
1419        if 'multi-attr' in aspec:
1420            multi_attrs.add(arg)
1421        needs_parg |= 'nested-attributes' in aspec
1422    if array_nests or multi_attrs:
1423        local_vars.append('int i;')
1424    if needs_parg:
1425        local_vars.append('struct ynl_parse_arg parg;')
1426        init_lines.append('parg.ys = yarg->ys;')
1427
1428    all_multi = array_nests | multi_attrs
1429
1430    for anest in sorted(all_multi):
1431        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1432
1433    ri.cw.block_start()
1434    ri.cw.write_func_lvar(local_vars)
1435
1436    for line in init_lines:
1437        ri.cw.p(line)
1438    ri.cw.nl()
1439
1440    for arg in struct.inherited:
1441        ri.cw.p(f'dst->{arg} = {arg};')
1442
1443    for anest in sorted(all_multi):
1444        aspec = struct[anest]
1445        ri.cw.p(f"if (dst->{aspec.c_name})")
1446        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1447
1448    ri.cw.nl()
1449    ri.cw.block_start(line=iter_line)
1450    ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1451    ri.cw.nl()
1452
1453    first = True
1454    for _, arg in struct.member_list():
1455        good = arg.attr_get(ri, 'dst', first=first)
1456        # First may be 'unused' or 'pad', ignore those
1457        first &= not good
1458
1459    ri.cw.block_end()
1460    ri.cw.nl()
1461
1462    for anest in sorted(array_nests):
1463        aspec = struct[anest]
1464
1465        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1466        ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1467        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1468        ri.cw.p('i = 0;')
1469        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1470        ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1471        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1472        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1473        ri.cw.p('return MNL_CB_ERROR;')
1474        ri.cw.p('i++;')
1475        ri.cw.block_end()
1476        ri.cw.block_end()
1477    ri.cw.nl()
1478
1479    for anest in sorted(multi_attrs):
1480        aspec = struct[anest]
1481        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1482        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1483        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1484        ri.cw.p('i = 0;')
1485        if 'nested-attributes' in aspec:
1486            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1487        ri.cw.block_start(line=iter_line)
1488        ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1489        if 'nested-attributes' in aspec:
1490            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1491            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1492            ri.cw.p('return MNL_CB_ERROR;')
1493        elif aspec['type'] in scalars:
1494            t = aspec['type']
1495            if t[0] == 's':
1496                t = 'u' + t[1:]
1497            ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1498        else:
1499            raise Exception('Nest parsing type not supported yet')
1500        ri.cw.p('i++;')
1501        ri.cw.block_end()
1502        ri.cw.block_end()
1503        ri.cw.block_end()
1504    ri.cw.nl()
1505
1506    if struct.nested:
1507        ri.cw.p('return 0;')
1508    else:
1509        ri.cw.p('return MNL_CB_OK;')
1510    ri.cw.block_end()
1511    ri.cw.nl()
1512
1513
1514def parse_rsp_nested(ri, struct):
1515    func_args = ['struct ynl_parse_arg *yarg',
1516                 'const struct nlattr *nested']
1517    for arg in struct.inherited:
1518        func_args.append('__u32 ' + arg)
1519
1520    local_vars = ['const struct nlattr *attr;',
1521                  f'{struct.ptr_name}dst = yarg->data;']
1522    init_lines = []
1523
1524    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1525
1526    _multi_parse(ri, struct, init_lines, local_vars)
1527
1528
1529def parse_rsp_msg(ri, deref=False):
1530    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1531        return
1532
1533    func_args = ['const struct nlmsghdr *nlh',
1534                 'void *data']
1535
1536    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1537                  'struct ynl_parse_arg *yarg = data;',
1538                  'const struct nlattr *attr;']
1539    init_lines = ['dst = yarg->data;']
1540
1541    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1542
1543    _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1544
1545
1546def print_req(ri):
1547    ret_ok = '0'
1548    ret_err = '-1'
1549    direction = "request"
1550    local_vars = ['struct nlmsghdr *nlh;',
1551                  'int err;']
1552
1553    if 'reply' in ri.op[ri.op_mode]:
1554        ret_ok = 'rsp'
1555        ret_err = 'NULL'
1556        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1557                       'struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };']
1558
1559    print_prototype(ri, direction, terminate=False)
1560    ri.cw.block_start()
1561    ri.cw.write_func_lvar(local_vars)
1562
1563    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1564
1565    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1566    if 'reply' in ri.op[ri.op_mode]:
1567        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1568    ri.cw.nl()
1569    for _, attr in ri.struct["request"].member_list():
1570        attr.attr_put(ri, "req")
1571    ri.cw.nl()
1572
1573    parse_arg = "NULL"
1574    if 'reply' in ri.op[ri.op_mode]:
1575        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1576        ri.cw.p('yrs.yarg.data = rsp;')
1577        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1578        if ri.op.value is not None:
1579            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1580        else:
1581            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1582        ri.cw.nl()
1583        parse_arg = '&yrs'
1584    ri.cw.p(f"err = ynl_exec(ys, nlh, {parse_arg});")
1585    ri.cw.p('if (err < 0)')
1586    if 'reply' in ri.op[ri.op_mode]:
1587        ri.cw.p('goto err_free;')
1588    else:
1589        ri.cw.p('return -1;')
1590    ri.cw.nl()
1591
1592    ri.cw.p(f"return {ret_ok};")
1593    ri.cw.nl()
1594
1595    if 'reply' in ri.op[ri.op_mode]:
1596        ri.cw.p('err_free:')
1597        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1598        ri.cw.p(f"return {ret_err};")
1599
1600    ri.cw.block_end()
1601
1602
1603def print_dump(ri):
1604    direction = "request"
1605    print_prototype(ri, direction, terminate=False)
1606    ri.cw.block_start()
1607    local_vars = ['struct ynl_dump_state yds = {};',
1608                  'struct nlmsghdr *nlh;',
1609                  'int err;']
1610
1611    for var in local_vars:
1612        ri.cw.p(f'{var}')
1613    ri.cw.nl()
1614
1615    ri.cw.p('yds.ys = ys;')
1616    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1617    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1618    if ri.op.value is not None:
1619        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1620    else:
1621        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1622    ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1623    ri.cw.nl()
1624    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1625
1626    if "request" in ri.op[ri.op_mode]:
1627        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1628        ri.cw.nl()
1629        for _, attr in ri.struct["request"].member_list():
1630            attr.attr_put(ri, "req")
1631    ri.cw.nl()
1632
1633    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1634    ri.cw.p('if (err < 0)')
1635    ri.cw.p('goto free_list;')
1636    ri.cw.nl()
1637
1638    ri.cw.p('return yds.first;')
1639    ri.cw.nl()
1640    ri.cw.p('free_list:')
1641    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1642    ri.cw.p('return NULL;')
1643    ri.cw.block_end()
1644
1645
1646def call_free(ri, direction, var):
1647    return f"{op_prefix(ri, direction)}_free({var});"
1648
1649
1650def free_arg_name(direction):
1651    if direction:
1652        return direction_to_suffix[direction][1:]
1653    return 'obj'
1654
1655
1656def print_alloc_wrapper(ri, direction):
1657    name = op_prefix(ri, direction)
1658    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1659    ri.cw.block_start()
1660    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1661    ri.cw.block_end()
1662
1663
1664def print_free_prototype(ri, direction, suffix=';'):
1665    name = op_prefix(ri, direction)
1666    struct_name = name
1667    if ri.type_name_conflict:
1668        struct_name += '_'
1669    arg = free_arg_name(direction)
1670    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1671
1672
1673def _print_type(ri, direction, struct):
1674    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1675    if not direction and ri.type_name_conflict:
1676        suffix += '_'
1677
1678    if ri.op_mode == 'dump':
1679        suffix += '_dump'
1680
1681    ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1682
1683    meta_started = False
1684    for _, attr in struct.member_list():
1685        for type_filter in ['len', 'bit']:
1686            line = attr.presence_member(ri.ku_space, type_filter)
1687            if line:
1688                if not meta_started:
1689                    ri.cw.block_start(line=f"struct")
1690                    meta_started = True
1691                ri.cw.p(line)
1692    if meta_started:
1693        ri.cw.block_end(line='_present;')
1694        ri.cw.nl()
1695
1696    for arg in struct.inherited:
1697        ri.cw.p(f"__u32 {arg};")
1698
1699    for _, attr in struct.member_list():
1700        attr.struct_member(ri)
1701
1702    ri.cw.block_end(line=';')
1703    ri.cw.nl()
1704
1705
1706def print_type(ri, direction):
1707    _print_type(ri, direction, ri.struct[direction])
1708
1709
1710def print_type_full(ri, struct):
1711    _print_type(ri, "", struct)
1712
1713
1714def print_type_helpers(ri, direction, deref=False):
1715    print_free_prototype(ri, direction)
1716    ri.cw.nl()
1717
1718    if ri.ku_space == 'user' and direction == 'request':
1719        for _, attr in ri.struct[direction].member_list():
1720            attr.setter(ri, ri.attr_set, direction, deref=deref)
1721    ri.cw.nl()
1722
1723
1724def print_req_type_helpers(ri):
1725    print_alloc_wrapper(ri, "request")
1726    print_type_helpers(ri, "request")
1727
1728
1729def print_rsp_type_helpers(ri):
1730    if 'reply' not in ri.op[ri.op_mode]:
1731        return
1732    print_type_helpers(ri, "reply")
1733
1734
1735def print_parse_prototype(ri, direction, terminate=True):
1736    suffix = "_rsp" if direction == "reply" else "_req"
1737    term = ';' if terminate else ''
1738
1739    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1740                          ['const struct nlattr **tb',
1741                           f"struct {ri.op.render_name}{suffix} *req"],
1742                          suffix=term)
1743
1744
1745def print_req_type(ri):
1746    print_type(ri, "request")
1747
1748
1749def print_req_free(ri):
1750    if 'request' not in ri.op[ri.op_mode]:
1751        return
1752    _free_type(ri, 'request', ri.struct['request'])
1753
1754
1755def print_rsp_type(ri):
1756    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1757        direction = 'reply'
1758    elif ri.op_mode == 'event':
1759        direction = 'reply'
1760    else:
1761        return
1762    print_type(ri, direction)
1763
1764
1765def print_wrapped_type(ri):
1766    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1767    if ri.op_mode == 'dump':
1768        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1769    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1770        ri.cw.p('__u16 family;')
1771        ri.cw.p('__u8 cmd;')
1772        ri.cw.p('struct ynl_ntf_base_type *next;')
1773        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1774    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1775    ri.cw.block_end(line=';')
1776    ri.cw.nl()
1777    print_free_prototype(ri, 'reply')
1778    ri.cw.nl()
1779
1780
1781def _free_type_members_iter(ri, struct):
1782    for _, attr in struct.member_list():
1783        if attr.free_needs_iter():
1784            ri.cw.p('unsigned int i;')
1785            ri.cw.nl()
1786            break
1787
1788
1789def _free_type_members(ri, var, struct, ref=''):
1790    for _, attr in struct.member_list():
1791        attr.free(ri, var, ref)
1792
1793
1794def _free_type(ri, direction, struct):
1795    var = free_arg_name(direction)
1796
1797    print_free_prototype(ri, direction, suffix='')
1798    ri.cw.block_start()
1799    _free_type_members_iter(ri, struct)
1800    _free_type_members(ri, var, struct)
1801    if direction:
1802        ri.cw.p(f'free({var});')
1803    ri.cw.block_end()
1804    ri.cw.nl()
1805
1806
1807def free_rsp_nested(ri, struct):
1808    _free_type(ri, "", struct)
1809
1810
1811def print_rsp_free(ri):
1812    if 'reply' not in ri.op[ri.op_mode]:
1813        return
1814    _free_type(ri, 'reply', ri.struct['reply'])
1815
1816
1817def print_dump_type_free(ri):
1818    sub_type = type_name(ri, 'reply')
1819
1820    print_free_prototype(ri, 'reply', suffix='')
1821    ri.cw.block_start()
1822    ri.cw.p(f"{sub_type} *next = rsp;")
1823    ri.cw.nl()
1824    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1825    _free_type_members_iter(ri, ri.struct['reply'])
1826    ri.cw.p('rsp = next;')
1827    ri.cw.p('next = rsp->next;')
1828    ri.cw.nl()
1829
1830    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1831    ri.cw.p(f'free(rsp);')
1832    ri.cw.block_end()
1833    ri.cw.block_end()
1834    ri.cw.nl()
1835
1836
1837def print_ntf_type_free(ri):
1838    print_free_prototype(ri, 'reply', suffix='')
1839    ri.cw.block_start()
1840    _free_type_members_iter(ri, ri.struct['reply'])
1841    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1842    ri.cw.p(f'free(rsp);')
1843    ri.cw.block_end()
1844    ri.cw.nl()
1845
1846
1847def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1848    if terminate and ri and policy_should_be_static(struct.family):
1849        return
1850
1851    if terminate:
1852        prefix = 'extern '
1853    else:
1854        if ri and policy_should_be_static(struct.family):
1855            prefix = 'static '
1856        else:
1857            prefix = ''
1858
1859    suffix = ';' if terminate else ' = {'
1860
1861    max_attr = struct.attr_max_val
1862    if ri:
1863        name = ri.op.render_name
1864        if ri.op.dual_policy:
1865            name += '_' + ri.op_mode
1866    else:
1867        name = struct.render_name
1868    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1869
1870
1871def print_req_policy(cw, struct, ri=None):
1872    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1873    for _, arg in struct.member_list():
1874        arg.attr_policy(cw)
1875    cw.p("};")
1876    cw.nl()
1877
1878
1879def kernel_can_gen_family_struct(family):
1880    return family.proto == 'genetlink'
1881
1882
1883def policy_should_be_static(family):
1884    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
1885
1886
1887def print_kernel_op_table_fwd(family, cw, terminate):
1888    exported = not kernel_can_gen_family_struct(family)
1889
1890    if not terminate or exported:
1891        cw.p(f"/* Ops table for {family.name} */")
1892
1893        pol_to_struct = {'global': 'genl_small_ops',
1894                         'per-op': 'genl_ops',
1895                         'split': 'genl_split_ops'}
1896        struct_type = pol_to_struct[family.kernel_policy]
1897
1898        if not exported:
1899            cnt = ""
1900        elif family.kernel_policy == 'split':
1901            cnt = 0
1902            for op in family.ops.values():
1903                if 'do' in op:
1904                    cnt += 1
1905                if 'dump' in op:
1906                    cnt += 1
1907        else:
1908            cnt = len(family.ops)
1909
1910        qual = 'static const' if not exported else 'const'
1911        line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1912        if terminate:
1913            cw.p(f"extern {line};")
1914        else:
1915            cw.block_start(line=line + ' =')
1916
1917    if not terminate:
1918        return
1919
1920    cw.nl()
1921    for name in family.hooks['pre']['do']['list']:
1922        cw.write_func_prot('int', c_lower(name),
1923                           ['const struct genl_split_ops *ops',
1924                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1925    for name in family.hooks['post']['do']['list']:
1926        cw.write_func_prot('void', c_lower(name),
1927                           ['const struct genl_split_ops *ops',
1928                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1929    for name in family.hooks['pre']['dump']['list']:
1930        cw.write_func_prot('int', c_lower(name),
1931                           ['struct netlink_callback *cb'], suffix=';')
1932    for name in family.hooks['post']['dump']['list']:
1933        cw.write_func_prot('int', c_lower(name),
1934                           ['struct netlink_callback *cb'], suffix=';')
1935
1936    cw.nl()
1937
1938    for op_name, op in family.ops.items():
1939        if op.is_async:
1940            continue
1941
1942        if 'do' in op:
1943            name = c_lower(f"{family.name}-nl-{op_name}-doit")
1944            cw.write_func_prot('int', name,
1945                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1946
1947        if 'dump' in op:
1948            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1949            cw.write_func_prot('int', name,
1950                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1951    cw.nl()
1952
1953
1954def print_kernel_op_table_hdr(family, cw):
1955    print_kernel_op_table_fwd(family, cw, terminate=True)
1956
1957
1958def print_kernel_op_table(family, cw):
1959    print_kernel_op_table_fwd(family, cw, terminate=False)
1960    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1961        for op_name, op in family.ops.items():
1962            if op.is_async:
1963                continue
1964
1965            cw.block_start()
1966            members = [('cmd', op.enum_name)]
1967            if 'dont-validate' in op:
1968                members.append(('validate',
1969                                ' | '.join([c_upper('genl-dont-validate-' + x)
1970                                            for x in op['dont-validate']])), )
1971            for op_mode in ['do', 'dump']:
1972                if op_mode in op:
1973                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1974                    members.append((op_mode + 'it', name))
1975            if family.kernel_policy == 'per-op':
1976                struct = Struct(family, op['attribute-set'],
1977                                type_list=op['do']['request']['attributes'])
1978
1979                name = c_lower(f"{family.name}-{op_name}-nl-policy")
1980                members.append(('policy', name))
1981                members.append(('maxattr', struct.attr_max_val.enum_name))
1982            if 'flags' in op:
1983                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1984            cw.write_struct_init(members)
1985            cw.block_end(line=',')
1986    elif family.kernel_policy == 'split':
1987        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1988                    'dump': {'pre': 'start', 'post': 'done'}}
1989
1990        for op_name, op in family.ops.items():
1991            for op_mode in ['do', 'dump']:
1992                if op.is_async or op_mode not in op:
1993                    continue
1994
1995                cw.block_start()
1996                members = [('cmd', op.enum_name)]
1997                if 'dont-validate' in op:
1998                    dont_validate = []
1999                    for x in op['dont-validate']:
2000                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2001                            continue
2002                        if op_mode == "dump" and x == 'strict':
2003                            continue
2004                        dont_validate.append(x)
2005
2006                    if dont_validate:
2007                        members.append(('validate',
2008                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2009                                                    for x in dont_validate])), )
2010                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2011                if 'pre' in op[op_mode]:
2012                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2013                members.append((op_mode + 'it', name))
2014                if 'post' in op[op_mode]:
2015                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2016                if 'request' in op[op_mode]:
2017                    struct = Struct(family, op['attribute-set'],
2018                                    type_list=op[op_mode]['request']['attributes'])
2019
2020                    if op.dual_policy:
2021                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2022                    else:
2023                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
2024                    members.append(('policy', name))
2025                    members.append(('maxattr', struct.attr_max_val.enum_name))
2026                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2027                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2028                cw.write_struct_init(members)
2029                cw.block_end(line=',')
2030
2031    cw.block_end(line=';')
2032    cw.nl()
2033
2034
2035def print_kernel_mcgrp_hdr(family, cw):
2036    if not family.mcgrps['list']:
2037        return
2038
2039    cw.block_start('enum')
2040    for grp in family.mcgrps['list']:
2041        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2042        cw.p(grp_id)
2043    cw.block_end(';')
2044    cw.nl()
2045
2046
2047def print_kernel_mcgrp_src(family, cw):
2048    if not family.mcgrps['list']:
2049        return
2050
2051    cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
2052    for grp in family.mcgrps['list']:
2053        name = grp['name']
2054        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2055        cw.p('[' + grp_id + '] = { "' + name + '", },')
2056    cw.block_end(';')
2057    cw.nl()
2058
2059
2060def print_kernel_family_struct_hdr(family, cw):
2061    if not kernel_can_gen_family_struct(family):
2062        return
2063
2064    cw.p(f"extern struct genl_family {family.name}_nl_family;")
2065    cw.nl()
2066
2067
2068def print_kernel_family_struct_src(family, cw):
2069    if not kernel_can_gen_family_struct(family):
2070        return
2071
2072    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2073    cw.p('.name\t\t= ' + family.fam_key + ',')
2074    cw.p('.version\t= ' + family.ver_key + ',')
2075    cw.p('.netnsok\t= true,')
2076    cw.p('.parallel_ops\t= true,')
2077    cw.p('.module\t\t= THIS_MODULE,')
2078    if family.kernel_policy == 'per-op':
2079        cw.p(f'.ops\t\t= {family.name}_nl_ops,')
2080        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
2081    elif family.kernel_policy == 'split':
2082        cw.p(f'.split_ops\t= {family.name}_nl_ops,')
2083        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
2084    if family.mcgrps['list']:
2085        cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
2086        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
2087    cw.block_end(';')
2088
2089
2090def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2091    start_line = 'enum'
2092    if enum_name in obj:
2093        if obj[enum_name]:
2094            start_line = 'enum ' + c_lower(obj[enum_name])
2095    elif ckey and ckey in obj:
2096        start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
2097    cw.block_start(line=start_line)
2098
2099
2100def render_uapi(family, cw):
2101    hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
2102    cw.p('#ifndef ' + hdr_prot)
2103    cw.p('#define ' + hdr_prot)
2104    cw.nl()
2105
2106    defines = [(family.fam_key, family["name"]),
2107               (family.ver_key, family.get('version', 1))]
2108    cw.writes_defines(defines)
2109    cw.nl()
2110
2111    defines = []
2112    for const in family['definitions']:
2113        if const['type'] != 'const':
2114            cw.writes_defines(defines)
2115            defines = []
2116            cw.nl()
2117
2118        # Write kdoc for enum and flags (one day maybe also structs)
2119        if const['type'] == 'enum' or const['type'] == 'flags':
2120            enum = family.consts[const['name']]
2121
2122            if enum.has_doc():
2123                cw.p('/**')
2124                doc = ''
2125                if 'doc' in enum:
2126                    doc = ' - ' + enum['doc']
2127                cw.write_doc_line(enum.enum_name + doc)
2128                for entry in enum.entries.values():
2129                    if entry.has_doc():
2130                        doc = '@' + entry.c_name + ': ' + entry['doc']
2131                        cw.write_doc_line(doc)
2132                cw.p(' */')
2133
2134            uapi_enum_start(family, cw, const, 'name')
2135            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2136            for entry in enum.entries.values():
2137                suffix = ','
2138                if entry.value_change:
2139                    suffix = f" = {entry.user_value()}" + suffix
2140                cw.p(entry.c_name + suffix)
2141
2142            if const.get('render-max', False):
2143                cw.nl()
2144                cw.p('/* private: */')
2145                if const['type'] == 'flags':
2146                    max_name = c_upper(name_pfx + 'mask')
2147                    max_val = f' = {enum.get_mask()},'
2148                    cw.p(max_name + max_val)
2149                else:
2150                    max_name = c_upper(name_pfx + 'max')
2151                    cw.p('__' + max_name + ',')
2152                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2153            cw.block_end(line=';')
2154            cw.nl()
2155        elif const['type'] == 'const':
2156            defines.append([c_upper(family.get('c-define-name',
2157                                               f"{family.name}-{const['name']}")),
2158                            const['value']])
2159
2160    if defines:
2161        cw.writes_defines(defines)
2162        cw.nl()
2163
2164    max_by_define = family.get('max-by-define', False)
2165
2166    for _, attr_set in family.attr_sets.items():
2167        if attr_set.subset_of:
2168            continue
2169
2170        cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2171        max_value = f"({cnt_name} - 1)"
2172
2173        val = 0
2174        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2175        for _, attr in attr_set.items():
2176            suffix = ','
2177            if attr.value != val:
2178                suffix = f" = {attr.value},"
2179                val = attr.value
2180            val += 1
2181            cw.p(attr.enum_name + suffix)
2182        cw.nl()
2183        cw.p(cnt_name + ('' if max_by_define else ','))
2184        if not max_by_define:
2185            cw.p(f"{attr_set.max_name} = {max_value}")
2186        cw.block_end(line=';')
2187        if max_by_define:
2188            cw.p(f"#define {attr_set.max_name} {max_value}")
2189        cw.nl()
2190
2191    # Commands
2192    separate_ntf = 'async-prefix' in family['operations']
2193
2194    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2195    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2196    max_value = f"({cnt_name} - 1)"
2197
2198    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2199    val = 0
2200    for op in family.msgs.values():
2201        if separate_ntf and ('notify' in op or 'event' in op):
2202            continue
2203
2204        suffix = ','
2205        if op.value != val:
2206            suffix = f" = {op.value},"
2207            val = op.value
2208        cw.p(op.enum_name + suffix)
2209        val += 1
2210    cw.nl()
2211    cw.p(cnt_name + ('' if max_by_define else ','))
2212    if not max_by_define:
2213        cw.p(f"{max_name} = {max_value}")
2214    cw.block_end(line=';')
2215    if max_by_define:
2216        cw.p(f"#define {max_name} {max_value}")
2217    cw.nl()
2218
2219    if separate_ntf:
2220        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2221        for op in family.msgs.values():
2222            if separate_ntf and not ('notify' in op or 'event' in op):
2223                continue
2224
2225            suffix = ','
2226            if 'value' in op:
2227                suffix = f" = {op['value']},"
2228            cw.p(op.enum_name + suffix)
2229        cw.block_end(line=';')
2230        cw.nl()
2231
2232    # Multicast
2233    defines = []
2234    for grp in family.mcgrps['list']:
2235        name = grp['name']
2236        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2237                        f'{name}'])
2238    cw.nl()
2239    if defines:
2240        cw.writes_defines(defines)
2241        cw.nl()
2242
2243    cw.p(f'#endif /* {hdr_prot} */')
2244
2245
2246def _render_user_ntf_entry(ri, op):
2247    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2248    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2249    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2250    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2251    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2252    ri.cw.block_end(line=',')
2253
2254
2255def render_user_family(family, cw, prototype):
2256    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2257    if prototype:
2258        cw.p(f'extern {symbol};')
2259        return
2260
2261    if family.ntfs:
2262        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2263        for ntf_op_name, ntf_op in family.ntfs.items():
2264            if 'notify' in ntf_op:
2265                op = family.ops[ntf_op['notify']]
2266                ri = RenderInfo(cw, family, "user", op, "notify")
2267            elif 'event' in ntf_op:
2268                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2269            else:
2270                raise Exception('Invalid notification ' + ntf_op_name)
2271            _render_user_ntf_entry(ri, ntf_op)
2272        for op_name, op in family.ops.items():
2273            if 'event' not in op:
2274                continue
2275            ri = RenderInfo(cw, family, "user", op, "event")
2276            _render_user_ntf_entry(ri, op)
2277        cw.block_end(line=";")
2278        cw.nl()
2279
2280    cw.block_start(f'{symbol} = ')
2281    cw.p(f'.name\t\t= "{family.name}",')
2282    if family.ntfs:
2283        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2284        cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2285    cw.block_end(line=';')
2286
2287
2288def find_kernel_root(full_path):
2289    sub_path = ''
2290    while True:
2291        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2292        full_path = os.path.dirname(full_path)
2293        maintainers = os.path.join(full_path, "MAINTAINERS")
2294        if os.path.exists(maintainers):
2295            return full_path, sub_path[:-1]
2296
2297
2298def main():
2299    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2300    parser.add_argument('--mode', dest='mode', type=str, required=True)
2301    parser.add_argument('--spec', dest='spec', type=str, required=True)
2302    parser.add_argument('--header', dest='header', action='store_true', default=None)
2303    parser.add_argument('--source', dest='header', action='store_false')
2304    parser.add_argument('--user-header', nargs='+', default=[])
2305    parser.add_argument('--exclude-op', action='append', default=[])
2306    parser.add_argument('-o', dest='out_file', type=str)
2307    args = parser.parse_args()
2308
2309    tmp_file = tempfile.TemporaryFile('w+') if args.out_file else os.sys.stdout
2310
2311    if args.header is None:
2312        parser.error("--header or --source is required")
2313
2314    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2315
2316    try:
2317        parsed = Family(args.spec, exclude_ops)
2318        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2319            print('Spec license:', parsed.license)
2320            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2321            os.sys.exit(1)
2322    except yaml.YAMLError as exc:
2323        print(exc)
2324        os.sys.exit(1)
2325        return
2326
2327    supported_models = ['unified']
2328    if args.mode in ['user', 'kernel']:
2329        supported_models += ['directional']
2330    if parsed.msg_id_model not in supported_models:
2331        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2332        os.sys.exit(1)
2333
2334    cw = CodeWriter(BaseNlLib(), tmp_file)
2335
2336    _, spec_kernel = find_kernel_root(args.spec)
2337    if args.mode == 'uapi' or args.header:
2338        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2339    else:
2340        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2341    cw.p("/* Do not edit directly, auto-generated from: */")
2342    cw.p(f"/*\t{spec_kernel} */")
2343    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2344    if args.exclude_op or args.user_header:
2345        line = ''
2346        line += ' --user-header '.join([''] + args.user_header)
2347        line += ' --exclude-op '.join([''] + args.exclude_op)
2348        cw.p(f'/* YNL-ARG{line} */')
2349    cw.nl()
2350
2351    if args.mode == 'uapi':
2352        render_uapi(parsed, cw)
2353        return
2354
2355    hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2356    if args.header:
2357        cw.p('#ifndef ' + hdr_prot)
2358        cw.p('#define ' + hdr_prot)
2359        cw.nl()
2360
2361    if args.mode == 'kernel':
2362        cw.p('#include <net/netlink.h>')
2363        cw.p('#include <net/genetlink.h>')
2364        cw.nl()
2365        if not args.header:
2366            if args.out_file:
2367                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2368            cw.nl()
2369        headers = ['uapi/' + parsed.uapi_header]
2370    else:
2371        cw.p('#include <stdlib.h>')
2372        cw.p('#include <string.h>')
2373        if args.header:
2374            cw.p('#include <linux/types.h>')
2375        else:
2376            cw.p(f'#include "{parsed.name}-user.h"')
2377            cw.p('#include "ynl.h"')
2378        headers = [parsed.uapi_header]
2379    for definition in parsed['definitions']:
2380        if 'header' in definition:
2381            headers.append(definition['header'])
2382    for one in headers:
2383        cw.p(f"#include <{one}>")
2384    cw.nl()
2385
2386    if args.mode == "user":
2387        if not args.header:
2388            cw.p("#include <libmnl/libmnl.h>")
2389            cw.p("#include <linux/genetlink.h>")
2390            cw.nl()
2391            for one in args.user_header:
2392                cw.p(f'#include "{one}"')
2393        else:
2394            cw.p('struct ynl_sock;')
2395            cw.nl()
2396            render_user_family(parsed, cw, True)
2397        cw.nl()
2398
2399    if args.mode == "kernel":
2400        if args.header:
2401            for _, struct in sorted(parsed.pure_nested_structs.items()):
2402                if struct.request:
2403                    cw.p('/* Common nested types */')
2404                    break
2405            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2406                if struct.request:
2407                    print_req_policy_fwd(cw, struct)
2408            cw.nl()
2409
2410            if parsed.kernel_policy == 'global':
2411                cw.p(f"/* Global operation policy for {parsed.name} */")
2412
2413                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2414                print_req_policy_fwd(cw, struct)
2415                cw.nl()
2416
2417            if parsed.kernel_policy in {'per-op', 'split'}:
2418                for op_name, op in parsed.ops.items():
2419                    if 'do' in op and 'event' not in op:
2420                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2421                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2422                        cw.nl()
2423
2424            print_kernel_op_table_hdr(parsed, cw)
2425            print_kernel_mcgrp_hdr(parsed, cw)
2426            print_kernel_family_struct_hdr(parsed, cw)
2427        else:
2428            for _, struct in sorted(parsed.pure_nested_structs.items()):
2429                if struct.request:
2430                    cw.p('/* Common nested types */')
2431                    break
2432            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2433                if struct.request:
2434                    print_req_policy(cw, struct)
2435            cw.nl()
2436
2437            if parsed.kernel_policy == 'global':
2438                cw.p(f"/* Global operation policy for {parsed.name} */")
2439
2440                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2441                print_req_policy(cw, struct)
2442                cw.nl()
2443
2444            for op_name, op in parsed.ops.items():
2445                if parsed.kernel_policy in {'per-op', 'split'}:
2446                    for op_mode in ['do', 'dump']:
2447                        if op_mode in op and 'request' in op[op_mode]:
2448                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2449                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2450                            print_req_policy(cw, ri.struct['request'], ri=ri)
2451                            cw.nl()
2452
2453            print_kernel_op_table(parsed, cw)
2454            print_kernel_mcgrp_src(parsed, cw)
2455            print_kernel_family_struct_src(parsed, cw)
2456
2457    if args.mode == "user":
2458        if args.header:
2459            cw.p('/* Enums */')
2460            put_op_name_fwd(parsed, cw)
2461
2462            for name, const in parsed.consts.items():
2463                if isinstance(const, EnumSet):
2464                    put_enum_to_str_fwd(parsed, cw, const)
2465            cw.nl()
2466
2467            cw.p('/* Common nested types */')
2468            for attr_set, struct in parsed.pure_nested_structs.items():
2469                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2470                print_type_full(ri, struct)
2471
2472            for op_name, op in parsed.ops.items():
2473                cw.p(f"/* ============== {op.enum_name} ============== */")
2474
2475                if 'do' in op and 'event' not in op:
2476                    cw.p(f"/* {op.enum_name} - do */")
2477                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2478                    print_req_type(ri)
2479                    print_req_type_helpers(ri)
2480                    cw.nl()
2481                    print_rsp_type(ri)
2482                    print_rsp_type_helpers(ri)
2483                    cw.nl()
2484                    print_req_prototype(ri)
2485                    cw.nl()
2486
2487                if 'dump' in op:
2488                    cw.p(f"/* {op.enum_name} - dump */")
2489                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2490                    if 'request' in op['dump']:
2491                        print_req_type(ri)
2492                        print_req_type_helpers(ri)
2493                    if not ri.type_consistent:
2494                        print_rsp_type(ri)
2495                    print_wrapped_type(ri)
2496                    print_dump_prototype(ri)
2497                    cw.nl()
2498
2499                if op.has_ntf:
2500                    cw.p(f"/* {op.enum_name} - notify */")
2501                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2502                    if not ri.type_consistent:
2503                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2504                    print_wrapped_type(ri)
2505
2506            for op_name, op in parsed.ntfs.items():
2507                if 'event' in op:
2508                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2509                    cw.p(f"/* {op.enum_name} - event */")
2510                    print_rsp_type(ri)
2511                    cw.nl()
2512                    print_wrapped_type(ri)
2513            cw.nl()
2514        else:
2515            cw.p('/* Enums */')
2516            put_op_name(parsed, cw)
2517
2518            for name, const in parsed.consts.items():
2519                if isinstance(const, EnumSet):
2520                    put_enum_to_str(parsed, cw, const)
2521            cw.nl()
2522
2523            cw.p('/* Policies */')
2524            for name in parsed.pure_nested_structs:
2525                struct = Struct(parsed, name)
2526                put_typol(cw, struct)
2527            for name in parsed.root_sets:
2528                struct = Struct(parsed, name)
2529                put_typol(cw, struct)
2530
2531            cw.p('/* Common nested types */')
2532            for attr_set, struct in parsed.pure_nested_structs.items():
2533                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2534
2535                free_rsp_nested(ri, struct)
2536                if struct.request:
2537                    put_req_nested(ri, struct)
2538                if struct.reply:
2539                    parse_rsp_nested(ri, struct)
2540
2541            for op_name, op in parsed.ops.items():
2542                cw.p(f"/* ============== {op.enum_name} ============== */")
2543                if 'do' in op and 'event' not in op:
2544                    cw.p(f"/* {op.enum_name} - do */")
2545                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2546                    print_req_free(ri)
2547                    print_rsp_free(ri)
2548                    parse_rsp_msg(ri)
2549                    print_req(ri)
2550                    cw.nl()
2551
2552                if 'dump' in op:
2553                    cw.p(f"/* {op.enum_name} - dump */")
2554                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2555                    if not ri.type_consistent:
2556                        parse_rsp_msg(ri, deref=True)
2557                    print_dump_type_free(ri)
2558                    print_dump(ri)
2559                    cw.nl()
2560
2561                if op.has_ntf:
2562                    cw.p(f"/* {op.enum_name} - notify */")
2563                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2564                    if not ri.type_consistent:
2565                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2566                    print_ntf_type_free(ri)
2567
2568            for op_name, op in parsed.ntfs.items():
2569                if 'event' in op:
2570                    cw.p(f"/* {op.enum_name} - event */")
2571
2572                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2573                    parse_rsp_msg(ri)
2574
2575                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2576                    print_ntf_type_free(ri)
2577            cw.nl()
2578            render_user_family(parsed, cw, False)
2579
2580    if args.header:
2581        cw.p(f'#endif /* {hdr_prot} */')
2582
2583    if args.out_file:
2584        out_file = open(args.out_file, 'w+')
2585        tmp_file.seek(0)
2586        shutil.copyfileobj(tmp_file, out_file)
2587
2588if __name__ == "__main__":
2589    main()
2590