xref: /openbmc/linux/tools/net/ynl/ynl-gen-c.py (revision d35ac6ac)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import os
7import re
8import yaml
9
10from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
11
12
13def c_upper(name):
14    return name.upper().replace('-', '_')
15
16
17def c_lower(name):
18    return name.lower().replace('-', '_')
19
20
21class BaseNlLib:
22    def get_family_id(self):
23        return 'ys->family_id'
24
25    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
26        ind = '\n\t\t' + '\t' * indent + ' '
27        if is_dump:
28            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
29        else:
30            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
31                   "ynl_cb_array, NLMSG_MIN_TYPE)"
32
33
34class Type(SpecAttr):
35    def __init__(self, family, attr_set, attr, value):
36        super().__init__(family, attr_set, attr, value)
37
38        self.attr = attr
39        self.attr_set = attr_set
40        self.type = attr['type']
41        self.checks = attr.get('checks', {})
42
43        if 'len' in attr:
44            self.len = attr['len']
45        if 'nested-attributes' in attr:
46            self.nested_attrs = attr['nested-attributes']
47            if self.nested_attrs == family.name:
48                self.nested_render_name = f"{family.name}"
49            else:
50                self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
51
52            if self.nested_attrs in self.family.consts:
53                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
54            else:
55                self.nested_struct_type = 'struct ' + self.nested_render_name
56
57        self.c_name = c_lower(self.name)
58        if self.c_name in _C_KW:
59            self.c_name += '_'
60
61        # Added by resolve():
62        self.enum_name = None
63        delattr(self, "enum_name")
64
65    def resolve(self):
66        if 'name-prefix' in self.attr:
67            enum_name = f"{self.attr['name-prefix']}{self.name}"
68        else:
69            enum_name = f"{self.attr_set.name_prefix}{self.name}"
70        self.enum_name = c_upper(enum_name)
71
72    def is_multi_val(self):
73        return None
74
75    def is_scalar(self):
76        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
77
78    def presence_type(self):
79        return 'bit'
80
81    def presence_member(self, space, type_filter):
82        if self.presence_type() != type_filter:
83            return
84
85        if self.presence_type() == 'bit':
86            pfx = '__' if space == 'user' else ''
87            return f"{pfx}u32 {self.c_name}:1;"
88
89        if self.presence_type() == 'len':
90            pfx = '__' if space == 'user' else ''
91            return f"{pfx}u32 {self.c_name}_len;"
92
93    def _complex_member_type(self, ri):
94        return None
95
96    def free_needs_iter(self):
97        return False
98
99    def free(self, ri, var, ref):
100        if self.is_multi_val() or self.presence_type() == 'len':
101            ri.cw.p(f'free({var}->{ref}{self.c_name});')
102
103    def arg_member(self, ri):
104        member = self._complex_member_type(ri)
105        if member:
106            arg = [member + ' *' + self.c_name]
107            if self.presence_type() == 'count':
108                arg += ['unsigned int n_' + self.c_name]
109            return arg
110        raise Exception(f"Struct member not implemented for class type {self.type}")
111
112    def struct_member(self, ri):
113        if self.is_multi_val():
114            ri.cw.p(f"unsigned int n_{self.c_name};")
115        member = self._complex_member_type(ri)
116        if member:
117            ptr = '*' if self.is_multi_val() else ''
118            ri.cw.p(f"{member} {ptr}{self.c_name};")
119            return
120        members = self.arg_member(ri)
121        for one in members:
122            ri.cw.p(one + ';')
123
124    def _attr_policy(self, policy):
125        return '{ .type = ' + policy + ', }'
126
127    def attr_policy(self, cw):
128        policy = c_upper('nla-' + self.attr['type'])
129
130        spec = self._attr_policy(policy)
131        cw.p(f"\t[{self.enum_name}] = {spec},")
132
133    def _attr_typol(self):
134        raise Exception(f"Type policy not implemented for class type {self.type}")
135
136    def attr_typol(self, cw):
137        typol = self._attr_typol()
138        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
139
140    def _attr_put_line(self, ri, var, line):
141        if self.presence_type() == 'bit':
142            ri.cw.p(f"if ({var}->_present.{self.c_name})")
143        elif self.presence_type() == 'len':
144            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
145        ri.cw.p(f"{line};")
146
147    def _attr_put_simple(self, ri, var, put_type):
148        line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
149        self._attr_put_line(ri, var, line)
150
151    def attr_put(self, ri, var):
152        raise Exception(f"Put not implemented for class type {self.type}")
153
154    def _attr_get(self, ri, var):
155        raise Exception(f"Attr get not implemented for class type {self.type}")
156
157    def attr_get(self, ri, var, first):
158        lines, init_lines, local_vars = self._attr_get(ri, var)
159        if type(lines) is str:
160            lines = [lines]
161        if type(init_lines) is str:
162            init_lines = [init_lines]
163
164        kw = 'if' if first else 'else if'
165        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
166        if local_vars:
167            for local in local_vars:
168                ri.cw.p(local)
169            ri.cw.nl()
170
171        if not self.is_multi_val():
172            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
173            ri.cw.p("return MNL_CB_ERROR;")
174            if self.presence_type() == 'bit':
175                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
176
177        if init_lines:
178            ri.cw.nl()
179            for line in init_lines:
180                ri.cw.p(line)
181
182        for line in lines:
183            ri.cw.p(line)
184        ri.cw.block_end()
185        return True
186
187    def _setter_lines(self, ri, member, presence):
188        raise Exception(f"Setter not implemented for class type {self.type}")
189
190    def setter(self, ri, space, direction, deref=False, ref=None):
191        ref = (ref if ref else []) + [self.c_name]
192        var = "req"
193        member = f"{var}->{'.'.join(ref)}"
194
195        code = []
196        presence = ''
197        for i in range(0, len(ref)):
198            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
199            if self.presence_type() == 'bit':
200                code.append(presence + ' = 1;')
201        code += self._setter_lines(ri, member, presence)
202
203        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
204        free = bool([x for x in code if 'free(' in x])
205        alloc = bool([x for x in code if 'alloc(' in x])
206        if free and not alloc:
207            func_name = '__' + func_name
208        ri.cw.write_func('static inline void', func_name, body=code,
209                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
210
211
212class TypeUnused(Type):
213    def presence_type(self):
214        return ''
215
216    def arg_member(self, ri):
217        return []
218
219    def _attr_get(self, ri, var):
220        return ['return MNL_CB_ERROR;'], None, None
221
222    def _attr_typol(self):
223        return '.type = YNL_PT_REJECT, '
224
225    def attr_policy(self, cw):
226        pass
227
228
229class TypePad(Type):
230    def presence_type(self):
231        return ''
232
233    def arg_member(self, ri):
234        return []
235
236    def _attr_typol(self):
237        return '.type = YNL_PT_IGNORE, '
238
239    def attr_put(self, ri, var):
240        pass
241
242    def attr_get(self, ri, var, first):
243        pass
244
245    def attr_policy(self, cw):
246        pass
247
248    def setter(self, ri, space, direction, deref=False, ref=None):
249        pass
250
251
252class TypeScalar(Type):
253    def __init__(self, family, attr_set, attr, value):
254        super().__init__(family, attr_set, attr, value)
255
256        self.byte_order_comment = ''
257        if 'byte-order' in attr:
258            self.byte_order_comment = f" /* {attr['byte-order']} */"
259
260        # Added by resolve():
261        self.is_bitfield = None
262        delattr(self, "is_bitfield")
263        self.type_name = None
264        delattr(self, "type_name")
265
266    def resolve(self):
267        self.resolve_up(super())
268
269        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
270            self.is_bitfield = True
271        elif 'enum' in self.attr:
272            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
273        else:
274            self.is_bitfield = False
275
276        maybe_enum = not self.is_bitfield and 'enum' in self.attr
277        if maybe_enum and self.family.consts[self.attr['enum']].enum_name:
278            self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
279        else:
280            self.type_name = '__' + self.type
281
282    def _mnl_type(self):
283        t = self.type
284        # mnl does not have a helper for signed types
285        if t[0] == 's':
286            t = 'u' + t[1:]
287        return t
288
289    def _attr_policy(self, policy):
290        if 'flags-mask' in self.checks or self.is_bitfield:
291            if self.is_bitfield:
292                enum = self.family.consts[self.attr['enum']]
293                mask = enum.get_mask(as_flags=True)
294            else:
295                flags = self.family.consts[self.checks['flags-mask']]
296                flag_cnt = len(flags['entries'])
297                mask = (1 << flag_cnt) - 1
298            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
299        elif 'min' in self.checks:
300            return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
301        elif 'enum' in self.attr:
302            enum = self.family.consts[self.attr['enum']]
303            low, high = enum.value_range()
304            if low == 0:
305                return f"NLA_POLICY_MAX({policy}, {high})"
306            return f"NLA_POLICY_RANGE({policy}, {low}, {high})"
307        return super()._attr_policy(policy)
308
309    def _attr_typol(self):
310        return f'.type = YNL_PT_U{self.type[1:]}, '
311
312    def arg_member(self, ri):
313        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
314
315    def attr_put(self, ri, var):
316        self._attr_put_simple(ri, var, self._mnl_type())
317
318    def _attr_get(self, ri, var):
319        return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
320
321    def _setter_lines(self, ri, member, presence):
322        return [f"{member} = {self.c_name};"]
323
324
325class TypeFlag(Type):
326    def arg_member(self, ri):
327        return []
328
329    def _attr_typol(self):
330        return '.type = YNL_PT_FLAG, '
331
332    def attr_put(self, ri, var):
333        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
334
335    def _attr_get(self, ri, var):
336        return [], None, None
337
338    def _setter_lines(self, ri, member, presence):
339        return []
340
341
342class TypeString(Type):
343    def arg_member(self, ri):
344        return [f"const char *{self.c_name}"]
345
346    def presence_type(self):
347        return 'len'
348
349    def struct_member(self, ri):
350        ri.cw.p(f"char *{self.c_name};")
351
352    def _attr_typol(self):
353        return f'.type = YNL_PT_NUL_STR, '
354
355    def _attr_policy(self, policy):
356        mem = '{ .type = ' + policy
357        if 'max-len' in self.checks:
358            mem += ', .len = ' + str(self.checks['max-len'])
359        mem += ', }'
360        return mem
361
362    def attr_policy(self, cw):
363        if self.checks.get('unterminated-ok', False):
364            policy = 'NLA_STRING'
365        else:
366            policy = 'NLA_NUL_STRING'
367
368        spec = self._attr_policy(policy)
369        cw.p(f"\t[{self.enum_name}] = {spec},")
370
371    def attr_put(self, ri, var):
372        self._attr_put_simple(ri, var, 'strz')
373
374    def _attr_get(self, ri, var):
375        len_mem = var + '->_present.' + self.c_name + '_len'
376        return [f"{len_mem} = len;",
377                f"{var}->{self.c_name} = malloc(len + 1);",
378                f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
379                f"{var}->{self.c_name}[len] = 0;"], \
380               ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
381               ['unsigned int len;']
382
383    def _setter_lines(self, ri, member, presence):
384        return [f"free({member});",
385                f"{presence}_len = strlen({self.c_name});",
386                f"{member} = malloc({presence}_len + 1);",
387                f'memcpy({member}, {self.c_name}, {presence}_len);',
388                f'{member}[{presence}_len] = 0;']
389
390
391class TypeBinary(Type):
392    def arg_member(self, ri):
393        return [f"const void *{self.c_name}", 'size_t len']
394
395    def presence_type(self):
396        return 'len'
397
398    def struct_member(self, ri):
399        ri.cw.p(f"void *{self.c_name};")
400
401    def _attr_typol(self):
402        return f'.type = YNL_PT_BINARY,'
403
404    def _attr_policy(self, policy):
405        mem = '{ '
406        if len(self.checks) == 1 and 'min-len' in self.checks:
407            mem += '.len = ' + str(self.checks['min-len'])
408        elif len(self.checks) == 0:
409            mem += '.type = NLA_BINARY'
410        else:
411            raise Exception('One or more of binary type checks not implemented, yet')
412        mem += ', }'
413        return mem
414
415    def attr_put(self, ri, var):
416        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
417                            f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
418
419    def _attr_get(self, ri, var):
420        len_mem = var + '->_present.' + self.c_name + '_len'
421        return [f"{len_mem} = len;",
422                f"{var}->{self.c_name} = malloc(len);",
423                f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
424               ['len = mnl_attr_get_payload_len(attr);'], \
425               ['unsigned int len;']
426
427    def _setter_lines(self, ri, member, presence):
428        return [f"free({member});",
429                f"{member} = malloc({presence}_len);",
430                f'memcpy({member}, {self.c_name}, {presence}_len);']
431
432
433class TypeNest(Type):
434    def _complex_member_type(self, ri):
435        return self.nested_struct_type
436
437    def free(self, ri, var, ref):
438        ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
439
440    def _attr_typol(self):
441        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
442
443    def _attr_policy(self, policy):
444        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
445
446    def attr_put(self, ri, var):
447        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
448                            f"{self.enum_name}, &{var}->{self.c_name})")
449
450    def _attr_get(self, ri, var):
451        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
452                     "return MNL_CB_ERROR;"]
453        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
454                      f"parg.data = &{var}->{self.c_name};"]
455        return get_lines, init_lines, None
456
457    def setter(self, ri, space, direction, deref=False, ref=None):
458        ref = (ref if ref else []) + [self.c_name]
459
460        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
461            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
462
463
464class TypeMultiAttr(Type):
465    def __init__(self, family, attr_set, attr, value, base_type):
466        super().__init__(family, attr_set, attr, value)
467
468        self.base_type = base_type
469
470    def is_multi_val(self):
471        return True
472
473    def presence_type(self):
474        return 'count'
475
476    def _mnl_type(self):
477        t = self.type
478        # mnl does not have a helper for signed types
479        if t[0] == 's':
480            t = 'u' + t[1:]
481        return t
482
483    def _complex_member_type(self, ri):
484        if 'type' not in self.attr or self.attr['type'] == 'nest':
485            return self.nested_struct_type
486        elif self.attr['type'] in scalars:
487            scalar_pfx = '__' if ri.ku_space == 'user' else ''
488            return scalar_pfx + self.attr['type']
489        else:
490            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
491
492    def free_needs_iter(self):
493        return 'type' not in self.attr or self.attr['type'] == 'nest'
494
495    def free(self, ri, var, ref):
496        if self.attr['type'] in scalars:
497            ri.cw.p(f"free({var}->{ref}{self.c_name});")
498        elif 'type' not in self.attr or self.attr['type'] == 'nest':
499            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
500            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
501            ri.cw.p(f"free({var}->{ref}{self.c_name});")
502        else:
503            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
504
505    def _attr_policy(self, policy):
506        return self.base_type._attr_policy(policy)
507
508    def _attr_typol(self):
509        return self.base_type._attr_typol()
510
511    def _attr_get(self, ri, var):
512        return f'n_{self.c_name}++;', None, None
513
514    def attr_put(self, ri, var):
515        if self.attr['type'] in scalars:
516            put_type = self._mnl_type()
517            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
518            ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
519        elif 'type' not in self.attr or self.attr['type'] == 'nest':
520            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
521            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
522                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
523        else:
524            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
525
526    def _setter_lines(self, ri, member, presence):
527        # For multi-attr we have a count, not presence, hack up the presence
528        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
529        return [f"free({member});",
530                f"{member} = {self.c_name};",
531                f"{presence} = n_{self.c_name};"]
532
533
534class TypeArrayNest(Type):
535    def is_multi_val(self):
536        return True
537
538    def presence_type(self):
539        return 'count'
540
541    def _complex_member_type(self, ri):
542        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
543            return self.nested_struct_type
544        elif self.attr['sub-type'] in scalars:
545            scalar_pfx = '__' if ri.ku_space == 'user' else ''
546            return scalar_pfx + self.attr['sub-type']
547        else:
548            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
549
550    def _attr_typol(self):
551        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
552
553    def _attr_get(self, ri, var):
554        local_vars = ['const struct nlattr *attr2;']
555        get_lines = [f'attr_{self.c_name} = attr;',
556                     'mnl_attr_for_each_nested(attr2, attr)',
557                     f'\t{var}->n_{self.c_name}++;']
558        return get_lines, None, local_vars
559
560
561class TypeNestTypeValue(Type):
562    def _complex_member_type(self, ri):
563        return self.nested_struct_type
564
565    def _attr_typol(self):
566        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
567
568    def _attr_get(self, ri, var):
569        prev = 'attr'
570        tv_args = ''
571        get_lines = []
572        local_vars = []
573        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
574                      f"parg.data = &{var}->{self.c_name};"]
575        if 'type-value' in self.attr:
576            tv_names = [c_lower(x) for x in self.attr["type-value"]]
577            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
578            local_vars += [f'__u32 {", ".join(tv_names)};']
579            for level in self.attr["type-value"]:
580                level = c_lower(level)
581                get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
582                get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
583                prev = 'attr_' + level
584
585            tv_args = f", {', '.join(tv_names)}"
586
587        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
588        return get_lines, init_lines, local_vars
589
590
591class Struct:
592    def __init__(self, family, space_name, type_list=None, inherited=None):
593        self.family = family
594        self.space_name = space_name
595        self.attr_set = family.attr_sets[space_name]
596        # Use list to catch comparisons with empty sets
597        self._inherited = inherited if inherited is not None else []
598        self.inherited = []
599
600        self.nested = type_list is None
601        if family.name == c_lower(space_name):
602            self.render_name = f"{family.name}"
603        else:
604            self.render_name = f"{family.name}_{c_lower(space_name)}"
605        self.struct_name = 'struct ' + self.render_name
606        if self.nested and space_name in family.consts:
607            self.struct_name += '_'
608        self.ptr_name = self.struct_name + ' *'
609
610        self.request = False
611        self.reply = False
612
613        self.attr_list = []
614        self.attrs = dict()
615        if type_list:
616            for t in type_list:
617                self.attr_list.append((t, self.attr_set[t]),)
618        else:
619            for t in self.attr_set:
620                self.attr_list.append((t, self.attr_set[t]),)
621
622        max_val = 0
623        self.attr_max_val = None
624        for name, attr in self.attr_list:
625            if attr.value >= max_val:
626                max_val = attr.value
627                self.attr_max_val = attr
628            self.attrs[name] = attr
629
630    def __iter__(self):
631        yield from self.attrs
632
633    def __getitem__(self, key):
634        return self.attrs[key]
635
636    def member_list(self):
637        return self.attr_list
638
639    def set_inherited(self, new_inherited):
640        if self._inherited != new_inherited:
641            raise Exception("Inheriting different members not supported")
642        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
643
644
645class EnumEntry(SpecEnumEntry):
646    def __init__(self, enum_set, yaml, prev, value_start):
647        super().__init__(enum_set, yaml, prev, value_start)
648
649        if prev:
650            self.value_change = (self.value != prev.value + 1)
651        else:
652            self.value_change = (self.value != 0)
653        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
654
655        # Added by resolve:
656        self.c_name = None
657        delattr(self, "c_name")
658
659    def resolve(self):
660        self.resolve_up(super())
661
662        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
663
664
665class EnumSet(SpecEnumSet):
666    def __init__(self, family, yaml):
667        self.render_name = c_lower(family.name + '-' + yaml['name'])
668
669        if 'enum-name' in yaml:
670            if yaml['enum-name']:
671                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
672            else:
673                self.enum_name = None
674        else:
675            self.enum_name = 'enum ' + self.render_name
676
677        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
678
679        super().__init__(family, yaml)
680
681    def new_entry(self, entry, prev_entry, value_start):
682        return EnumEntry(self, entry, prev_entry, value_start)
683
684    def value_range(self):
685        low = min([x.value for x in self.entries.values()])
686        high = max([x.value for x in self.entries.values()])
687
688        if high - low + 1 != len(self.entries):
689            raise Exception("Can't get value range for a noncontiguous enum")
690
691        return low, high
692
693
694class AttrSet(SpecAttrSet):
695    def __init__(self, family, yaml):
696        super().__init__(family, yaml)
697
698        if self.subset_of is None:
699            if 'name-prefix' in yaml:
700                pfx = yaml['name-prefix']
701            elif self.name == family.name:
702                pfx = family.name + '-a-'
703            else:
704                pfx = f"{family.name}-a-{self.name}-"
705            self.name_prefix = c_upper(pfx)
706            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
707        else:
708            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
709            self.max_name = family.attr_sets[self.subset_of].max_name
710
711        # Added by resolve:
712        self.c_name = None
713        delattr(self, "c_name")
714
715    def resolve(self):
716        self.c_name = c_lower(self.name)
717        if self.c_name in _C_KW:
718            self.c_name += '_'
719        if self.c_name == self.family.c_name:
720            self.c_name = ''
721
722    def new_attr(self, elem, value):
723        if elem['type'] in scalars:
724            t = TypeScalar(self.family, self, elem, value)
725        elif elem['type'] == 'unused':
726            t = TypeUnused(self.family, self, elem, value)
727        elif elem['type'] == 'pad':
728            t = TypePad(self.family, self, elem, value)
729        elif elem['type'] == 'flag':
730            t = TypeFlag(self.family, self, elem, value)
731        elif elem['type'] == 'string':
732            t = TypeString(self.family, self, elem, value)
733        elif elem['type'] == 'binary':
734            t = TypeBinary(self.family, self, elem, value)
735        elif elem['type'] == 'nest':
736            t = TypeNest(self.family, self, elem, value)
737        elif elem['type'] == 'array-nest':
738            t = TypeArrayNest(self.family, self, elem, value)
739        elif elem['type'] == 'nest-type-value':
740            t = TypeNestTypeValue(self.family, self, elem, value)
741        else:
742            raise Exception(f"No typed class for type {elem['type']}")
743
744        if 'multi-attr' in elem and elem['multi-attr']:
745            t = TypeMultiAttr(self.family, self, elem, value, t)
746
747        return t
748
749
750class Operation(SpecOperation):
751    def __init__(self, family, yaml, req_value, rsp_value):
752        super().__init__(family, yaml, req_value, rsp_value)
753
754        self.render_name = family.name + '_' + c_lower(self.name)
755
756        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
757                         ('dump' in yaml and 'request' in yaml['dump'])
758
759        self.has_ntf = False
760
761        # Added by resolve:
762        self.enum_name = None
763        delattr(self, "enum_name")
764
765    def resolve(self):
766        self.resolve_up(super())
767
768        if not self.is_async:
769            self.enum_name = self.family.op_prefix + c_upper(self.name)
770        else:
771            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
772
773    def mark_has_ntf(self):
774        self.has_ntf = True
775
776
777class Family(SpecFamily):
778    def __init__(self, file_name, exclude_ops):
779        # Added by resolve:
780        self.c_name = None
781        delattr(self, "c_name")
782        self.op_prefix = None
783        delattr(self, "op_prefix")
784        self.async_op_prefix = None
785        delattr(self, "async_op_prefix")
786        self.mcgrps = None
787        delattr(self, "mcgrps")
788        self.consts = None
789        delattr(self, "consts")
790        self.hooks = None
791        delattr(self, "hooks")
792
793        super().__init__(file_name, exclude_ops=exclude_ops)
794
795        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
796        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
797
798        if 'definitions' not in self.yaml:
799            self.yaml['definitions'] = []
800
801        if 'uapi-header' in self.yaml:
802            self.uapi_header = self.yaml['uapi-header']
803        else:
804            self.uapi_header = f"linux/{self.name}.h"
805
806    def resolve(self):
807        self.resolve_up(super())
808
809        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
810            raise Exception("Codegen only supported for genetlink")
811
812        self.c_name = c_lower(self.name)
813        if 'name-prefix' in self.yaml['operations']:
814            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
815        else:
816            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
817        if 'async-prefix' in self.yaml['operations']:
818            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
819        else:
820            self.async_op_prefix = self.op_prefix
821
822        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
823
824        self.hooks = dict()
825        for when in ['pre', 'post']:
826            self.hooks[when] = dict()
827            for op_mode in ['do', 'dump']:
828                self.hooks[when][op_mode] = dict()
829                self.hooks[when][op_mode]['set'] = set()
830                self.hooks[when][op_mode]['list'] = []
831
832        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
833        self.root_sets = dict()
834        # dict space-name -> set('request', 'reply')
835        self.pure_nested_structs = dict()
836
837        self._mark_notify()
838        self._mock_up_events()
839
840        self._load_root_sets()
841        self._load_nested_sets()
842        self._load_hooks()
843
844        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
845        if self.kernel_policy == 'global':
846            self._load_global_policy()
847
848    def new_enum(self, elem):
849        return EnumSet(self, elem)
850
851    def new_attr_set(self, elem):
852        return AttrSet(self, elem)
853
854    def new_operation(self, elem, req_value, rsp_value):
855        return Operation(self, elem, req_value, rsp_value)
856
857    def _mark_notify(self):
858        for op in self.msgs.values():
859            if 'notify' in op:
860                self.ops[op['notify']].mark_has_ntf()
861
862    # Fake a 'do' equivalent of all events, so that we can render their response parsing
863    def _mock_up_events(self):
864        for op in self.yaml['operations']['list']:
865            if 'event' in op:
866                op['do'] = {
867                    'reply': {
868                        'attributes': op['event']['attributes']
869                    }
870                }
871
872    def _load_root_sets(self):
873        for op_name, op in self.msgs.items():
874            if 'attribute-set' not in op:
875                continue
876
877            req_attrs = set()
878            rsp_attrs = set()
879            for op_mode in ['do', 'dump']:
880                if op_mode in op and 'request' in op[op_mode]:
881                    req_attrs.update(set(op[op_mode]['request']['attributes']))
882                if op_mode in op and 'reply' in op[op_mode]:
883                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
884            if 'event' in op:
885                rsp_attrs.update(set(op['event']['attributes']))
886
887            if op['attribute-set'] not in self.root_sets:
888                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
889            else:
890                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
891                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
892
893    def _load_nested_sets(self):
894        attr_set_queue = list(self.root_sets.keys())
895        attr_set_seen = set(self.root_sets.keys())
896
897        while len(attr_set_queue):
898            a_set = attr_set_queue.pop(0)
899            for attr, spec in self.attr_sets[a_set].items():
900                if 'nested-attributes' not in spec:
901                    continue
902
903                nested = spec['nested-attributes']
904                if nested not in attr_set_seen:
905                    attr_set_queue.append(nested)
906                    attr_set_seen.add(nested)
907
908                inherit = set()
909                if nested not in self.root_sets:
910                    if nested not in self.pure_nested_structs:
911                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
912                else:
913                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
914
915                if 'type-value' in spec:
916                    if nested in self.root_sets:
917                        raise Exception("Inheriting members to a space used as root not supported")
918                    inherit.update(set(spec['type-value']))
919                elif spec['type'] == 'array-nest':
920                    inherit.add('idx')
921                self.pure_nested_structs[nested].set_inherited(inherit)
922
923        for root_set, rs_members in self.root_sets.items():
924            for attr, spec in self.attr_sets[root_set].items():
925                if 'nested-attributes' in spec:
926                    nested = spec['nested-attributes']
927                    if attr in rs_members['request']:
928                        self.pure_nested_structs[nested].request = True
929                    if attr in rs_members['reply']:
930                        self.pure_nested_structs[nested].reply = True
931
932        # Try to reorder according to dependencies
933        pns_key_list = list(self.pure_nested_structs.keys())
934        pns_key_seen = set()
935        rounds = len(pns_key_list)**2  # it's basically bubble sort
936        for _ in range(rounds):
937            if len(pns_key_list) == 0:
938                break
939            name = pns_key_list.pop(0)
940            finished = True
941            for _, spec in self.attr_sets[name].items():
942                if 'nested-attributes' in spec:
943                    if spec['nested-attributes'] not in pns_key_seen:
944                        # Dicts are sorted, this will make struct last
945                        struct = self.pure_nested_structs.pop(name)
946                        self.pure_nested_structs[name] = struct
947                        finished = False
948                        break
949            if finished:
950                pns_key_seen.add(name)
951            else:
952                pns_key_list.append(name)
953        # Propagate the request / reply
954        for attr_set, struct in reversed(self.pure_nested_structs.items()):
955            for _, spec in self.attr_sets[attr_set].items():
956                if 'nested-attributes' in spec:
957                    child = self.pure_nested_structs.get(spec['nested-attributes'])
958                    if child:
959                        child.request |= struct.request
960                        child.reply |= struct.reply
961
962    def _load_global_policy(self):
963        global_set = set()
964        attr_set_name = None
965        for op_name, op in self.ops.items():
966            if not op:
967                continue
968            if 'attribute-set' not in op:
969                continue
970
971            if attr_set_name is None:
972                attr_set_name = op['attribute-set']
973            if attr_set_name != op['attribute-set']:
974                raise Exception('For a global policy all ops must use the same set')
975
976            for op_mode in ['do', 'dump']:
977                if op_mode in op:
978                    global_set.update(op[op_mode].get('request', []))
979
980        self.global_policy = []
981        self.global_policy_set = attr_set_name
982        for attr in self.attr_sets[attr_set_name]:
983            if attr in global_set:
984                self.global_policy.append(attr)
985
986    def _load_hooks(self):
987        for op in self.ops.values():
988            for op_mode in ['do', 'dump']:
989                if op_mode not in op:
990                    continue
991                for when in ['pre', 'post']:
992                    if when not in op[op_mode]:
993                        continue
994                    name = op[op_mode][when]
995                    if name in self.hooks[when][op_mode]['set']:
996                        continue
997                    self.hooks[when][op_mode]['set'].add(name)
998                    self.hooks[when][op_mode]['list'].append(name)
999
1000
1001class RenderInfo:
1002    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1003        self.family = family
1004        self.nl = cw.nlib
1005        self.ku_space = ku_space
1006        self.op_mode = op_mode
1007        self.op = op
1008
1009        # 'do' and 'dump' response parsing is identical
1010        self.type_consistent = True
1011        if op_mode != 'do' and 'dump' in op and 'do' in op:
1012            if ('reply' in op['do']) != ('reply' in op["dump"]):
1013                self.type_consistent = False
1014            elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1015                self.type_consistent = False
1016
1017        self.attr_set = attr_set
1018        if not self.attr_set:
1019            self.attr_set = op['attribute-set']
1020
1021        self.type_name_conflict = False
1022        if op:
1023            self.type_name = c_lower(op.name)
1024        else:
1025            self.type_name = c_lower(attr_set)
1026            if attr_set in family.consts:
1027                self.type_name_conflict = True
1028
1029        self.cw = cw
1030
1031        self.struct = dict()
1032        if op_mode == 'notify':
1033            op_mode = 'do'
1034        for op_dir in ['request', 'reply']:
1035            if op and op_dir in op[op_mode]:
1036                self.struct[op_dir] = Struct(family, self.attr_set,
1037                                             type_list=op[op_mode][op_dir]['attributes'])
1038        if op_mode == 'event':
1039            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1040
1041
1042class CodeWriter:
1043    def __init__(self, nlib, out_file):
1044        self.nlib = nlib
1045
1046        self._nl = False
1047        self._block_end = False
1048        self._silent_block = False
1049        self._ind = 0
1050        self._out = out_file
1051
1052    @classmethod
1053    def _is_cond(cls, line):
1054        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1055
1056    def p(self, line, add_ind=0):
1057        if self._block_end:
1058            self._block_end = False
1059            if line.startswith('else'):
1060                line = '} ' + line
1061            else:
1062                self._out.write('\t' * self._ind + '}\n')
1063
1064        if self._nl:
1065            self._out.write('\n')
1066            self._nl = False
1067
1068        ind = self._ind
1069        if line[-1] == ':':
1070            ind -= 1
1071        if self._silent_block:
1072            ind += 1
1073        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1074        if add_ind:
1075            ind += add_ind
1076        self._out.write('\t' * ind + line + '\n')
1077
1078    def nl(self):
1079        self._nl = True
1080
1081    def block_start(self, line=''):
1082        if line:
1083            line = line + ' '
1084        self.p(line + '{')
1085        self._ind += 1
1086
1087    def block_end(self, line=''):
1088        if line and line[0] not in {';', ','}:
1089            line = ' ' + line
1090        self._ind -= 1
1091        self._nl = False
1092        if not line:
1093            # Delay printing closing bracket in case "else" comes next
1094            if self._block_end:
1095                self._out.write('\t' * (self._ind + 1) + '}\n')
1096            self._block_end = True
1097        else:
1098            self.p('}' + line)
1099
1100    def write_doc_line(self, doc, indent=True):
1101        words = doc.split()
1102        line = ' *'
1103        for word in words:
1104            if len(line) + len(word) >= 79:
1105                self.p(line)
1106                line = ' *'
1107                if indent:
1108                    line += '  '
1109            line += ' ' + word
1110        self.p(line)
1111
1112    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1113        if not args:
1114            args = ['void']
1115
1116        if doc:
1117            self.p('/*')
1118            self.p(' * ' + doc)
1119            self.p(' */')
1120
1121        oneline = qual_ret
1122        if qual_ret[-1] != '*':
1123            oneline += ' '
1124        oneline += f"{name}({', '.join(args)}){suffix}"
1125
1126        if len(oneline) < 80:
1127            self.p(oneline)
1128            return
1129
1130        v = qual_ret
1131        if len(v) > 3:
1132            self.p(v)
1133            v = ''
1134        elif qual_ret[-1] != '*':
1135            v += ' '
1136        v += name + '('
1137        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1138        delta_ind = len(v) - len(ind)
1139        v += args[0]
1140        i = 1
1141        while i < len(args):
1142            next_len = len(v) + len(args[i])
1143            if v[0] == '\t':
1144                next_len += delta_ind
1145            if next_len > 76:
1146                self.p(v + ',')
1147                v = ind
1148            else:
1149                v += ', '
1150            v += args[i]
1151            i += 1
1152        self.p(v + ')' + suffix)
1153
1154    def write_func_lvar(self, local_vars):
1155        if not local_vars:
1156            return
1157
1158        if type(local_vars) is str:
1159            local_vars = [local_vars]
1160
1161        local_vars.sort(key=len, reverse=True)
1162        for var in local_vars:
1163            self.p(var)
1164        self.nl()
1165
1166    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1167        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1168        self.write_func_lvar(local_vars=local_vars)
1169
1170        self.block_start()
1171        for line in body:
1172            self.p(line)
1173        self.block_end()
1174
1175    def writes_defines(self, defines):
1176        longest = 0
1177        for define in defines:
1178            if len(define[0]) > longest:
1179                longest = len(define[0])
1180        longest = ((longest + 8) // 8) * 8
1181        for define in defines:
1182            line = '#define ' + define[0]
1183            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1184            if type(define[1]) is int:
1185                line += str(define[1])
1186            elif type(define[1]) is str:
1187                line += '"' + define[1] + '"'
1188            self.p(line)
1189
1190    def write_struct_init(self, members):
1191        longest = max([len(x[0]) for x in members])
1192        longest += 1  # because we prepend a .
1193        longest = ((longest + 8) // 8) * 8
1194        for one in members:
1195            line = '.' + one[0]
1196            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1197            line += '= ' + one[1] + ','
1198            self.p(line)
1199
1200
1201scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1202
1203direction_to_suffix = {
1204    'reply': '_rsp',
1205    'request': '_req',
1206    '': ''
1207}
1208
1209op_mode_to_wrapper = {
1210    'do': '',
1211    'dump': '_list',
1212    'notify': '_ntf',
1213    'event': '',
1214}
1215
1216_C_KW = {
1217    'auto',
1218    'bool',
1219    'break',
1220    'case',
1221    'char',
1222    'const',
1223    'continue',
1224    'default',
1225    'do',
1226    'double',
1227    'else',
1228    'enum',
1229    'extern',
1230    'float',
1231    'for',
1232    'goto',
1233    'if',
1234    'inline',
1235    'int',
1236    'long',
1237    'register',
1238    'return',
1239    'short',
1240    'signed',
1241    'sizeof',
1242    'static',
1243    'struct',
1244    'switch',
1245    'typedef',
1246    'union',
1247    'unsigned',
1248    'void',
1249    'volatile',
1250    'while'
1251}
1252
1253
1254def rdir(direction):
1255    if direction == 'reply':
1256        return 'request'
1257    if direction == 'request':
1258        return 'reply'
1259    return direction
1260
1261
1262def op_prefix(ri, direction, deref=False):
1263    suffix = f"_{ri.type_name}"
1264
1265    if not ri.op_mode or ri.op_mode == 'do':
1266        suffix += f"{direction_to_suffix[direction]}"
1267    else:
1268        if direction == 'request':
1269            suffix += '_req_dump'
1270        else:
1271            if ri.type_consistent:
1272                if deref:
1273                    suffix += f"{direction_to_suffix[direction]}"
1274                else:
1275                    suffix += op_mode_to_wrapper[ri.op_mode]
1276            else:
1277                suffix += '_rsp'
1278                suffix += '_dump' if deref else '_list'
1279
1280    return f"{ri.family['name']}{suffix}"
1281
1282
1283def type_name(ri, direction, deref=False):
1284    return f"struct {op_prefix(ri, direction, deref=deref)}"
1285
1286
1287def print_prototype(ri, direction, terminate=True, doc=None):
1288    suffix = ';' if terminate else ''
1289
1290    fname = ri.op.render_name
1291    if ri.op_mode == 'dump':
1292        fname += '_dump'
1293
1294    args = ['struct ynl_sock *ys']
1295    if 'request' in ri.op[ri.op_mode]:
1296        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1297
1298    ret = 'int'
1299    if 'reply' in ri.op[ri.op_mode]:
1300        ret = f"{type_name(ri, rdir(direction))} *"
1301
1302    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1303
1304
1305def print_req_prototype(ri):
1306    print_prototype(ri, "request", doc=ri.op['doc'])
1307
1308
1309def print_dump_prototype(ri):
1310    print_prototype(ri, "request")
1311
1312
1313def put_typol(cw, struct):
1314    type_max = struct.attr_set.max_name
1315    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1316
1317    for _, arg in struct.member_list():
1318        arg.attr_typol(cw)
1319
1320    cw.block_end(line=';')
1321    cw.nl()
1322
1323    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1324    cw.p(f'.max_attr = {type_max},')
1325    cw.p(f'.table = {struct.render_name}_policy,')
1326    cw.block_end(line=';')
1327    cw.nl()
1328
1329
1330def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1331    args = [f'int {arg_name}']
1332    if enum and not ('enum-name' in enum and not enum['enum-name']):
1333        args = [f'enum {render_name} {arg_name}']
1334    cw.write_func_prot('const char *', f'{render_name}_str', args)
1335    cw.block_start()
1336    if enum and enum.type == 'flags':
1337        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1338    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1339    cw.p('return NULL;')
1340    cw.p(f'return {map_name}[{arg_name}];')
1341    cw.block_end()
1342    cw.nl()
1343
1344
1345def put_op_name_fwd(family, cw):
1346    cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'], suffix=';')
1347
1348
1349def put_op_name(family, cw):
1350    map_name = f'{family.name}_op_strmap'
1351    cw.block_start(line=f"static const char * const {map_name}[] =")
1352    for op_name, op in family.msgs.items():
1353        if op.rsp_value:
1354            if op.req_value == op.rsp_value:
1355                cw.p(f'[{op.enum_name}] = "{op_name}",')
1356            else:
1357                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1358    cw.block_end(line=';')
1359    cw.nl()
1360
1361    _put_enum_to_str_helper(cw, family.name + '_op', map_name, 'op')
1362
1363
1364def put_enum_to_str_fwd(family, cw, enum):
1365    args = [f'enum {enum.render_name} value']
1366    if 'enum-name' in enum and not enum['enum-name']:
1367        args = ['int value']
1368    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1369
1370
1371def put_enum_to_str(family, cw, enum):
1372    map_name = f'{enum.render_name}_strmap'
1373    cw.block_start(line=f"static const char * const {map_name}[] =")
1374    for entry in enum.entries.values():
1375        cw.p(f'[{entry.value}] = "{entry.name}",')
1376    cw.block_end(line=';')
1377    cw.nl()
1378
1379    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1380
1381
1382def put_req_nested(ri, struct):
1383    func_args = ['struct nlmsghdr *nlh',
1384                 'unsigned int attr_type',
1385                 f'{struct.ptr_name}obj']
1386
1387    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1388    ri.cw.block_start()
1389    ri.cw.write_func_lvar('struct nlattr *nest;')
1390
1391    ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1392
1393    for _, arg in struct.member_list():
1394        arg.attr_put(ri, "obj")
1395
1396    ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1397
1398    ri.cw.nl()
1399    ri.cw.p('return 0;')
1400    ri.cw.block_end()
1401    ri.cw.nl()
1402
1403
1404def _multi_parse(ri, struct, init_lines, local_vars):
1405    if struct.nested:
1406        iter_line = "mnl_attr_for_each_nested(attr, nested)"
1407    else:
1408        iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1409
1410    array_nests = set()
1411    multi_attrs = set()
1412    needs_parg = False
1413    for arg, aspec in struct.member_list():
1414        if aspec['type'] == 'array-nest':
1415            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1416            array_nests.add(arg)
1417        if 'multi-attr' in aspec:
1418            multi_attrs.add(arg)
1419        needs_parg |= 'nested-attributes' in aspec
1420    if array_nests or multi_attrs:
1421        local_vars.append('int i;')
1422    if needs_parg:
1423        local_vars.append('struct ynl_parse_arg parg;')
1424        init_lines.append('parg.ys = yarg->ys;')
1425
1426    all_multi = array_nests | multi_attrs
1427
1428    for anest in sorted(all_multi):
1429        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1430
1431    ri.cw.block_start()
1432    ri.cw.write_func_lvar(local_vars)
1433
1434    for line in init_lines:
1435        ri.cw.p(line)
1436    ri.cw.nl()
1437
1438    for arg in struct.inherited:
1439        ri.cw.p(f'dst->{arg} = {arg};')
1440
1441    for anest in sorted(all_multi):
1442        aspec = struct[anest]
1443        ri.cw.p(f"if (dst->{aspec.c_name})")
1444        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1445
1446    ri.cw.nl()
1447    ri.cw.block_start(line=iter_line)
1448    ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1449    ri.cw.nl()
1450
1451    first = True
1452    for _, arg in struct.member_list():
1453        good = arg.attr_get(ri, 'dst', first=first)
1454        # First may be 'unused' or 'pad', ignore those
1455        first &= not good
1456
1457    ri.cw.block_end()
1458    ri.cw.nl()
1459
1460    for anest in sorted(array_nests):
1461        aspec = struct[anest]
1462
1463        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1464        ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1465        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1466        ri.cw.p('i = 0;')
1467        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1468        ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1469        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1470        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1471        ri.cw.p('return MNL_CB_ERROR;')
1472        ri.cw.p('i++;')
1473        ri.cw.block_end()
1474        ri.cw.block_end()
1475    ri.cw.nl()
1476
1477    for anest in sorted(multi_attrs):
1478        aspec = struct[anest]
1479        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1480        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1481        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1482        ri.cw.p('i = 0;')
1483        if 'nested-attributes' in aspec:
1484            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1485        ri.cw.block_start(line=iter_line)
1486        ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1487        if 'nested-attributes' in aspec:
1488            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1489            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1490            ri.cw.p('return MNL_CB_ERROR;')
1491        elif aspec['type'] in scalars:
1492            t = aspec['type']
1493            if t[0] == 's':
1494                t = 'u' + t[1:]
1495            ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1496        else:
1497            raise Exception('Nest parsing type not supported yet')
1498        ri.cw.p('i++;')
1499        ri.cw.block_end()
1500        ri.cw.block_end()
1501        ri.cw.block_end()
1502    ri.cw.nl()
1503
1504    if struct.nested:
1505        ri.cw.p('return 0;')
1506    else:
1507        ri.cw.p('return MNL_CB_OK;')
1508    ri.cw.block_end()
1509    ri.cw.nl()
1510
1511
1512def parse_rsp_nested(ri, struct):
1513    func_args = ['struct ynl_parse_arg *yarg',
1514                 'const struct nlattr *nested']
1515    for arg in struct.inherited:
1516        func_args.append('__u32 ' + arg)
1517
1518    local_vars = ['const struct nlattr *attr;',
1519                  f'{struct.ptr_name}dst = yarg->data;']
1520    init_lines = []
1521
1522    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1523
1524    _multi_parse(ri, struct, init_lines, local_vars)
1525
1526
1527def parse_rsp_msg(ri, deref=False):
1528    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1529        return
1530
1531    func_args = ['const struct nlmsghdr *nlh',
1532                 'void *data']
1533
1534    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1535                  'struct ynl_parse_arg *yarg = data;',
1536                  'const struct nlattr *attr;']
1537    init_lines = ['dst = yarg->data;']
1538
1539    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1540
1541    _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1542
1543
1544def print_req(ri):
1545    ret_ok = '0'
1546    ret_err = '-1'
1547    direction = "request"
1548    local_vars = ['struct nlmsghdr *nlh;',
1549                  'int err;']
1550
1551    if 'reply' in ri.op[ri.op_mode]:
1552        ret_ok = 'rsp'
1553        ret_err = 'NULL'
1554        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1555                       'struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };']
1556
1557    print_prototype(ri, direction, terminate=False)
1558    ri.cw.block_start()
1559    ri.cw.write_func_lvar(local_vars)
1560
1561    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1562
1563    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1564    if 'reply' in ri.op[ri.op_mode]:
1565        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1566    ri.cw.nl()
1567    for _, attr in ri.struct["request"].member_list():
1568        attr.attr_put(ri, "req")
1569    ri.cw.nl()
1570
1571    parse_arg = "NULL"
1572    if 'reply' in ri.op[ri.op_mode]:
1573        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1574        ri.cw.p('yrs.yarg.data = rsp;')
1575        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1576        if ri.op.value is not None:
1577            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1578        else:
1579            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1580        ri.cw.nl()
1581        parse_arg = '&yrs'
1582    ri.cw.p(f"err = ynl_exec(ys, nlh, {parse_arg});")
1583    ri.cw.p('if (err < 0)')
1584    if 'reply' in ri.op[ri.op_mode]:
1585        ri.cw.p('goto err_free;')
1586    else:
1587        ri.cw.p('return -1;')
1588    ri.cw.nl()
1589
1590    ri.cw.p(f"return {ret_ok};")
1591    ri.cw.nl()
1592
1593    if 'reply' in ri.op[ri.op_mode]:
1594        ri.cw.p('err_free:')
1595        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1596        ri.cw.p(f"return {ret_err};")
1597
1598    ri.cw.block_end()
1599
1600
1601def print_dump(ri):
1602    direction = "request"
1603    print_prototype(ri, direction, terminate=False)
1604    ri.cw.block_start()
1605    local_vars = ['struct ynl_dump_state yds = {};',
1606                  'struct nlmsghdr *nlh;',
1607                  'int err;']
1608
1609    for var in local_vars:
1610        ri.cw.p(f'{var}')
1611    ri.cw.nl()
1612
1613    ri.cw.p('yds.ys = ys;')
1614    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1615    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1616    if ri.op.value is not None:
1617        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1618    else:
1619        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1620    ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1621    ri.cw.nl()
1622    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1623
1624    if "request" in ri.op[ri.op_mode]:
1625        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1626        ri.cw.nl()
1627        for _, attr in ri.struct["request"].member_list():
1628            attr.attr_put(ri, "req")
1629    ri.cw.nl()
1630
1631    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1632    ri.cw.p('if (err < 0)')
1633    ri.cw.p('goto free_list;')
1634    ri.cw.nl()
1635
1636    ri.cw.p('return yds.first;')
1637    ri.cw.nl()
1638    ri.cw.p('free_list:')
1639    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1640    ri.cw.p('return NULL;')
1641    ri.cw.block_end()
1642
1643
1644def call_free(ri, direction, var):
1645    return f"{op_prefix(ri, direction)}_free({var});"
1646
1647
1648def free_arg_name(direction):
1649    if direction:
1650        return direction_to_suffix[direction][1:]
1651    return 'obj'
1652
1653
1654def print_alloc_wrapper(ri, direction):
1655    name = op_prefix(ri, direction)
1656    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1657    ri.cw.block_start()
1658    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1659    ri.cw.block_end()
1660
1661
1662def print_free_prototype(ri, direction, suffix=';'):
1663    name = op_prefix(ri, direction)
1664    struct_name = name
1665    if ri.type_name_conflict:
1666        struct_name += '_'
1667    arg = free_arg_name(direction)
1668    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1669
1670
1671def _print_type(ri, direction, struct):
1672    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1673    if not direction and ri.type_name_conflict:
1674        suffix += '_'
1675
1676    if ri.op_mode == 'dump':
1677        suffix += '_dump'
1678
1679    ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1680
1681    meta_started = False
1682    for _, attr in struct.member_list():
1683        for type_filter in ['len', 'bit']:
1684            line = attr.presence_member(ri.ku_space, type_filter)
1685            if line:
1686                if not meta_started:
1687                    ri.cw.block_start(line=f"struct")
1688                    meta_started = True
1689                ri.cw.p(line)
1690    if meta_started:
1691        ri.cw.block_end(line='_present;')
1692        ri.cw.nl()
1693
1694    for arg in struct.inherited:
1695        ri.cw.p(f"__u32 {arg};")
1696
1697    for _, attr in struct.member_list():
1698        attr.struct_member(ri)
1699
1700    ri.cw.block_end(line=';')
1701    ri.cw.nl()
1702
1703
1704def print_type(ri, direction):
1705    _print_type(ri, direction, ri.struct[direction])
1706
1707
1708def print_type_full(ri, struct):
1709    _print_type(ri, "", struct)
1710
1711
1712def print_type_helpers(ri, direction, deref=False):
1713    print_free_prototype(ri, direction)
1714    ri.cw.nl()
1715
1716    if ri.ku_space == 'user' and direction == 'request':
1717        for _, attr in ri.struct[direction].member_list():
1718            attr.setter(ri, ri.attr_set, direction, deref=deref)
1719    ri.cw.nl()
1720
1721
1722def print_req_type_helpers(ri):
1723    print_alloc_wrapper(ri, "request")
1724    print_type_helpers(ri, "request")
1725
1726
1727def print_rsp_type_helpers(ri):
1728    if 'reply' not in ri.op[ri.op_mode]:
1729        return
1730    print_type_helpers(ri, "reply")
1731
1732
1733def print_parse_prototype(ri, direction, terminate=True):
1734    suffix = "_rsp" if direction == "reply" else "_req"
1735    term = ';' if terminate else ''
1736
1737    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1738                          ['const struct nlattr **tb',
1739                           f"struct {ri.op.render_name}{suffix} *req"],
1740                          suffix=term)
1741
1742
1743def print_req_type(ri):
1744    print_type(ri, "request")
1745
1746
1747def print_req_free(ri):
1748    if 'request' not in ri.op[ri.op_mode]:
1749        return
1750    _free_type(ri, 'request', ri.struct['request'])
1751
1752
1753def print_rsp_type(ri):
1754    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1755        direction = 'reply'
1756    elif ri.op_mode == 'event':
1757        direction = 'reply'
1758    else:
1759        return
1760    print_type(ri, direction)
1761
1762
1763def print_wrapped_type(ri):
1764    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1765    if ri.op_mode == 'dump':
1766        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1767    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1768        ri.cw.p('__u16 family;')
1769        ri.cw.p('__u8 cmd;')
1770        ri.cw.p('struct ynl_ntf_base_type *next;')
1771        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1772    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1773    ri.cw.block_end(line=';')
1774    ri.cw.nl()
1775    print_free_prototype(ri, 'reply')
1776    ri.cw.nl()
1777
1778
1779def _free_type_members_iter(ri, struct):
1780    for _, attr in struct.member_list():
1781        if attr.free_needs_iter():
1782            ri.cw.p('unsigned int i;')
1783            ri.cw.nl()
1784            break
1785
1786
1787def _free_type_members(ri, var, struct, ref=''):
1788    for _, attr in struct.member_list():
1789        attr.free(ri, var, ref)
1790
1791
1792def _free_type(ri, direction, struct):
1793    var = free_arg_name(direction)
1794
1795    print_free_prototype(ri, direction, suffix='')
1796    ri.cw.block_start()
1797    _free_type_members_iter(ri, struct)
1798    _free_type_members(ri, var, struct)
1799    if direction:
1800        ri.cw.p(f'free({var});')
1801    ri.cw.block_end()
1802    ri.cw.nl()
1803
1804
1805def free_rsp_nested(ri, struct):
1806    _free_type(ri, "", struct)
1807
1808
1809def print_rsp_free(ri):
1810    if 'reply' not in ri.op[ri.op_mode]:
1811        return
1812    _free_type(ri, 'reply', ri.struct['reply'])
1813
1814
1815def print_dump_type_free(ri):
1816    sub_type = type_name(ri, 'reply')
1817
1818    print_free_prototype(ri, 'reply', suffix='')
1819    ri.cw.block_start()
1820    ri.cw.p(f"{sub_type} *next = rsp;")
1821    ri.cw.nl()
1822    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1823    _free_type_members_iter(ri, ri.struct['reply'])
1824    ri.cw.p('rsp = next;')
1825    ri.cw.p('next = rsp->next;')
1826    ri.cw.nl()
1827
1828    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1829    ri.cw.p(f'free(rsp);')
1830    ri.cw.block_end()
1831    ri.cw.block_end()
1832    ri.cw.nl()
1833
1834
1835def print_ntf_type_free(ri):
1836    print_free_prototype(ri, 'reply', suffix='')
1837    ri.cw.block_start()
1838    _free_type_members_iter(ri, ri.struct['reply'])
1839    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1840    ri.cw.p(f'free(rsp);')
1841    ri.cw.block_end()
1842    ri.cw.nl()
1843
1844
1845def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1846    if terminate and ri and kernel_can_gen_family_struct(struct.family):
1847        return
1848
1849    if terminate:
1850        prefix = 'extern '
1851    else:
1852        if kernel_can_gen_family_struct(struct.family) and ri:
1853            prefix = 'static '
1854        else:
1855            prefix = ''
1856
1857    suffix = ';' if terminate else ' = {'
1858
1859    max_attr = struct.attr_max_val
1860    if ri:
1861        name = ri.op.render_name
1862        if ri.op.dual_policy:
1863            name += '_' + ri.op_mode
1864    else:
1865        name = struct.render_name
1866    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1867
1868
1869def print_req_policy(cw, struct, ri=None):
1870    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1871    for _, arg in struct.member_list():
1872        arg.attr_policy(cw)
1873    cw.p("};")
1874
1875
1876def kernel_can_gen_family_struct(family):
1877    return family.proto == 'genetlink'
1878
1879
1880def print_kernel_op_table_fwd(family, cw, terminate):
1881    exported = not kernel_can_gen_family_struct(family)
1882
1883    if not terminate or exported:
1884        cw.p(f"/* Ops table for {family.name} */")
1885
1886        pol_to_struct = {'global': 'genl_small_ops',
1887                         'per-op': 'genl_ops',
1888                         'split': 'genl_split_ops'}
1889        struct_type = pol_to_struct[family.kernel_policy]
1890
1891        if not exported:
1892            cnt = ""
1893        elif family.kernel_policy == 'split':
1894            cnt = 0
1895            for op in family.ops.values():
1896                if 'do' in op:
1897                    cnt += 1
1898                if 'dump' in op:
1899                    cnt += 1
1900        else:
1901            cnt = len(family.ops)
1902
1903        qual = 'static const' if not exported else 'const'
1904        line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1905        if terminate:
1906            cw.p(f"extern {line};")
1907        else:
1908            cw.block_start(line=line + ' =')
1909
1910    if not terminate:
1911        return
1912
1913    cw.nl()
1914    for name in family.hooks['pre']['do']['list']:
1915        cw.write_func_prot('int', c_lower(name),
1916                           ['const struct genl_split_ops *ops',
1917                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1918    for name in family.hooks['post']['do']['list']:
1919        cw.write_func_prot('void', c_lower(name),
1920                           ['const struct genl_split_ops *ops',
1921                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1922    for name in family.hooks['pre']['dump']['list']:
1923        cw.write_func_prot('int', c_lower(name),
1924                           ['struct netlink_callback *cb'], suffix=';')
1925    for name in family.hooks['post']['dump']['list']:
1926        cw.write_func_prot('int', c_lower(name),
1927                           ['struct netlink_callback *cb'], suffix=';')
1928
1929    cw.nl()
1930
1931    for op_name, op in family.ops.items():
1932        if op.is_async:
1933            continue
1934
1935        if 'do' in op:
1936            name = c_lower(f"{family.name}-nl-{op_name}-doit")
1937            cw.write_func_prot('int', name,
1938                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1939
1940        if 'dump' in op:
1941            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1942            cw.write_func_prot('int', name,
1943                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1944    cw.nl()
1945
1946
1947def print_kernel_op_table_hdr(family, cw):
1948    print_kernel_op_table_fwd(family, cw, terminate=True)
1949
1950
1951def print_kernel_op_table(family, cw):
1952    print_kernel_op_table_fwd(family, cw, terminate=False)
1953    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1954        for op_name, op in family.ops.items():
1955            if op.is_async:
1956                continue
1957
1958            cw.block_start()
1959            members = [('cmd', op.enum_name)]
1960            if 'dont-validate' in op:
1961                members.append(('validate',
1962                                ' | '.join([c_upper('genl-dont-validate-' + x)
1963                                            for x in op['dont-validate']])), )
1964            for op_mode in ['do', 'dump']:
1965                if op_mode in op:
1966                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1967                    members.append((op_mode + 'it', name))
1968            if family.kernel_policy == 'per-op':
1969                struct = Struct(family, op['attribute-set'],
1970                                type_list=op['do']['request']['attributes'])
1971
1972                name = c_lower(f"{family.name}-{op_name}-nl-policy")
1973                members.append(('policy', name))
1974                members.append(('maxattr', struct.attr_max_val.enum_name))
1975            if 'flags' in op:
1976                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1977            cw.write_struct_init(members)
1978            cw.block_end(line=',')
1979    elif family.kernel_policy == 'split':
1980        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1981                    'dump': {'pre': 'start', 'post': 'done'}}
1982
1983        for op_name, op in family.ops.items():
1984            for op_mode in ['do', 'dump']:
1985                if op.is_async or op_mode not in op:
1986                    continue
1987
1988                cw.block_start()
1989                members = [('cmd', op.enum_name)]
1990                if 'dont-validate' in op:
1991                    members.append(('validate',
1992                                    ' | '.join([c_upper('genl-dont-validate-' + x)
1993                                                for x in op['dont-validate']])), )
1994                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1995                if 'pre' in op[op_mode]:
1996                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
1997                members.append((op_mode + 'it', name))
1998                if 'post' in op[op_mode]:
1999                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2000                if 'request' in op[op_mode]:
2001                    struct = Struct(family, op['attribute-set'],
2002                                    type_list=op[op_mode]['request']['attributes'])
2003
2004                    if op.dual_policy:
2005                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2006                    else:
2007                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
2008                    members.append(('policy', name))
2009                    members.append(('maxattr', struct.attr_max_val.enum_name))
2010                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2011                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2012                cw.write_struct_init(members)
2013                cw.block_end(line=',')
2014
2015    cw.block_end(line=';')
2016    cw.nl()
2017
2018
2019def print_kernel_mcgrp_hdr(family, cw):
2020    if not family.mcgrps['list']:
2021        return
2022
2023    cw.block_start('enum')
2024    for grp in family.mcgrps['list']:
2025        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2026        cw.p(grp_id)
2027    cw.block_end(';')
2028    cw.nl()
2029
2030
2031def print_kernel_mcgrp_src(family, cw):
2032    if not family.mcgrps['list']:
2033        return
2034
2035    cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
2036    for grp in family.mcgrps['list']:
2037        name = grp['name']
2038        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2039        cw.p('[' + grp_id + '] = { "' + name + '", },')
2040    cw.block_end(';')
2041    cw.nl()
2042
2043
2044def print_kernel_family_struct_hdr(family, cw):
2045    if not kernel_can_gen_family_struct(family):
2046        return
2047
2048    cw.p(f"extern struct genl_family {family.name}_nl_family;")
2049    cw.nl()
2050
2051
2052def print_kernel_family_struct_src(family, cw):
2053    if not kernel_can_gen_family_struct(family):
2054        return
2055
2056    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2057    cw.p('.name\t\t= ' + family.fam_key + ',')
2058    cw.p('.version\t= ' + family.ver_key + ',')
2059    cw.p('.netnsok\t= true,')
2060    cw.p('.parallel_ops\t= true,')
2061    cw.p('.module\t\t= THIS_MODULE,')
2062    if family.kernel_policy == 'per-op':
2063        cw.p(f'.ops\t\t= {family.name}_nl_ops,')
2064        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
2065    elif family.kernel_policy == 'split':
2066        cw.p(f'.split_ops\t= {family.name}_nl_ops,')
2067        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
2068    if family.mcgrps['list']:
2069        cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
2070        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
2071    cw.block_end(';')
2072
2073
2074def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2075    start_line = 'enum'
2076    if enum_name in obj:
2077        if obj[enum_name]:
2078            start_line = 'enum ' + c_lower(obj[enum_name])
2079    elif ckey and ckey in obj:
2080        start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
2081    cw.block_start(line=start_line)
2082
2083
2084def render_uapi(family, cw):
2085    hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
2086    cw.p('#ifndef ' + hdr_prot)
2087    cw.p('#define ' + hdr_prot)
2088    cw.nl()
2089
2090    defines = [(family.fam_key, family["name"]),
2091               (family.ver_key, family.get('version', 1))]
2092    cw.writes_defines(defines)
2093    cw.nl()
2094
2095    defines = []
2096    for const in family['definitions']:
2097        if const['type'] != 'const':
2098            cw.writes_defines(defines)
2099            defines = []
2100            cw.nl()
2101
2102        # Write kdoc for enum and flags (one day maybe also structs)
2103        if const['type'] == 'enum' or const['type'] == 'flags':
2104            enum = family.consts[const['name']]
2105
2106            if enum.has_doc():
2107                cw.p('/**')
2108                doc = ''
2109                if 'doc' in enum:
2110                    doc = ' - ' + enum['doc']
2111                cw.write_doc_line(enum.enum_name + doc)
2112                for entry in enum.entries.values():
2113                    if entry.has_doc():
2114                        doc = '@' + entry.c_name + ': ' + entry['doc']
2115                        cw.write_doc_line(doc)
2116                cw.p(' */')
2117
2118            uapi_enum_start(family, cw, const, 'name')
2119            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2120            for entry in enum.entries.values():
2121                suffix = ','
2122                if entry.value_change:
2123                    suffix = f" = {entry.user_value()}" + suffix
2124                cw.p(entry.c_name + suffix)
2125
2126            if const.get('render-max', False):
2127                cw.nl()
2128                if const['type'] == 'flags':
2129                    max_name = c_upper(name_pfx + 'mask')
2130                    max_val = f' = {enum.get_mask()},'
2131                    cw.p(max_name + max_val)
2132                else:
2133                    max_name = c_upper(name_pfx + 'max')
2134                    cw.p('__' + max_name + ',')
2135                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2136            cw.block_end(line=';')
2137            cw.nl()
2138        elif const['type'] == 'const':
2139            defines.append([c_upper(family.get('c-define-name',
2140                                               f"{family.name}-{const['name']}")),
2141                            const['value']])
2142
2143    if defines:
2144        cw.writes_defines(defines)
2145        cw.nl()
2146
2147    max_by_define = family.get('max-by-define', False)
2148
2149    for _, attr_set in family.attr_sets.items():
2150        if attr_set.subset_of:
2151            continue
2152
2153        cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2154        max_value = f"({cnt_name} - 1)"
2155
2156        val = 0
2157        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2158        for _, attr in attr_set.items():
2159            suffix = ','
2160            if attr.value != val:
2161                suffix = f" = {attr.value},"
2162                val = attr.value
2163            val += 1
2164            cw.p(attr.enum_name + suffix)
2165        cw.nl()
2166        cw.p(cnt_name + ('' if max_by_define else ','))
2167        if not max_by_define:
2168            cw.p(f"{attr_set.max_name} = {max_value}")
2169        cw.block_end(line=';')
2170        if max_by_define:
2171            cw.p(f"#define {attr_set.max_name} {max_value}")
2172        cw.nl()
2173
2174    # Commands
2175    separate_ntf = 'async-prefix' in family['operations']
2176
2177    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2178    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2179    max_value = f"({cnt_name} - 1)"
2180
2181    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2182    val = 0
2183    for op in family.msgs.values():
2184        if separate_ntf and ('notify' in op or 'event' in op):
2185            continue
2186
2187        suffix = ','
2188        if op.value != val:
2189            suffix = f" = {op.value},"
2190            val = op.value
2191        cw.p(op.enum_name + suffix)
2192        val += 1
2193    cw.nl()
2194    cw.p(cnt_name + ('' if max_by_define else ','))
2195    if not max_by_define:
2196        cw.p(f"{max_name} = {max_value}")
2197    cw.block_end(line=';')
2198    if max_by_define:
2199        cw.p(f"#define {max_name} {max_value}")
2200    cw.nl()
2201
2202    if separate_ntf:
2203        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2204        for op in family.msgs.values():
2205            if separate_ntf and not ('notify' in op or 'event' in op):
2206                continue
2207
2208            suffix = ','
2209            if 'value' in op:
2210                suffix = f" = {op['value']},"
2211            cw.p(op.enum_name + suffix)
2212        cw.block_end(line=';')
2213        cw.nl()
2214
2215    # Multicast
2216    defines = []
2217    for grp in family.mcgrps['list']:
2218        name = grp['name']
2219        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2220                        f'{name}'])
2221    cw.nl()
2222    if defines:
2223        cw.writes_defines(defines)
2224        cw.nl()
2225
2226    cw.p(f'#endif /* {hdr_prot} */')
2227
2228
2229def _render_user_ntf_entry(ri, op):
2230    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2231    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2232    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2233    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2234    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2235    ri.cw.block_end(line=',')
2236
2237
2238def render_user_family(family, cw, prototype):
2239    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2240    if prototype:
2241        cw.p(f'extern {symbol};')
2242        return
2243
2244    if family.ntfs:
2245        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2246        for ntf_op_name, ntf_op in family.ntfs.items():
2247            if 'notify' in ntf_op:
2248                op = family.ops[ntf_op['notify']]
2249                ri = RenderInfo(cw, family, "user", op, "notify")
2250            elif 'event' in ntf_op:
2251                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2252            else:
2253                raise Exception('Invalid notification ' + ntf_op_name)
2254            _render_user_ntf_entry(ri, ntf_op)
2255        for op_name, op in family.ops.items():
2256            if 'event' not in op:
2257                continue
2258            ri = RenderInfo(cw, family, "user", op, "event")
2259            _render_user_ntf_entry(ri, op)
2260        cw.block_end(line=";")
2261        cw.nl()
2262
2263    cw.block_start(f'{symbol} = ')
2264    cw.p(f'.name\t\t= "{family.name}",')
2265    if family.ntfs:
2266        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2267        cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2268    cw.block_end(line=';')
2269
2270
2271def find_kernel_root(full_path):
2272    sub_path = ''
2273    while True:
2274        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2275        full_path = os.path.dirname(full_path)
2276        maintainers = os.path.join(full_path, "MAINTAINERS")
2277        if os.path.exists(maintainers):
2278            return full_path, sub_path[:-1]
2279
2280
2281def main():
2282    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2283    parser.add_argument('--mode', dest='mode', type=str, required=True)
2284    parser.add_argument('--spec', dest='spec', type=str, required=True)
2285    parser.add_argument('--header', dest='header', action='store_true', default=None)
2286    parser.add_argument('--source', dest='header', action='store_false')
2287    parser.add_argument('--user-header', nargs='+', default=[])
2288    parser.add_argument('--exclude-op', action='append', default=[])
2289    parser.add_argument('-o', dest='out_file', type=str)
2290    args = parser.parse_args()
2291
2292    out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout
2293
2294    if args.header is None:
2295        parser.error("--header or --source is required")
2296
2297    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2298
2299    try:
2300        parsed = Family(args.spec, exclude_ops)
2301        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2302            print('Spec license:', parsed.license)
2303            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2304            os.sys.exit(1)
2305    except yaml.YAMLError as exc:
2306        print(exc)
2307        os.sys.exit(1)
2308        return
2309
2310    supported_models = ['unified']
2311    if args.mode == 'user':
2312        supported_models += ['directional']
2313    if parsed.msg_id_model not in supported_models:
2314        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2315        os.sys.exit(1)
2316
2317    cw = CodeWriter(BaseNlLib(), out_file)
2318
2319    _, spec_kernel = find_kernel_root(args.spec)
2320    if args.mode == 'uapi' or args.header:
2321        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2322    else:
2323        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2324    cw.p("/* Do not edit directly, auto-generated from: */")
2325    cw.p(f"/*\t{spec_kernel} */")
2326    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2327    if args.exclude_op or args.user_header:
2328        line = ''
2329        line += ' --user-header '.join([''] + args.user_header)
2330        line += ' --exclude-op '.join([''] + args.exclude_op)
2331        cw.p(f'/* YNL-ARG{line} */')
2332    cw.nl()
2333
2334    if args.mode == 'uapi':
2335        render_uapi(parsed, cw)
2336        return
2337
2338    hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2339    if args.header:
2340        cw.p('#ifndef ' + hdr_prot)
2341        cw.p('#define ' + hdr_prot)
2342        cw.nl()
2343
2344    if args.mode == 'kernel':
2345        cw.p('#include <net/netlink.h>')
2346        cw.p('#include <net/genetlink.h>')
2347        cw.nl()
2348        if not args.header:
2349            if args.out_file:
2350                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2351            cw.nl()
2352        headers = ['uapi/' + parsed.uapi_header]
2353    else:
2354        cw.p('#include <stdlib.h>')
2355        cw.p('#include <string.h>')
2356        if args.header:
2357            cw.p('#include <linux/types.h>')
2358        else:
2359            cw.p(f'#include "{parsed.name}-user.h"')
2360            cw.p('#include "ynl.h"')
2361        headers = [parsed.uapi_header]
2362    for definition in parsed['definitions']:
2363        if 'header' in definition:
2364            headers.append(definition['header'])
2365    for one in headers:
2366        cw.p(f"#include <{one}>")
2367    cw.nl()
2368
2369    if args.mode == "user":
2370        if not args.header:
2371            cw.p("#include <libmnl/libmnl.h>")
2372            cw.p("#include <linux/genetlink.h>")
2373            cw.nl()
2374            for one in args.user_header:
2375                cw.p(f'#include "{one}"')
2376        else:
2377            cw.p('struct ynl_sock;')
2378            cw.nl()
2379            render_user_family(parsed, cw, True)
2380        cw.nl()
2381
2382    if args.mode == "kernel":
2383        if args.header:
2384            for _, struct in sorted(parsed.pure_nested_structs.items()):
2385                if struct.request:
2386                    cw.p('/* Common nested types */')
2387                    break
2388            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2389                if struct.request:
2390                    print_req_policy_fwd(cw, struct)
2391            cw.nl()
2392
2393            if parsed.kernel_policy == 'global':
2394                cw.p(f"/* Global operation policy for {parsed.name} */")
2395
2396                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2397                print_req_policy_fwd(cw, struct)
2398                cw.nl()
2399
2400            if parsed.kernel_policy in {'per-op', 'split'}:
2401                for op_name, op in parsed.ops.items():
2402                    if 'do' in op and 'event' not in op:
2403                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2404                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2405                        cw.nl()
2406
2407            print_kernel_op_table_hdr(parsed, cw)
2408            print_kernel_mcgrp_hdr(parsed, cw)
2409            print_kernel_family_struct_hdr(parsed, cw)
2410        else:
2411            for _, struct in sorted(parsed.pure_nested_structs.items()):
2412                if struct.request:
2413                    cw.p('/* Common nested types */')
2414                    break
2415            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2416                if struct.request:
2417                    print_req_policy(cw, struct)
2418            cw.nl()
2419
2420            if parsed.kernel_policy == 'global':
2421                cw.p(f"/* Global operation policy for {parsed.name} */")
2422
2423                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2424                print_req_policy(cw, struct)
2425                cw.nl()
2426
2427            for op_name, op in parsed.ops.items():
2428                if parsed.kernel_policy in {'per-op', 'split'}:
2429                    for op_mode in ['do', 'dump']:
2430                        if op_mode in op and 'request' in op[op_mode]:
2431                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2432                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2433                            print_req_policy(cw, ri.struct['request'], ri=ri)
2434                            cw.nl()
2435
2436            print_kernel_op_table(parsed, cw)
2437            print_kernel_mcgrp_src(parsed, cw)
2438            print_kernel_family_struct_src(parsed, cw)
2439
2440    if args.mode == "user":
2441        if args.header:
2442            cw.p('/* Enums */')
2443            put_op_name_fwd(parsed, cw)
2444
2445            for name, const in parsed.consts.items():
2446                if isinstance(const, EnumSet):
2447                    put_enum_to_str_fwd(parsed, cw, const)
2448            cw.nl()
2449
2450            cw.p('/* Common nested types */')
2451            for attr_set, struct in parsed.pure_nested_structs.items():
2452                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2453                print_type_full(ri, struct)
2454
2455            for op_name, op in parsed.ops.items():
2456                cw.p(f"/* ============== {op.enum_name} ============== */")
2457
2458                if 'do' in op and 'event' not in op:
2459                    cw.p(f"/* {op.enum_name} - do */")
2460                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2461                    print_req_type(ri)
2462                    print_req_type_helpers(ri)
2463                    cw.nl()
2464                    print_rsp_type(ri)
2465                    print_rsp_type_helpers(ri)
2466                    cw.nl()
2467                    print_req_prototype(ri)
2468                    cw.nl()
2469
2470                if 'dump' in op:
2471                    cw.p(f"/* {op.enum_name} - dump */")
2472                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2473                    if 'request' in op['dump']:
2474                        print_req_type(ri)
2475                        print_req_type_helpers(ri)
2476                    if not ri.type_consistent:
2477                        print_rsp_type(ri)
2478                    print_wrapped_type(ri)
2479                    print_dump_prototype(ri)
2480                    cw.nl()
2481
2482                if op.has_ntf:
2483                    cw.p(f"/* {op.enum_name} - notify */")
2484                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2485                    if not ri.type_consistent:
2486                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2487                    print_wrapped_type(ri)
2488
2489            for op_name, op in parsed.ntfs.items():
2490                if 'event' in op:
2491                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2492                    cw.p(f"/* {op.enum_name} - event */")
2493                    print_rsp_type(ri)
2494                    cw.nl()
2495                    print_wrapped_type(ri)
2496            cw.nl()
2497        else:
2498            cw.p('/* Enums */')
2499            put_op_name(parsed, cw)
2500
2501            for name, const in parsed.consts.items():
2502                if isinstance(const, EnumSet):
2503                    put_enum_to_str(parsed, cw, const)
2504            cw.nl()
2505
2506            cw.p('/* Policies */')
2507            for name in parsed.pure_nested_structs:
2508                struct = Struct(parsed, name)
2509                put_typol(cw, struct)
2510            for name in parsed.root_sets:
2511                struct = Struct(parsed, name)
2512                put_typol(cw, struct)
2513
2514            cw.p('/* Common nested types */')
2515            for attr_set, struct in parsed.pure_nested_structs.items():
2516                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2517
2518                free_rsp_nested(ri, struct)
2519                if struct.request:
2520                    put_req_nested(ri, struct)
2521                if struct.reply:
2522                    parse_rsp_nested(ri, struct)
2523
2524            for op_name, op in parsed.ops.items():
2525                cw.p(f"/* ============== {op.enum_name} ============== */")
2526                if 'do' in op and 'event' not in op:
2527                    cw.p(f"/* {op.enum_name} - do */")
2528                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2529                    print_req_free(ri)
2530                    print_rsp_free(ri)
2531                    parse_rsp_msg(ri)
2532                    print_req(ri)
2533                    cw.nl()
2534
2535                if 'dump' in op:
2536                    cw.p(f"/* {op.enum_name} - dump */")
2537                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2538                    if not ri.type_consistent:
2539                        parse_rsp_msg(ri, deref=True)
2540                    print_dump_type_free(ri)
2541                    print_dump(ri)
2542                    cw.nl()
2543
2544                if op.has_ntf:
2545                    cw.p(f"/* {op.enum_name} - notify */")
2546                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2547                    if not ri.type_consistent:
2548                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2549                    print_ntf_type_free(ri)
2550
2551            for op_name, op in parsed.ntfs.items():
2552                if 'event' in op:
2553                    cw.p(f"/* {op.enum_name} - event */")
2554
2555                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2556                    parse_rsp_msg(ri)
2557
2558                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2559                    print_ntf_type_free(ri)
2560            cw.nl()
2561            render_user_family(parsed, cw, False)
2562
2563    if args.header:
2564        cw.p(f'#endif /* {hdr_prot} */')
2565
2566
2567if __name__ == "__main__":
2568    main()
2569