xref: /openbmc/linux/tools/net/ynl/ynl-gen-c.py (revision dc7b81a8)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import os
7import re
8import yaml
9
10from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
11
12
13def c_upper(name):
14    return name.upper().replace('-', '_')
15
16
17def c_lower(name):
18    return name.lower().replace('-', '_')
19
20
21class BaseNlLib:
22    def get_family_id(self):
23        return 'ys->family_id'
24
25    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
26        ind = '\n\t\t' + '\t' * indent + ' '
27        if is_dump:
28            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
29        else:
30            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
31                   "ynl_cb_array, NLMSG_MIN_TYPE)"
32
33
34class Type(SpecAttr):
35    def __init__(self, family, attr_set, attr, value):
36        super().__init__(family, attr_set, attr, value)
37
38        self.attr = attr
39        self.attr_set = attr_set
40        self.type = attr['type']
41        self.checks = attr.get('checks', {})
42
43        if 'len' in attr:
44            self.len = attr['len']
45        if 'nested-attributes' in attr:
46            self.nested_attrs = attr['nested-attributes']
47            if self.nested_attrs == family.name:
48                self.nested_render_name = f"{family.name}"
49            else:
50                self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
51
52            if self.nested_attrs in self.family.consts:
53                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
54            else:
55                self.nested_struct_type = 'struct ' + self.nested_render_name
56
57        self.c_name = c_lower(self.name)
58        if self.c_name in _C_KW:
59            self.c_name += '_'
60
61        # Added by resolve():
62        self.enum_name = None
63        delattr(self, "enum_name")
64
65    def resolve(self):
66        if 'name-prefix' in self.attr:
67            enum_name = f"{self.attr['name-prefix']}{self.name}"
68        else:
69            enum_name = f"{self.attr_set.name_prefix}{self.name}"
70        self.enum_name = c_upper(enum_name)
71
72    def is_multi_val(self):
73        return None
74
75    def is_scalar(self):
76        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
77
78    def presence_type(self):
79        return 'bit'
80
81    def presence_member(self, space, type_filter):
82        if self.presence_type() != type_filter:
83            return
84
85        if self.presence_type() == 'bit':
86            pfx = '__' if space == 'user' else ''
87            return f"{pfx}u32 {self.c_name}:1;"
88
89        if self.presence_type() == 'len':
90            pfx = '__' if space == 'user' else ''
91            return f"{pfx}u32 {self.c_name}_len;"
92
93    def _complex_member_type(self, ri):
94        return None
95
96    def free_needs_iter(self):
97        return False
98
99    def free(self, ri, var, ref):
100        if self.is_multi_val() or self.presence_type() == 'len':
101            ri.cw.p(f'free({var}->{ref}{self.c_name});')
102
103    def arg_member(self, ri):
104        member = self._complex_member_type(ri)
105        if member:
106            arg = [member + ' *' + self.c_name]
107            if self.presence_type() == 'count':
108                arg += ['unsigned int n_' + self.c_name]
109            return arg
110        raise Exception(f"Struct member not implemented for class type {self.type}")
111
112    def struct_member(self, ri):
113        if self.is_multi_val():
114            ri.cw.p(f"unsigned int n_{self.c_name};")
115        member = self._complex_member_type(ri)
116        if member:
117            ptr = '*' if self.is_multi_val() else ''
118            ri.cw.p(f"{member} {ptr}{self.c_name};")
119            return
120        members = self.arg_member(ri)
121        for one in members:
122            ri.cw.p(one + ';')
123
124    def _attr_policy(self, policy):
125        return '{ .type = ' + policy + ', }'
126
127    def attr_policy(self, cw):
128        policy = c_upper('nla-' + self.attr['type'])
129
130        spec = self._attr_policy(policy)
131        cw.p(f"\t[{self.enum_name}] = {spec},")
132
133    def _attr_typol(self):
134        raise Exception(f"Type policy not implemented for class type {self.type}")
135
136    def attr_typol(self, cw):
137        typol = self._attr_typol()
138        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
139
140    def _attr_put_line(self, ri, var, line):
141        if self.presence_type() == 'bit':
142            ri.cw.p(f"if ({var}->_present.{self.c_name})")
143        elif self.presence_type() == 'len':
144            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
145        ri.cw.p(f"{line};")
146
147    def _attr_put_simple(self, ri, var, put_type):
148        line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
149        self._attr_put_line(ri, var, line)
150
151    def attr_put(self, ri, var):
152        raise Exception(f"Put not implemented for class type {self.type}")
153
154    def _attr_get(self, ri, var):
155        raise Exception(f"Attr get not implemented for class type {self.type}")
156
157    def attr_get(self, ri, var, first):
158        lines, init_lines, local_vars = self._attr_get(ri, var)
159        if type(lines) is str:
160            lines = [lines]
161        if type(init_lines) is str:
162            init_lines = [init_lines]
163
164        kw = 'if' if first else 'else if'
165        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
166        if local_vars:
167            for local in local_vars:
168                ri.cw.p(local)
169            ri.cw.nl()
170
171        if not self.is_multi_val():
172            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
173            ri.cw.p("return MNL_CB_ERROR;")
174            if self.presence_type() == 'bit':
175                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
176
177        if init_lines:
178            ri.cw.nl()
179            for line in init_lines:
180                ri.cw.p(line)
181
182        for line in lines:
183            ri.cw.p(line)
184        ri.cw.block_end()
185        return True
186
187    def _setter_lines(self, ri, member, presence):
188        raise Exception(f"Setter not implemented for class type {self.type}")
189
190    def setter(self, ri, space, direction, deref=False, ref=None):
191        ref = (ref if ref else []) + [self.c_name]
192        var = "req"
193        member = f"{var}->{'.'.join(ref)}"
194
195        code = []
196        presence = ''
197        for i in range(0, len(ref)):
198            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
199            if self.presence_type() == 'bit':
200                code.append(presence + ' = 1;')
201        code += self._setter_lines(ri, member, presence)
202
203        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
204        free = bool([x for x in code if 'free(' in x])
205        alloc = bool([x for x in code if 'alloc(' in x])
206        if free and not alloc:
207            func_name = '__' + func_name
208        ri.cw.write_func('static inline void', func_name, body=code,
209                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
210
211
212class TypeUnused(Type):
213    def presence_type(self):
214        return ''
215
216    def arg_member(self, ri):
217        return []
218
219    def _attr_get(self, ri, var):
220        return ['return MNL_CB_ERROR;'], None, None
221
222    def _attr_typol(self):
223        return '.type = YNL_PT_REJECT, '
224
225    def attr_policy(self, cw):
226        pass
227
228
229class TypePad(Type):
230    def presence_type(self):
231        return ''
232
233    def arg_member(self, ri):
234        return []
235
236    def _attr_typol(self):
237        return '.type = YNL_PT_IGNORE, '
238
239    def attr_put(self, ri, var):
240        pass
241
242    def attr_get(self, ri, var, first):
243        pass
244
245    def attr_policy(self, cw):
246        pass
247
248    def setter(self, ri, space, direction, deref=False, ref=None):
249        pass
250
251
252class TypeScalar(Type):
253    def __init__(self, family, attr_set, attr, value):
254        super().__init__(family, attr_set, attr, value)
255
256        self.byte_order_comment = ''
257        if 'byte-order' in attr:
258            self.byte_order_comment = f" /* {attr['byte-order']} */"
259
260        # Added by resolve():
261        self.is_bitfield = None
262        delattr(self, "is_bitfield")
263        self.type_name = None
264        delattr(self, "type_name")
265
266    def resolve(self):
267        self.resolve_up(super())
268
269        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
270            self.is_bitfield = True
271        elif 'enum' in self.attr:
272            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
273        else:
274            self.is_bitfield = False
275
276        maybe_enum = not self.is_bitfield and 'enum' in self.attr
277        if maybe_enum and self.family.consts[self.attr['enum']].enum_name:
278            self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
279        else:
280            self.type_name = '__' + self.type
281
282    def _mnl_type(self):
283        t = self.type
284        # mnl does not have a helper for signed types
285        if t[0] == 's':
286            t = 'u' + t[1:]
287        return t
288
289    def _attr_policy(self, policy):
290        if 'flags-mask' in self.checks or self.is_bitfield:
291            if self.is_bitfield:
292                enum = self.family.consts[self.attr['enum']]
293                mask = enum.get_mask(as_flags=True)
294            else:
295                flags = self.family.consts[self.checks['flags-mask']]
296                flag_cnt = len(flags['entries'])
297                mask = (1 << flag_cnt) - 1
298            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
299        elif 'min' in self.checks:
300            return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
301        elif 'enum' in self.attr:
302            enum = self.family.consts[self.attr['enum']]
303            low, high = enum.value_range()
304            if low == 0:
305                return f"NLA_POLICY_MAX({policy}, {high})"
306            return f"NLA_POLICY_RANGE({policy}, {low}, {high})"
307        return super()._attr_policy(policy)
308
309    def _attr_typol(self):
310        return f'.type = YNL_PT_U{self.type[1:]}, '
311
312    def arg_member(self, ri):
313        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
314
315    def attr_put(self, ri, var):
316        self._attr_put_simple(ri, var, self._mnl_type())
317
318    def _attr_get(self, ri, var):
319        return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
320
321    def _setter_lines(self, ri, member, presence):
322        return [f"{member} = {self.c_name};"]
323
324
325class TypeFlag(Type):
326    def arg_member(self, ri):
327        return []
328
329    def _attr_typol(self):
330        return '.type = YNL_PT_FLAG, '
331
332    def attr_put(self, ri, var):
333        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
334
335    def _attr_get(self, ri, var):
336        return [], None, None
337
338    def _setter_lines(self, ri, member, presence):
339        return []
340
341
342class TypeString(Type):
343    def arg_member(self, ri):
344        return [f"const char *{self.c_name}"]
345
346    def presence_type(self):
347        return 'len'
348
349    def struct_member(self, ri):
350        ri.cw.p(f"char *{self.c_name};")
351
352    def _attr_typol(self):
353        return f'.type = YNL_PT_NUL_STR, '
354
355    def _attr_policy(self, policy):
356        mem = '{ .type = ' + policy
357        if 'max-len' in self.checks:
358            mem += ', .len = ' + str(self.checks['max-len'])
359        mem += ', }'
360        return mem
361
362    def attr_policy(self, cw):
363        if self.checks.get('unterminated-ok', False):
364            policy = 'NLA_STRING'
365        else:
366            policy = 'NLA_NUL_STRING'
367
368        spec = self._attr_policy(policy)
369        cw.p(f"\t[{self.enum_name}] = {spec},")
370
371    def attr_put(self, ri, var):
372        self._attr_put_simple(ri, var, 'strz')
373
374    def _attr_get(self, ri, var):
375        len_mem = var + '->_present.' + self.c_name + '_len'
376        return [f"{len_mem} = len;",
377                f"{var}->{self.c_name} = malloc(len + 1);",
378                f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
379                f"{var}->{self.c_name}[len] = 0;"], \
380               ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
381               ['unsigned int len;']
382
383    def _setter_lines(self, ri, member, presence):
384        return [f"free({member});",
385                f"{presence}_len = strlen({self.c_name});",
386                f"{member} = malloc({presence}_len + 1);",
387                f'memcpy({member}, {self.c_name}, {presence}_len);',
388                f'{member}[{presence}_len] = 0;']
389
390
391class TypeBinary(Type):
392    def arg_member(self, ri):
393        return [f"const void *{self.c_name}", 'size_t len']
394
395    def presence_type(self):
396        return 'len'
397
398    def struct_member(self, ri):
399        ri.cw.p(f"void *{self.c_name};")
400
401    def _attr_typol(self):
402        return f'.type = YNL_PT_BINARY,'
403
404    def _attr_policy(self, policy):
405        mem = '{ '
406        if len(self.checks) == 1 and 'min-len' in self.checks:
407            mem += '.len = ' + str(self.checks['min-len'])
408        elif len(self.checks) == 0:
409            mem += '.type = NLA_BINARY'
410        else:
411            raise Exception('One or more of binary type checks not implemented, yet')
412        mem += ', }'
413        return mem
414
415    def attr_put(self, ri, var):
416        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
417                            f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
418
419    def _attr_get(self, ri, var):
420        len_mem = var + '->_present.' + self.c_name + '_len'
421        return [f"{len_mem} = len;",
422                f"{var}->{self.c_name} = malloc(len);",
423                f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
424               ['len = mnl_attr_get_payload_len(attr);'], \
425               ['unsigned int len;']
426
427    def _setter_lines(self, ri, member, presence):
428        return [f"free({member});",
429                f"{member} = malloc({presence}_len);",
430                f'memcpy({member}, {self.c_name}, {presence}_len);']
431
432
433class TypeNest(Type):
434    def _complex_member_type(self, ri):
435        return self.nested_struct_type
436
437    def free(self, ri, var, ref):
438        ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
439
440    def _attr_typol(self):
441        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
442
443    def _attr_policy(self, policy):
444        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
445
446    def attr_put(self, ri, var):
447        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
448                            f"{self.enum_name}, &{var}->{self.c_name})")
449
450    def _attr_get(self, ri, var):
451        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
452                     "return MNL_CB_ERROR;"]
453        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
454                      f"parg.data = &{var}->{self.c_name};"]
455        return get_lines, init_lines, None
456
457    def setter(self, ri, space, direction, deref=False, ref=None):
458        ref = (ref if ref else []) + [self.c_name]
459
460        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
461            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
462
463
464class TypeMultiAttr(Type):
465    def __init__(self, family, attr_set, attr, value, base_type):
466        super().__init__(family, attr_set, attr, value)
467
468        self.base_type = base_type
469
470    def is_multi_val(self):
471        return True
472
473    def presence_type(self):
474        return 'count'
475
476    def _mnl_type(self):
477        t = self.type
478        # mnl does not have a helper for signed types
479        if t[0] == 's':
480            t = 'u' + t[1:]
481        return t
482
483    def _complex_member_type(self, ri):
484        if 'type' not in self.attr or self.attr['type'] == 'nest':
485            return self.nested_struct_type
486        elif self.attr['type'] in scalars:
487            scalar_pfx = '__' if ri.ku_space == 'user' else ''
488            return scalar_pfx + self.attr['type']
489        else:
490            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
491
492    def free_needs_iter(self):
493        return 'type' not in self.attr or self.attr['type'] == 'nest'
494
495    def free(self, ri, var, ref):
496        if self.attr['type'] in scalars:
497            ri.cw.p(f"free({var}->{ref}{self.c_name});")
498        elif 'type' not in self.attr or self.attr['type'] == 'nest':
499            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
500            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
501            ri.cw.p(f"free({var}->{ref}{self.c_name});")
502        else:
503            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
504
505    def _attr_policy(self, policy):
506        return self.base_type._attr_policy(policy)
507
508    def _attr_typol(self):
509        return self.base_type._attr_typol()
510
511    def _attr_get(self, ri, var):
512        return f'n_{self.c_name}++;', None, None
513
514    def attr_put(self, ri, var):
515        if self.attr['type'] in scalars:
516            put_type = self._mnl_type()
517            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
518            ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
519        elif 'type' not in self.attr or self.attr['type'] == 'nest':
520            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
521            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
522                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
523        else:
524            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
525
526    def _setter_lines(self, ri, member, presence):
527        # For multi-attr we have a count, not presence, hack up the presence
528        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
529        return [f"free({member});",
530                f"{member} = {self.c_name};",
531                f"{presence} = n_{self.c_name};"]
532
533
534class TypeArrayNest(Type):
535    def is_multi_val(self):
536        return True
537
538    def presence_type(self):
539        return 'count'
540
541    def _complex_member_type(self, ri):
542        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
543            return self.nested_struct_type
544        elif self.attr['sub-type'] in scalars:
545            scalar_pfx = '__' if ri.ku_space == 'user' else ''
546            return scalar_pfx + self.attr['sub-type']
547        else:
548            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
549
550    def _attr_typol(self):
551        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
552
553    def _attr_get(self, ri, var):
554        local_vars = ['const struct nlattr *attr2;']
555        get_lines = [f'attr_{self.c_name} = attr;',
556                     'mnl_attr_for_each_nested(attr2, attr)',
557                     f'\t{var}->n_{self.c_name}++;']
558        return get_lines, None, local_vars
559
560
561class TypeNestTypeValue(Type):
562    def _complex_member_type(self, ri):
563        return self.nested_struct_type
564
565    def _attr_typol(self):
566        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
567
568    def _attr_get(self, ri, var):
569        prev = 'attr'
570        tv_args = ''
571        get_lines = []
572        local_vars = []
573        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
574                      f"parg.data = &{var}->{self.c_name};"]
575        if 'type-value' in self.attr:
576            tv_names = [c_lower(x) for x in self.attr["type-value"]]
577            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
578            local_vars += [f'__u32 {", ".join(tv_names)};']
579            for level in self.attr["type-value"]:
580                level = c_lower(level)
581                get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
582                get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
583                prev = 'attr_' + level
584
585            tv_args = f", {', '.join(tv_names)}"
586
587        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
588        return get_lines, init_lines, local_vars
589
590
591class Struct:
592    def __init__(self, family, space_name, type_list=None, inherited=None):
593        self.family = family
594        self.space_name = space_name
595        self.attr_set = family.attr_sets[space_name]
596        # Use list to catch comparisons with empty sets
597        self._inherited = inherited if inherited is not None else []
598        self.inherited = []
599
600        self.nested = type_list is None
601        if family.name == c_lower(space_name):
602            self.render_name = f"{family.name}"
603        else:
604            self.render_name = f"{family.name}_{c_lower(space_name)}"
605        self.struct_name = 'struct ' + self.render_name
606        if self.nested and space_name in family.consts:
607            self.struct_name += '_'
608        self.ptr_name = self.struct_name + ' *'
609
610        self.request = False
611        self.reply = False
612
613        self.attr_list = []
614        self.attrs = dict()
615        if type_list:
616            for t in type_list:
617                self.attr_list.append((t, self.attr_set[t]),)
618        else:
619            for t in self.attr_set:
620                self.attr_list.append((t, self.attr_set[t]),)
621
622        max_val = 0
623        self.attr_max_val = None
624        for name, attr in self.attr_list:
625            if attr.value >= max_val:
626                max_val = attr.value
627                self.attr_max_val = attr
628            self.attrs[name] = attr
629
630    def __iter__(self):
631        yield from self.attrs
632
633    def __getitem__(self, key):
634        return self.attrs[key]
635
636    def member_list(self):
637        return self.attr_list
638
639    def set_inherited(self, new_inherited):
640        if self._inherited != new_inherited:
641            raise Exception("Inheriting different members not supported")
642        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
643
644
645class EnumEntry(SpecEnumEntry):
646    def __init__(self, enum_set, yaml, prev, value_start):
647        super().__init__(enum_set, yaml, prev, value_start)
648
649        if prev:
650            self.value_change = (self.value != prev.value + 1)
651        else:
652            self.value_change = (self.value != 0)
653        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
654
655        # Added by resolve:
656        self.c_name = None
657        delattr(self, "c_name")
658
659    def resolve(self):
660        self.resolve_up(super())
661
662        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
663
664
665class EnumSet(SpecEnumSet):
666    def __init__(self, family, yaml):
667        self.render_name = c_lower(family.name + '-' + yaml['name'])
668
669        if 'enum-name' in yaml:
670            if yaml['enum-name']:
671                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
672            else:
673                self.enum_name = None
674        else:
675            self.enum_name = 'enum ' + self.render_name
676
677        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
678
679        super().__init__(family, yaml)
680
681    def new_entry(self, entry, prev_entry, value_start):
682        return EnumEntry(self, entry, prev_entry, value_start)
683
684    def value_range(self):
685        low = min([x.value for x in self.entries.values()])
686        high = max([x.value for x in self.entries.values()])
687
688        if high - low + 1 != len(self.entries):
689            raise Exception("Can't get value range for a noncontiguous enum")
690
691        return low, high
692
693
694class AttrSet(SpecAttrSet):
695    def __init__(self, family, yaml):
696        super().__init__(family, yaml)
697
698        if self.subset_of is None:
699            if 'name-prefix' in yaml:
700                pfx = yaml['name-prefix']
701            elif self.name == family.name:
702                pfx = family.name + '-a-'
703            else:
704                pfx = f"{family.name}-a-{self.name}-"
705            self.name_prefix = c_upper(pfx)
706            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
707        else:
708            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
709            self.max_name = family.attr_sets[self.subset_of].max_name
710
711        # Added by resolve:
712        self.c_name = None
713        delattr(self, "c_name")
714
715    def resolve(self):
716        self.c_name = c_lower(self.name)
717        if self.c_name in _C_KW:
718            self.c_name += '_'
719        if self.c_name == self.family.c_name:
720            self.c_name = ''
721
722    def new_attr(self, elem, value):
723        if elem['type'] in scalars:
724            t = TypeScalar(self.family, self, elem, value)
725        elif elem['type'] == 'unused':
726            t = TypeUnused(self.family, self, elem, value)
727        elif elem['type'] == 'pad':
728            t = TypePad(self.family, self, elem, value)
729        elif elem['type'] == 'flag':
730            t = TypeFlag(self.family, self, elem, value)
731        elif elem['type'] == 'string':
732            t = TypeString(self.family, self, elem, value)
733        elif elem['type'] == 'binary':
734            t = TypeBinary(self.family, self, elem, value)
735        elif elem['type'] == 'nest':
736            t = TypeNest(self.family, self, elem, value)
737        elif elem['type'] == 'array-nest':
738            t = TypeArrayNest(self.family, self, elem, value)
739        elif elem['type'] == 'nest-type-value':
740            t = TypeNestTypeValue(self.family, self, elem, value)
741        else:
742            raise Exception(f"No typed class for type {elem['type']}")
743
744        if 'multi-attr' in elem and elem['multi-attr']:
745            t = TypeMultiAttr(self.family, self, elem, value, t)
746
747        return t
748
749
750class Operation(SpecOperation):
751    def __init__(self, family, yaml, req_value, rsp_value):
752        super().__init__(family, yaml, req_value, rsp_value)
753
754        self.render_name = family.name + '_' + c_lower(self.name)
755
756        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
757                         ('dump' in yaml and 'request' in yaml['dump'])
758
759        self.has_ntf = False
760
761        # Added by resolve:
762        self.enum_name = None
763        delattr(self, "enum_name")
764
765    def resolve(self):
766        self.resolve_up(super())
767
768        if not self.is_async:
769            self.enum_name = self.family.op_prefix + c_upper(self.name)
770        else:
771            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
772
773    def mark_has_ntf(self):
774        self.has_ntf = True
775
776
777class Family(SpecFamily):
778    def __init__(self, file_name, exclude_ops):
779        # Added by resolve:
780        self.c_name = None
781        delattr(self, "c_name")
782        self.op_prefix = None
783        delattr(self, "op_prefix")
784        self.async_op_prefix = None
785        delattr(self, "async_op_prefix")
786        self.mcgrps = None
787        delattr(self, "mcgrps")
788        self.consts = None
789        delattr(self, "consts")
790        self.hooks = None
791        delattr(self, "hooks")
792
793        super().__init__(file_name, exclude_ops=exclude_ops)
794
795        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
796        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
797
798        if 'definitions' not in self.yaml:
799            self.yaml['definitions'] = []
800
801        if 'uapi-header' in self.yaml:
802            self.uapi_header = self.yaml['uapi-header']
803        else:
804            self.uapi_header = f"linux/{self.name}.h"
805
806    def resolve(self):
807        self.resolve_up(super())
808
809        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
810            raise Exception("Codegen only supported for genetlink")
811
812        self.c_name = c_lower(self.name)
813        if 'name-prefix' in self.yaml['operations']:
814            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
815        else:
816            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
817        if 'async-prefix' in self.yaml['operations']:
818            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
819        else:
820            self.async_op_prefix = self.op_prefix
821
822        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
823
824        self.hooks = dict()
825        for when in ['pre', 'post']:
826            self.hooks[when] = dict()
827            for op_mode in ['do', 'dump']:
828                self.hooks[when][op_mode] = dict()
829                self.hooks[when][op_mode]['set'] = set()
830                self.hooks[when][op_mode]['list'] = []
831
832        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
833        self.root_sets = dict()
834        # dict space-name -> set('request', 'reply')
835        self.pure_nested_structs = dict()
836
837        self._mark_notify()
838        self._mock_up_events()
839
840        self._load_root_sets()
841        self._load_nested_sets()
842        self._load_hooks()
843
844        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
845        if self.kernel_policy == 'global':
846            self._load_global_policy()
847
848    def new_enum(self, elem):
849        return EnumSet(self, elem)
850
851    def new_attr_set(self, elem):
852        return AttrSet(self, elem)
853
854    def new_operation(self, elem, req_value, rsp_value):
855        return Operation(self, elem, req_value, rsp_value)
856
857    def _mark_notify(self):
858        for op in self.msgs.values():
859            if 'notify' in op:
860                self.ops[op['notify']].mark_has_ntf()
861
862    # Fake a 'do' equivalent of all events, so that we can render their response parsing
863    def _mock_up_events(self):
864        for op in self.yaml['operations']['list']:
865            if 'event' in op:
866                op['do'] = {
867                    'reply': {
868                        'attributes': op['event']['attributes']
869                    }
870                }
871
872    def _load_root_sets(self):
873        for op_name, op in self.msgs.items():
874            if 'attribute-set' not in op:
875                continue
876
877            req_attrs = set()
878            rsp_attrs = set()
879            for op_mode in ['do', 'dump']:
880                if op_mode in op and 'request' in op[op_mode]:
881                    req_attrs.update(set(op[op_mode]['request']['attributes']))
882                if op_mode in op and 'reply' in op[op_mode]:
883                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
884            if 'event' in op:
885                rsp_attrs.update(set(op['event']['attributes']))
886
887            if op['attribute-set'] not in self.root_sets:
888                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
889            else:
890                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
891                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
892
893    def _load_nested_sets(self):
894        attr_set_queue = list(self.root_sets.keys())
895        attr_set_seen = set(self.root_sets.keys())
896
897        while len(attr_set_queue):
898            a_set = attr_set_queue.pop(0)
899            for attr, spec in self.attr_sets[a_set].items():
900                if 'nested-attributes' not in spec:
901                    continue
902
903                nested = spec['nested-attributes']
904                if nested not in attr_set_seen:
905                    attr_set_queue.append(nested)
906                    attr_set_seen.add(nested)
907
908                inherit = set()
909                if nested not in self.root_sets:
910                    if nested not in self.pure_nested_structs:
911                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
912                else:
913                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
914
915                if 'type-value' in spec:
916                    if nested in self.root_sets:
917                        raise Exception("Inheriting members to a space used as root not supported")
918                    inherit.update(set(spec['type-value']))
919                elif spec['type'] == 'array-nest':
920                    inherit.add('idx')
921                self.pure_nested_structs[nested].set_inherited(inherit)
922
923        for root_set, rs_members in self.root_sets.items():
924            for attr, spec in self.attr_sets[root_set].items():
925                if 'nested-attributes' in spec:
926                    nested = spec['nested-attributes']
927                    if attr in rs_members['request']:
928                        self.pure_nested_structs[nested].request = True
929                    if attr in rs_members['reply']:
930                        self.pure_nested_structs[nested].reply = True
931
932        # Try to reorder according to dependencies
933        pns_key_list = list(self.pure_nested_structs.keys())
934        pns_key_seen = set()
935        rounds = len(pns_key_list)**2  # it's basically bubble sort
936        for _ in range(rounds):
937            if len(pns_key_list) == 0:
938                break
939            name = pns_key_list.pop(0)
940            finished = True
941            for _, spec in self.attr_sets[name].items():
942                if 'nested-attributes' in spec:
943                    if spec['nested-attributes'] not in pns_key_seen:
944                        # Dicts are sorted, this will make struct last
945                        struct = self.pure_nested_structs.pop(name)
946                        self.pure_nested_structs[name] = struct
947                        finished = False
948                        break
949            if finished:
950                pns_key_seen.add(name)
951            else:
952                pns_key_list.append(name)
953        # Propagate the request / reply
954        for attr_set, struct in reversed(self.pure_nested_structs.items()):
955            for _, spec in self.attr_sets[attr_set].items():
956                if 'nested-attributes' in spec:
957                    child = self.pure_nested_structs.get(spec['nested-attributes'])
958                    if child:
959                        child.request |= struct.request
960                        child.reply |= struct.reply
961
962    def _load_global_policy(self):
963        global_set = set()
964        attr_set_name = None
965        for op_name, op in self.ops.items():
966            if not op:
967                continue
968            if 'attribute-set' not in op:
969                continue
970
971            if attr_set_name is None:
972                attr_set_name = op['attribute-set']
973            if attr_set_name != op['attribute-set']:
974                raise Exception('For a global policy all ops must use the same set')
975
976            for op_mode in ['do', 'dump']:
977                if op_mode in op:
978                    global_set.update(op[op_mode].get('request', []))
979
980        self.global_policy = []
981        self.global_policy_set = attr_set_name
982        for attr in self.attr_sets[attr_set_name]:
983            if attr in global_set:
984                self.global_policy.append(attr)
985
986    def _load_hooks(self):
987        for op in self.ops.values():
988            for op_mode in ['do', 'dump']:
989                if op_mode not in op:
990                    continue
991                for when in ['pre', 'post']:
992                    if when not in op[op_mode]:
993                        continue
994                    name = op[op_mode][when]
995                    if name in self.hooks[when][op_mode]['set']:
996                        continue
997                    self.hooks[when][op_mode]['set'].add(name)
998                    self.hooks[when][op_mode]['list'].append(name)
999
1000
1001class RenderInfo:
1002    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1003        self.family = family
1004        self.nl = cw.nlib
1005        self.ku_space = ku_space
1006        self.op_mode = op_mode
1007        self.op = op
1008
1009        # 'do' and 'dump' response parsing is identical
1010        self.type_consistent = True
1011        if op_mode != 'do' and 'dump' in op and 'do' in op:
1012            if ('reply' in op['do']) != ('reply' in op["dump"]):
1013                self.type_consistent = False
1014            elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1015                self.type_consistent = False
1016
1017        self.attr_set = attr_set
1018        if not self.attr_set:
1019            self.attr_set = op['attribute-set']
1020
1021        self.type_name_conflict = False
1022        if op:
1023            self.type_name = c_lower(op.name)
1024        else:
1025            self.type_name = c_lower(attr_set)
1026            if attr_set in family.consts:
1027                self.type_name_conflict = True
1028
1029        self.cw = cw
1030
1031        self.struct = dict()
1032        if op_mode == 'notify':
1033            op_mode = 'do'
1034        for op_dir in ['request', 'reply']:
1035            if op and op_dir in op[op_mode]:
1036                self.struct[op_dir] = Struct(family, self.attr_set,
1037                                             type_list=op[op_mode][op_dir]['attributes'])
1038        if op_mode == 'event':
1039            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1040
1041
1042class CodeWriter:
1043    def __init__(self, nlib, out_file):
1044        self.nlib = nlib
1045
1046        self._nl = False
1047        self._block_end = False
1048        self._silent_block = False
1049        self._ind = 0
1050        self._out = out_file
1051
1052    @classmethod
1053    def _is_cond(cls, line):
1054        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1055
1056    def p(self, line, add_ind=0):
1057        if self._block_end:
1058            self._block_end = False
1059            if line.startswith('else'):
1060                line = '} ' + line
1061            else:
1062                self._out.write('\t' * self._ind + '}\n')
1063
1064        if self._nl:
1065            self._out.write('\n')
1066            self._nl = False
1067
1068        ind = self._ind
1069        if line[-1] == ':':
1070            ind -= 1
1071        if self._silent_block:
1072            ind += 1
1073        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1074        if add_ind:
1075            ind += add_ind
1076        self._out.write('\t' * ind + line + '\n')
1077
1078    def nl(self):
1079        self._nl = True
1080
1081    def block_start(self, line=''):
1082        if line:
1083            line = line + ' '
1084        self.p(line + '{')
1085        self._ind += 1
1086
1087    def block_end(self, line=''):
1088        if line and line[0] not in {';', ','}:
1089            line = ' ' + line
1090        self._ind -= 1
1091        self._nl = False
1092        if not line:
1093            # Delay printing closing bracket in case "else" comes next
1094            if self._block_end:
1095                self._out.write('\t' * (self._ind + 1) + '}\n')
1096            self._block_end = True
1097        else:
1098            self.p('}' + line)
1099
1100    def write_doc_line(self, doc, indent=True):
1101        words = doc.split()
1102        line = ' *'
1103        for word in words:
1104            if len(line) + len(word) >= 79:
1105                self.p(line)
1106                line = ' *'
1107                if indent:
1108                    line += '  '
1109            line += ' ' + word
1110        self.p(line)
1111
1112    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1113        if not args:
1114            args = ['void']
1115
1116        if doc:
1117            self.p('/*')
1118            self.p(' * ' + doc)
1119            self.p(' */')
1120
1121        oneline = qual_ret
1122        if qual_ret[-1] != '*':
1123            oneline += ' '
1124        oneline += f"{name}({', '.join(args)}){suffix}"
1125
1126        if len(oneline) < 80:
1127            self.p(oneline)
1128            return
1129
1130        v = qual_ret
1131        if len(v) > 3:
1132            self.p(v)
1133            v = ''
1134        elif qual_ret[-1] != '*':
1135            v += ' '
1136        v += name + '('
1137        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1138        delta_ind = len(v) - len(ind)
1139        v += args[0]
1140        i = 1
1141        while i < len(args):
1142            next_len = len(v) + len(args[i])
1143            if v[0] == '\t':
1144                next_len += delta_ind
1145            if next_len > 76:
1146                self.p(v + ',')
1147                v = ind
1148            else:
1149                v += ', '
1150            v += args[i]
1151            i += 1
1152        self.p(v + ')' + suffix)
1153
1154    def write_func_lvar(self, local_vars):
1155        if not local_vars:
1156            return
1157
1158        if type(local_vars) is str:
1159            local_vars = [local_vars]
1160
1161        local_vars.sort(key=len, reverse=True)
1162        for var in local_vars:
1163            self.p(var)
1164        self.nl()
1165
1166    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1167        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1168        self.write_func_lvar(local_vars=local_vars)
1169
1170        self.block_start()
1171        for line in body:
1172            self.p(line)
1173        self.block_end()
1174
1175    def writes_defines(self, defines):
1176        longest = 0
1177        for define in defines:
1178            if len(define[0]) > longest:
1179                longest = len(define[0])
1180        longest = ((longest + 8) // 8) * 8
1181        for define in defines:
1182            line = '#define ' + define[0]
1183            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1184            if type(define[1]) is int:
1185                line += str(define[1])
1186            elif type(define[1]) is str:
1187                line += '"' + define[1] + '"'
1188            self.p(line)
1189
1190    def write_struct_init(self, members):
1191        longest = max([len(x[0]) for x in members])
1192        longest += 1  # because we prepend a .
1193        longest = ((longest + 8) // 8) * 8
1194        for one in members:
1195            line = '.' + one[0]
1196            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1197            line += '= ' + one[1] + ','
1198            self.p(line)
1199
1200
1201scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1202
1203direction_to_suffix = {
1204    'reply': '_rsp',
1205    'request': '_req',
1206    '': ''
1207}
1208
1209op_mode_to_wrapper = {
1210    'do': '',
1211    'dump': '_list',
1212    'notify': '_ntf',
1213    'event': '',
1214}
1215
1216_C_KW = {
1217    'auto',
1218    'bool',
1219    'break',
1220    'case',
1221    'char',
1222    'const',
1223    'continue',
1224    'default',
1225    'do',
1226    'double',
1227    'else',
1228    'enum',
1229    'extern',
1230    'float',
1231    'for',
1232    'goto',
1233    'if',
1234    'inline',
1235    'int',
1236    'long',
1237    'register',
1238    'return',
1239    'short',
1240    'signed',
1241    'sizeof',
1242    'static',
1243    'struct',
1244    'switch',
1245    'typedef',
1246    'union',
1247    'unsigned',
1248    'void',
1249    'volatile',
1250    'while'
1251}
1252
1253
1254def rdir(direction):
1255    if direction == 'reply':
1256        return 'request'
1257    if direction == 'request':
1258        return 'reply'
1259    return direction
1260
1261
1262def op_prefix(ri, direction, deref=False):
1263    suffix = f"_{ri.type_name}"
1264
1265    if not ri.op_mode or ri.op_mode == 'do':
1266        suffix += f"{direction_to_suffix[direction]}"
1267    else:
1268        if direction == 'request':
1269            suffix += '_req_dump'
1270        else:
1271            if ri.type_consistent:
1272                if deref:
1273                    suffix += f"{direction_to_suffix[direction]}"
1274                else:
1275                    suffix += op_mode_to_wrapper[ri.op_mode]
1276            else:
1277                suffix += '_rsp'
1278                suffix += '_dump' if deref else '_list'
1279
1280    return f"{ri.family['name']}{suffix}"
1281
1282
1283def type_name(ri, direction, deref=False):
1284    return f"struct {op_prefix(ri, direction, deref=deref)}"
1285
1286
1287def print_prototype(ri, direction, terminate=True, doc=None):
1288    suffix = ';' if terminate else ''
1289
1290    fname = ri.op.render_name
1291    if ri.op_mode == 'dump':
1292        fname += '_dump'
1293
1294    args = ['struct ynl_sock *ys']
1295    if 'request' in ri.op[ri.op_mode]:
1296        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1297
1298    ret = 'int'
1299    if 'reply' in ri.op[ri.op_mode]:
1300        ret = f"{type_name(ri, rdir(direction))} *"
1301
1302    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1303
1304
1305def print_req_prototype(ri):
1306    print_prototype(ri, "request", doc=ri.op['doc'])
1307
1308
1309def print_dump_prototype(ri):
1310    print_prototype(ri, "request")
1311
1312
1313def put_typol(cw, struct):
1314    type_max = struct.attr_set.max_name
1315    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1316
1317    for _, arg in struct.member_list():
1318        arg.attr_typol(cw)
1319
1320    cw.block_end(line=';')
1321    cw.nl()
1322
1323    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1324    cw.p(f'.max_attr = {type_max},')
1325    cw.p(f'.table = {struct.render_name}_policy,')
1326    cw.block_end(line=';')
1327    cw.nl()
1328
1329
1330def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1331    args = [f'int {arg_name}']
1332    if enum and not ('enum-name' in enum and not enum['enum-name']):
1333        args = [f'enum {render_name} {arg_name}']
1334    cw.write_func_prot('const char *', f'{render_name}_str', args)
1335    cw.block_start()
1336    if enum and enum.type == 'flags':
1337        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1338    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1339    cw.p('return NULL;')
1340    cw.p(f'return {map_name}[{arg_name}];')
1341    cw.block_end()
1342    cw.nl()
1343
1344
1345def put_op_name_fwd(family, cw):
1346    cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'], suffix=';')
1347
1348
1349def put_op_name(family, cw):
1350    map_name = f'{family.name}_op_strmap'
1351    cw.block_start(line=f"static const char * const {map_name}[] =")
1352    for op_name, op in family.msgs.items():
1353        if op.rsp_value:
1354            if op.req_value == op.rsp_value:
1355                cw.p(f'[{op.enum_name}] = "{op_name}",')
1356            else:
1357                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1358    cw.block_end(line=';')
1359    cw.nl()
1360
1361    _put_enum_to_str_helper(cw, family.name + '_op', map_name, 'op')
1362
1363
1364def put_enum_to_str_fwd(family, cw, enum):
1365    args = [f'enum {enum.render_name} value']
1366    if 'enum-name' in enum and not enum['enum-name']:
1367        args = ['int value']
1368    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1369
1370
1371def put_enum_to_str(family, cw, enum):
1372    map_name = f'{enum.render_name}_strmap'
1373    cw.block_start(line=f"static const char * const {map_name}[] =")
1374    for entry in enum.entries.values():
1375        cw.p(f'[{entry.value}] = "{entry.name}",')
1376    cw.block_end(line=';')
1377    cw.nl()
1378
1379    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1380
1381
1382def put_req_nested(ri, struct):
1383    func_args = ['struct nlmsghdr *nlh',
1384                 'unsigned int attr_type',
1385                 f'{struct.ptr_name}obj']
1386
1387    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1388    ri.cw.block_start()
1389    ri.cw.write_func_lvar('struct nlattr *nest;')
1390
1391    ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1392
1393    for _, arg in struct.member_list():
1394        arg.attr_put(ri, "obj")
1395
1396    ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1397
1398    ri.cw.nl()
1399    ri.cw.p('return 0;')
1400    ri.cw.block_end()
1401    ri.cw.nl()
1402
1403
1404def _multi_parse(ri, struct, init_lines, local_vars):
1405    if struct.nested:
1406        iter_line = "mnl_attr_for_each_nested(attr, nested)"
1407    else:
1408        iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1409
1410    array_nests = set()
1411    multi_attrs = set()
1412    needs_parg = False
1413    for arg, aspec in struct.member_list():
1414        if aspec['type'] == 'array-nest':
1415            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1416            array_nests.add(arg)
1417        if 'multi-attr' in aspec:
1418            multi_attrs.add(arg)
1419        needs_parg |= 'nested-attributes' in aspec
1420    if array_nests or multi_attrs:
1421        local_vars.append('int i;')
1422    if needs_parg:
1423        local_vars.append('struct ynl_parse_arg parg;')
1424        init_lines.append('parg.ys = yarg->ys;')
1425
1426    all_multi = array_nests | multi_attrs
1427
1428    for anest in sorted(all_multi):
1429        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1430
1431    ri.cw.block_start()
1432    ri.cw.write_func_lvar(local_vars)
1433
1434    for line in init_lines:
1435        ri.cw.p(line)
1436    ri.cw.nl()
1437
1438    for arg in struct.inherited:
1439        ri.cw.p(f'dst->{arg} = {arg};')
1440
1441    for anest in sorted(all_multi):
1442        aspec = struct[anest]
1443        ri.cw.p(f"if (dst->{aspec.c_name})")
1444        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1445
1446    ri.cw.nl()
1447    ri.cw.block_start(line=iter_line)
1448    ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1449    ri.cw.nl()
1450
1451    first = True
1452    for _, arg in struct.member_list():
1453        good = arg.attr_get(ri, 'dst', first=first)
1454        # First may be 'unused' or 'pad', ignore those
1455        first &= not good
1456
1457    ri.cw.block_end()
1458    ri.cw.nl()
1459
1460    for anest in sorted(array_nests):
1461        aspec = struct[anest]
1462
1463        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1464        ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1465        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1466        ri.cw.p('i = 0;')
1467        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1468        ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1469        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1470        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1471        ri.cw.p('return MNL_CB_ERROR;')
1472        ri.cw.p('i++;')
1473        ri.cw.block_end()
1474        ri.cw.block_end()
1475    ri.cw.nl()
1476
1477    for anest in sorted(multi_attrs):
1478        aspec = struct[anest]
1479        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1480        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1481        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1482        ri.cw.p('i = 0;')
1483        if 'nested-attributes' in aspec:
1484            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1485        ri.cw.block_start(line=iter_line)
1486        ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1487        if 'nested-attributes' in aspec:
1488            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1489            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1490            ri.cw.p('return MNL_CB_ERROR;')
1491        elif aspec['type'] in scalars:
1492            t = aspec['type']
1493            if t[0] == 's':
1494                t = 'u' + t[1:]
1495            ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1496        else:
1497            raise Exception('Nest parsing type not supported yet')
1498        ri.cw.p('i++;')
1499        ri.cw.block_end()
1500        ri.cw.block_end()
1501        ri.cw.block_end()
1502    ri.cw.nl()
1503
1504    if struct.nested:
1505        ri.cw.p('return 0;')
1506    else:
1507        ri.cw.p('return MNL_CB_OK;')
1508    ri.cw.block_end()
1509    ri.cw.nl()
1510
1511
1512def parse_rsp_nested(ri, struct):
1513    func_args = ['struct ynl_parse_arg *yarg',
1514                 'const struct nlattr *nested']
1515    for arg in struct.inherited:
1516        func_args.append('__u32 ' + arg)
1517
1518    local_vars = ['const struct nlattr *attr;',
1519                  f'{struct.ptr_name}dst = yarg->data;']
1520    init_lines = []
1521
1522    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1523
1524    _multi_parse(ri, struct, init_lines, local_vars)
1525
1526
1527def parse_rsp_msg(ri, deref=False):
1528    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1529        return
1530
1531    func_args = ['const struct nlmsghdr *nlh',
1532                 'void *data']
1533
1534    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1535                  'struct ynl_parse_arg *yarg = data;',
1536                  'const struct nlattr *attr;']
1537    init_lines = ['dst = yarg->data;']
1538
1539    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1540
1541    _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1542
1543
1544def print_req(ri):
1545    ret_ok = '0'
1546    ret_err = '-1'
1547    direction = "request"
1548    local_vars = ['struct nlmsghdr *nlh;',
1549                  'int err;']
1550
1551    if 'reply' in ri.op[ri.op_mode]:
1552        ret_ok = 'rsp'
1553        ret_err = 'NULL'
1554        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1555                       'struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };']
1556
1557    print_prototype(ri, direction, terminate=False)
1558    ri.cw.block_start()
1559    ri.cw.write_func_lvar(local_vars)
1560
1561    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1562
1563    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1564    if 'reply' in ri.op[ri.op_mode]:
1565        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1566    ri.cw.nl()
1567    for _, attr in ri.struct["request"].member_list():
1568        attr.attr_put(ri, "req")
1569    ri.cw.nl()
1570
1571    parse_arg = "NULL"
1572    if 'reply' in ri.op[ri.op_mode]:
1573        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1574        ri.cw.p('yrs.yarg.data = rsp;')
1575        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1576        if ri.op.value is not None:
1577            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1578        else:
1579            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1580        ri.cw.nl()
1581        parse_arg = '&yrs'
1582    ri.cw.p(f"err = ynl_exec(ys, nlh, {parse_arg});")
1583    ri.cw.p('if (err < 0)')
1584    if 'reply' in ri.op[ri.op_mode]:
1585        ri.cw.p('goto err_free;')
1586    else:
1587        ri.cw.p('return -1;')
1588    ri.cw.nl()
1589
1590    ri.cw.p(f"return {ret_ok};")
1591    ri.cw.nl()
1592
1593    if 'reply' in ri.op[ri.op_mode]:
1594        ri.cw.p('err_free:')
1595        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1596        ri.cw.p(f"return {ret_err};")
1597
1598    ri.cw.block_end()
1599
1600
1601def print_dump(ri):
1602    direction = "request"
1603    print_prototype(ri, direction, terminate=False)
1604    ri.cw.block_start()
1605    local_vars = ['struct ynl_dump_state yds = {};',
1606                  'struct nlmsghdr *nlh;',
1607                  'int err;']
1608
1609    for var in local_vars:
1610        ri.cw.p(f'{var}')
1611    ri.cw.nl()
1612
1613    ri.cw.p('yds.ys = ys;')
1614    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1615    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1616    if ri.op.value is not None:
1617        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1618    else:
1619        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1620    ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1621    ri.cw.nl()
1622    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1623
1624    if "request" in ri.op[ri.op_mode]:
1625        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1626        ri.cw.nl()
1627        for _, attr in ri.struct["request"].member_list():
1628            attr.attr_put(ri, "req")
1629    ri.cw.nl()
1630
1631    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1632    ri.cw.p('if (err < 0)')
1633    ri.cw.p('goto free_list;')
1634    ri.cw.nl()
1635
1636    ri.cw.p('return yds.first;')
1637    ri.cw.nl()
1638    ri.cw.p('free_list:')
1639    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1640    ri.cw.p('return NULL;')
1641    ri.cw.block_end()
1642
1643
1644def call_free(ri, direction, var):
1645    return f"{op_prefix(ri, direction)}_free({var});"
1646
1647
1648def free_arg_name(direction):
1649    if direction:
1650        return direction_to_suffix[direction][1:]
1651    return 'obj'
1652
1653
1654def print_alloc_wrapper(ri, direction):
1655    name = op_prefix(ri, direction)
1656    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1657    ri.cw.block_start()
1658    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1659    ri.cw.block_end()
1660
1661
1662def print_free_prototype(ri, direction, suffix=';'):
1663    name = op_prefix(ri, direction)
1664    struct_name = name
1665    if ri.type_name_conflict:
1666        struct_name += '_'
1667    arg = free_arg_name(direction)
1668    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1669
1670
1671def _print_type(ri, direction, struct):
1672    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1673    if not direction and ri.type_name_conflict:
1674        suffix += '_'
1675
1676    if ri.op_mode == 'dump':
1677        suffix += '_dump'
1678
1679    ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1680
1681    meta_started = False
1682    for _, attr in struct.member_list():
1683        for type_filter in ['len', 'bit']:
1684            line = attr.presence_member(ri.ku_space, type_filter)
1685            if line:
1686                if not meta_started:
1687                    ri.cw.block_start(line=f"struct")
1688                    meta_started = True
1689                ri.cw.p(line)
1690    if meta_started:
1691        ri.cw.block_end(line='_present;')
1692        ri.cw.nl()
1693
1694    for arg in struct.inherited:
1695        ri.cw.p(f"__u32 {arg};")
1696
1697    for _, attr in struct.member_list():
1698        attr.struct_member(ri)
1699
1700    ri.cw.block_end(line=';')
1701    ri.cw.nl()
1702
1703
1704def print_type(ri, direction):
1705    _print_type(ri, direction, ri.struct[direction])
1706
1707
1708def print_type_full(ri, struct):
1709    _print_type(ri, "", struct)
1710
1711
1712def print_type_helpers(ri, direction, deref=False):
1713    print_free_prototype(ri, direction)
1714    ri.cw.nl()
1715
1716    if ri.ku_space == 'user' and direction == 'request':
1717        for _, attr in ri.struct[direction].member_list():
1718            attr.setter(ri, ri.attr_set, direction, deref=deref)
1719    ri.cw.nl()
1720
1721
1722def print_req_type_helpers(ri):
1723    print_alloc_wrapper(ri, "request")
1724    print_type_helpers(ri, "request")
1725
1726
1727def print_rsp_type_helpers(ri):
1728    if 'reply' not in ri.op[ri.op_mode]:
1729        return
1730    print_type_helpers(ri, "reply")
1731
1732
1733def print_parse_prototype(ri, direction, terminate=True):
1734    suffix = "_rsp" if direction == "reply" else "_req"
1735    term = ';' if terminate else ''
1736
1737    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1738                          ['const struct nlattr **tb',
1739                           f"struct {ri.op.render_name}{suffix} *req"],
1740                          suffix=term)
1741
1742
1743def print_req_type(ri):
1744    print_type(ri, "request")
1745
1746
1747def print_req_free(ri):
1748    if 'request' not in ri.op[ri.op_mode]:
1749        return
1750    _free_type(ri, 'request', ri.struct['request'])
1751
1752
1753def print_rsp_type(ri):
1754    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1755        direction = 'reply'
1756    elif ri.op_mode == 'event':
1757        direction = 'reply'
1758    else:
1759        return
1760    print_type(ri, direction)
1761
1762
1763def print_wrapped_type(ri):
1764    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1765    if ri.op_mode == 'dump':
1766        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1767    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1768        ri.cw.p('__u16 family;')
1769        ri.cw.p('__u8 cmd;')
1770        ri.cw.p('struct ynl_ntf_base_type *next;')
1771        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1772    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1773    ri.cw.block_end(line=';')
1774    ri.cw.nl()
1775    print_free_prototype(ri, 'reply')
1776    ri.cw.nl()
1777
1778
1779def _free_type_members_iter(ri, struct):
1780    for _, attr in struct.member_list():
1781        if attr.free_needs_iter():
1782            ri.cw.p('unsigned int i;')
1783            ri.cw.nl()
1784            break
1785
1786
1787def _free_type_members(ri, var, struct, ref=''):
1788    for _, attr in struct.member_list():
1789        attr.free(ri, var, ref)
1790
1791
1792def _free_type(ri, direction, struct):
1793    var = free_arg_name(direction)
1794
1795    print_free_prototype(ri, direction, suffix='')
1796    ri.cw.block_start()
1797    _free_type_members_iter(ri, struct)
1798    _free_type_members(ri, var, struct)
1799    if direction:
1800        ri.cw.p(f'free({var});')
1801    ri.cw.block_end()
1802    ri.cw.nl()
1803
1804
1805def free_rsp_nested(ri, struct):
1806    _free_type(ri, "", struct)
1807
1808
1809def print_rsp_free(ri):
1810    if 'reply' not in ri.op[ri.op_mode]:
1811        return
1812    _free_type(ri, 'reply', ri.struct['reply'])
1813
1814
1815def print_dump_type_free(ri):
1816    sub_type = type_name(ri, 'reply')
1817
1818    print_free_prototype(ri, 'reply', suffix='')
1819    ri.cw.block_start()
1820    ri.cw.p(f"{sub_type} *next = rsp;")
1821    ri.cw.nl()
1822    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1823    _free_type_members_iter(ri, ri.struct['reply'])
1824    ri.cw.p('rsp = next;')
1825    ri.cw.p('next = rsp->next;')
1826    ri.cw.nl()
1827
1828    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1829    ri.cw.p(f'free(rsp);')
1830    ri.cw.block_end()
1831    ri.cw.block_end()
1832    ri.cw.nl()
1833
1834
1835def print_ntf_type_free(ri):
1836    print_free_prototype(ri, 'reply', suffix='')
1837    ri.cw.block_start()
1838    _free_type_members_iter(ri, ri.struct['reply'])
1839    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1840    ri.cw.p(f'free(rsp);')
1841    ri.cw.block_end()
1842    ri.cw.nl()
1843
1844
1845def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1846    if terminate and ri and kernel_can_gen_family_struct(struct.family):
1847        return
1848
1849    if terminate:
1850        prefix = 'extern '
1851    else:
1852        if kernel_can_gen_family_struct(struct.family) and ri:
1853            prefix = 'static '
1854        else:
1855            prefix = ''
1856
1857    suffix = ';' if terminate else ' = {'
1858
1859    max_attr = struct.attr_max_val
1860    if ri:
1861        name = ri.op.render_name
1862        if ri.op.dual_policy:
1863            name += '_' + ri.op_mode
1864    else:
1865        name = struct.render_name
1866    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1867
1868
1869def print_req_policy(cw, struct, ri=None):
1870    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1871    for _, arg in struct.member_list():
1872        arg.attr_policy(cw)
1873    cw.p("};")
1874
1875
1876def kernel_can_gen_family_struct(family):
1877    return family.proto == 'genetlink'
1878
1879
1880def print_kernel_op_table_fwd(family, cw, terminate):
1881    exported = not kernel_can_gen_family_struct(family)
1882
1883    if not terminate or exported:
1884        cw.p(f"/* Ops table for {family.name} */")
1885
1886        pol_to_struct = {'global': 'genl_small_ops',
1887                         'per-op': 'genl_ops',
1888                         'split': 'genl_split_ops'}
1889        struct_type = pol_to_struct[family.kernel_policy]
1890
1891        if not exported:
1892            cnt = ""
1893        elif family.kernel_policy == 'split':
1894            cnt = 0
1895            for op in family.ops.values():
1896                if 'do' in op:
1897                    cnt += 1
1898                if 'dump' in op:
1899                    cnt += 1
1900        else:
1901            cnt = len(family.ops)
1902
1903        qual = 'static const' if not exported else 'const'
1904        line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1905        if terminate:
1906            cw.p(f"extern {line};")
1907        else:
1908            cw.block_start(line=line + ' =')
1909
1910    if not terminate:
1911        return
1912
1913    cw.nl()
1914    for name in family.hooks['pre']['do']['list']:
1915        cw.write_func_prot('int', c_lower(name),
1916                           ['const struct genl_split_ops *ops',
1917                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1918    for name in family.hooks['post']['do']['list']:
1919        cw.write_func_prot('void', c_lower(name),
1920                           ['const struct genl_split_ops *ops',
1921                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1922    for name in family.hooks['pre']['dump']['list']:
1923        cw.write_func_prot('int', c_lower(name),
1924                           ['struct netlink_callback *cb'], suffix=';')
1925    for name in family.hooks['post']['dump']['list']:
1926        cw.write_func_prot('int', c_lower(name),
1927                           ['struct netlink_callback *cb'], suffix=';')
1928
1929    cw.nl()
1930
1931    for op_name, op in family.ops.items():
1932        if op.is_async:
1933            continue
1934
1935        if 'do' in op:
1936            name = c_lower(f"{family.name}-nl-{op_name}-doit")
1937            cw.write_func_prot('int', name,
1938                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1939
1940        if 'dump' in op:
1941            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1942            cw.write_func_prot('int', name,
1943                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1944    cw.nl()
1945
1946
1947def print_kernel_op_table_hdr(family, cw):
1948    print_kernel_op_table_fwd(family, cw, terminate=True)
1949
1950
1951def print_kernel_op_table(family, cw):
1952    print_kernel_op_table_fwd(family, cw, terminate=False)
1953    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1954        for op_name, op in family.ops.items():
1955            if op.is_async:
1956                continue
1957
1958            cw.block_start()
1959            members = [('cmd', op.enum_name)]
1960            if 'dont-validate' in op:
1961                members.append(('validate',
1962                                ' | '.join([c_upper('genl-dont-validate-' + x)
1963                                            for x in op['dont-validate']])), )
1964            for op_mode in ['do', 'dump']:
1965                if op_mode in op:
1966                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1967                    members.append((op_mode + 'it', name))
1968            if family.kernel_policy == 'per-op':
1969                struct = Struct(family, op['attribute-set'],
1970                                type_list=op['do']['request']['attributes'])
1971
1972                name = c_lower(f"{family.name}-{op_name}-nl-policy")
1973                members.append(('policy', name))
1974                members.append(('maxattr', struct.attr_max_val.enum_name))
1975            if 'flags' in op:
1976                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1977            cw.write_struct_init(members)
1978            cw.block_end(line=',')
1979    elif family.kernel_policy == 'split':
1980        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1981                    'dump': {'pre': 'start', 'post': 'done'}}
1982
1983        for op_name, op in family.ops.items():
1984            for op_mode in ['do', 'dump']:
1985                if op.is_async or op_mode not in op:
1986                    continue
1987
1988                cw.block_start()
1989                members = [('cmd', op.enum_name)]
1990                if 'dont-validate' in op:
1991                    dont_validate = []
1992                    for x in op['dont-validate']:
1993                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
1994                            continue
1995                        if op_mode == "dump" and x == 'strict':
1996                            continue
1997                        dont_validate.append(x)
1998
1999                    members.append(('validate',
2000                                    ' | '.join([c_upper('genl-dont-validate-' + x)
2001                                                for x in dont_validate])), )
2002                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2003                if 'pre' in op[op_mode]:
2004                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2005                members.append((op_mode + 'it', name))
2006                if 'post' in op[op_mode]:
2007                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2008                if 'request' in op[op_mode]:
2009                    struct = Struct(family, op['attribute-set'],
2010                                    type_list=op[op_mode]['request']['attributes'])
2011
2012                    if op.dual_policy:
2013                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2014                    else:
2015                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
2016                    members.append(('policy', name))
2017                    members.append(('maxattr', struct.attr_max_val.enum_name))
2018                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2019                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2020                cw.write_struct_init(members)
2021                cw.block_end(line=',')
2022
2023    cw.block_end(line=';')
2024    cw.nl()
2025
2026
2027def print_kernel_mcgrp_hdr(family, cw):
2028    if not family.mcgrps['list']:
2029        return
2030
2031    cw.block_start('enum')
2032    for grp in family.mcgrps['list']:
2033        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2034        cw.p(grp_id)
2035    cw.block_end(';')
2036    cw.nl()
2037
2038
2039def print_kernel_mcgrp_src(family, cw):
2040    if not family.mcgrps['list']:
2041        return
2042
2043    cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
2044    for grp in family.mcgrps['list']:
2045        name = grp['name']
2046        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2047        cw.p('[' + grp_id + '] = { "' + name + '", },')
2048    cw.block_end(';')
2049    cw.nl()
2050
2051
2052def print_kernel_family_struct_hdr(family, cw):
2053    if not kernel_can_gen_family_struct(family):
2054        return
2055
2056    cw.p(f"extern struct genl_family {family.name}_nl_family;")
2057    cw.nl()
2058
2059
2060def print_kernel_family_struct_src(family, cw):
2061    if not kernel_can_gen_family_struct(family):
2062        return
2063
2064    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2065    cw.p('.name\t\t= ' + family.fam_key + ',')
2066    cw.p('.version\t= ' + family.ver_key + ',')
2067    cw.p('.netnsok\t= true,')
2068    cw.p('.parallel_ops\t= true,')
2069    cw.p('.module\t\t= THIS_MODULE,')
2070    if family.kernel_policy == 'per-op':
2071        cw.p(f'.ops\t\t= {family.name}_nl_ops,')
2072        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
2073    elif family.kernel_policy == 'split':
2074        cw.p(f'.split_ops\t= {family.name}_nl_ops,')
2075        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
2076    if family.mcgrps['list']:
2077        cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
2078        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
2079    cw.block_end(';')
2080
2081
2082def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2083    start_line = 'enum'
2084    if enum_name in obj:
2085        if obj[enum_name]:
2086            start_line = 'enum ' + c_lower(obj[enum_name])
2087    elif ckey and ckey in obj:
2088        start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
2089    cw.block_start(line=start_line)
2090
2091
2092def render_uapi(family, cw):
2093    hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
2094    cw.p('#ifndef ' + hdr_prot)
2095    cw.p('#define ' + hdr_prot)
2096    cw.nl()
2097
2098    defines = [(family.fam_key, family["name"]),
2099               (family.ver_key, family.get('version', 1))]
2100    cw.writes_defines(defines)
2101    cw.nl()
2102
2103    defines = []
2104    for const in family['definitions']:
2105        if const['type'] != 'const':
2106            cw.writes_defines(defines)
2107            defines = []
2108            cw.nl()
2109
2110        # Write kdoc for enum and flags (one day maybe also structs)
2111        if const['type'] == 'enum' or const['type'] == 'flags':
2112            enum = family.consts[const['name']]
2113
2114            if enum.has_doc():
2115                cw.p('/**')
2116                doc = ''
2117                if 'doc' in enum:
2118                    doc = ' - ' + enum['doc']
2119                cw.write_doc_line(enum.enum_name + doc)
2120                for entry in enum.entries.values():
2121                    if entry.has_doc():
2122                        doc = '@' + entry.c_name + ': ' + entry['doc']
2123                        cw.write_doc_line(doc)
2124                cw.p(' */')
2125
2126            uapi_enum_start(family, cw, const, 'name')
2127            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2128            for entry in enum.entries.values():
2129                suffix = ','
2130                if entry.value_change:
2131                    suffix = f" = {entry.user_value()}" + suffix
2132                cw.p(entry.c_name + suffix)
2133
2134            if const.get('render-max', False):
2135                cw.nl()
2136                cw.p('/* private: */')
2137                if const['type'] == 'flags':
2138                    max_name = c_upper(name_pfx + 'mask')
2139                    max_val = f' = {enum.get_mask()},'
2140                    cw.p(max_name + max_val)
2141                else:
2142                    max_name = c_upper(name_pfx + 'max')
2143                    cw.p('__' + max_name + ',')
2144                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2145            cw.block_end(line=';')
2146            cw.nl()
2147        elif const['type'] == 'const':
2148            defines.append([c_upper(family.get('c-define-name',
2149                                               f"{family.name}-{const['name']}")),
2150                            const['value']])
2151
2152    if defines:
2153        cw.writes_defines(defines)
2154        cw.nl()
2155
2156    max_by_define = family.get('max-by-define', False)
2157
2158    for _, attr_set in family.attr_sets.items():
2159        if attr_set.subset_of:
2160            continue
2161
2162        cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2163        max_value = f"({cnt_name} - 1)"
2164
2165        val = 0
2166        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2167        for _, attr in attr_set.items():
2168            suffix = ','
2169            if attr.value != val:
2170                suffix = f" = {attr.value},"
2171                val = attr.value
2172            val += 1
2173            cw.p(attr.enum_name + suffix)
2174        cw.nl()
2175        cw.p(cnt_name + ('' if max_by_define else ','))
2176        if not max_by_define:
2177            cw.p(f"{attr_set.max_name} = {max_value}")
2178        cw.block_end(line=';')
2179        if max_by_define:
2180            cw.p(f"#define {attr_set.max_name} {max_value}")
2181        cw.nl()
2182
2183    # Commands
2184    separate_ntf = 'async-prefix' in family['operations']
2185
2186    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2187    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2188    max_value = f"({cnt_name} - 1)"
2189
2190    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2191    val = 0
2192    for op in family.msgs.values():
2193        if separate_ntf and ('notify' in op or 'event' in op):
2194            continue
2195
2196        suffix = ','
2197        if op.value != val:
2198            suffix = f" = {op.value},"
2199            val = op.value
2200        cw.p(op.enum_name + suffix)
2201        val += 1
2202    cw.nl()
2203    cw.p(cnt_name + ('' if max_by_define else ','))
2204    if not max_by_define:
2205        cw.p(f"{max_name} = {max_value}")
2206    cw.block_end(line=';')
2207    if max_by_define:
2208        cw.p(f"#define {max_name} {max_value}")
2209    cw.nl()
2210
2211    if separate_ntf:
2212        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2213        for op in family.msgs.values():
2214            if separate_ntf and not ('notify' in op or 'event' in op):
2215                continue
2216
2217            suffix = ','
2218            if 'value' in op:
2219                suffix = f" = {op['value']},"
2220            cw.p(op.enum_name + suffix)
2221        cw.block_end(line=';')
2222        cw.nl()
2223
2224    # Multicast
2225    defines = []
2226    for grp in family.mcgrps['list']:
2227        name = grp['name']
2228        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2229                        f'{name}'])
2230    cw.nl()
2231    if defines:
2232        cw.writes_defines(defines)
2233        cw.nl()
2234
2235    cw.p(f'#endif /* {hdr_prot} */')
2236
2237
2238def _render_user_ntf_entry(ri, op):
2239    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2240    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2241    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2242    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2243    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2244    ri.cw.block_end(line=',')
2245
2246
2247def render_user_family(family, cw, prototype):
2248    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2249    if prototype:
2250        cw.p(f'extern {symbol};')
2251        return
2252
2253    if family.ntfs:
2254        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2255        for ntf_op_name, ntf_op in family.ntfs.items():
2256            if 'notify' in ntf_op:
2257                op = family.ops[ntf_op['notify']]
2258                ri = RenderInfo(cw, family, "user", op, "notify")
2259            elif 'event' in ntf_op:
2260                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2261            else:
2262                raise Exception('Invalid notification ' + ntf_op_name)
2263            _render_user_ntf_entry(ri, ntf_op)
2264        for op_name, op in family.ops.items():
2265            if 'event' not in op:
2266                continue
2267            ri = RenderInfo(cw, family, "user", op, "event")
2268            _render_user_ntf_entry(ri, op)
2269        cw.block_end(line=";")
2270        cw.nl()
2271
2272    cw.block_start(f'{symbol} = ')
2273    cw.p(f'.name\t\t= "{family.name}",')
2274    if family.ntfs:
2275        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2276        cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2277    cw.block_end(line=';')
2278
2279
2280def find_kernel_root(full_path):
2281    sub_path = ''
2282    while True:
2283        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2284        full_path = os.path.dirname(full_path)
2285        maintainers = os.path.join(full_path, "MAINTAINERS")
2286        if os.path.exists(maintainers):
2287            return full_path, sub_path[:-1]
2288
2289
2290def main():
2291    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2292    parser.add_argument('--mode', dest='mode', type=str, required=True)
2293    parser.add_argument('--spec', dest='spec', type=str, required=True)
2294    parser.add_argument('--header', dest='header', action='store_true', default=None)
2295    parser.add_argument('--source', dest='header', action='store_false')
2296    parser.add_argument('--user-header', nargs='+', default=[])
2297    parser.add_argument('--exclude-op', action='append', default=[])
2298    parser.add_argument('-o', dest='out_file', type=str)
2299    args = parser.parse_args()
2300
2301    out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout
2302
2303    if args.header is None:
2304        parser.error("--header or --source is required")
2305
2306    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2307
2308    try:
2309        parsed = Family(args.spec, exclude_ops)
2310        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2311            print('Spec license:', parsed.license)
2312            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2313            os.sys.exit(1)
2314    except yaml.YAMLError as exc:
2315        print(exc)
2316        os.sys.exit(1)
2317        return
2318
2319    supported_models = ['unified']
2320    if args.mode == 'user':
2321        supported_models += ['directional']
2322    if parsed.msg_id_model not in supported_models:
2323        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2324        os.sys.exit(1)
2325
2326    cw = CodeWriter(BaseNlLib(), out_file)
2327
2328    _, spec_kernel = find_kernel_root(args.spec)
2329    if args.mode == 'uapi' or args.header:
2330        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2331    else:
2332        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2333    cw.p("/* Do not edit directly, auto-generated from: */")
2334    cw.p(f"/*\t{spec_kernel} */")
2335    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2336    if args.exclude_op or args.user_header:
2337        line = ''
2338        line += ' --user-header '.join([''] + args.user_header)
2339        line += ' --exclude-op '.join([''] + args.exclude_op)
2340        cw.p(f'/* YNL-ARG{line} */')
2341    cw.nl()
2342
2343    if args.mode == 'uapi':
2344        render_uapi(parsed, cw)
2345        return
2346
2347    hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2348    if args.header:
2349        cw.p('#ifndef ' + hdr_prot)
2350        cw.p('#define ' + hdr_prot)
2351        cw.nl()
2352
2353    if args.mode == 'kernel':
2354        cw.p('#include <net/netlink.h>')
2355        cw.p('#include <net/genetlink.h>')
2356        cw.nl()
2357        if not args.header:
2358            if args.out_file:
2359                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2360            cw.nl()
2361        headers = ['uapi/' + parsed.uapi_header]
2362    else:
2363        cw.p('#include <stdlib.h>')
2364        cw.p('#include <string.h>')
2365        if args.header:
2366            cw.p('#include <linux/types.h>')
2367        else:
2368            cw.p(f'#include "{parsed.name}-user.h"')
2369            cw.p('#include "ynl.h"')
2370        headers = [parsed.uapi_header]
2371    for definition in parsed['definitions']:
2372        if 'header' in definition:
2373            headers.append(definition['header'])
2374    for one in headers:
2375        cw.p(f"#include <{one}>")
2376    cw.nl()
2377
2378    if args.mode == "user":
2379        if not args.header:
2380            cw.p("#include <libmnl/libmnl.h>")
2381            cw.p("#include <linux/genetlink.h>")
2382            cw.nl()
2383            for one in args.user_header:
2384                cw.p(f'#include "{one}"')
2385        else:
2386            cw.p('struct ynl_sock;')
2387            cw.nl()
2388            render_user_family(parsed, cw, True)
2389        cw.nl()
2390
2391    if args.mode == "kernel":
2392        if args.header:
2393            for _, struct in sorted(parsed.pure_nested_structs.items()):
2394                if struct.request:
2395                    cw.p('/* Common nested types */')
2396                    break
2397            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2398                if struct.request:
2399                    print_req_policy_fwd(cw, struct)
2400            cw.nl()
2401
2402            if parsed.kernel_policy == 'global':
2403                cw.p(f"/* Global operation policy for {parsed.name} */")
2404
2405                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2406                print_req_policy_fwd(cw, struct)
2407                cw.nl()
2408
2409            if parsed.kernel_policy in {'per-op', 'split'}:
2410                for op_name, op in parsed.ops.items():
2411                    if 'do' in op and 'event' not in op:
2412                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2413                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2414                        cw.nl()
2415
2416            print_kernel_op_table_hdr(parsed, cw)
2417            print_kernel_mcgrp_hdr(parsed, cw)
2418            print_kernel_family_struct_hdr(parsed, cw)
2419        else:
2420            for _, struct in sorted(parsed.pure_nested_structs.items()):
2421                if struct.request:
2422                    cw.p('/* Common nested types */')
2423                    break
2424            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2425                if struct.request:
2426                    print_req_policy(cw, struct)
2427            cw.nl()
2428
2429            if parsed.kernel_policy == 'global':
2430                cw.p(f"/* Global operation policy for {parsed.name} */")
2431
2432                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2433                print_req_policy(cw, struct)
2434                cw.nl()
2435
2436            for op_name, op in parsed.ops.items():
2437                if parsed.kernel_policy in {'per-op', 'split'}:
2438                    for op_mode in ['do', 'dump']:
2439                        if op_mode in op and 'request' in op[op_mode]:
2440                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2441                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2442                            print_req_policy(cw, ri.struct['request'], ri=ri)
2443                            cw.nl()
2444
2445            print_kernel_op_table(parsed, cw)
2446            print_kernel_mcgrp_src(parsed, cw)
2447            print_kernel_family_struct_src(parsed, cw)
2448
2449    if args.mode == "user":
2450        if args.header:
2451            cw.p('/* Enums */')
2452            put_op_name_fwd(parsed, cw)
2453
2454            for name, const in parsed.consts.items():
2455                if isinstance(const, EnumSet):
2456                    put_enum_to_str_fwd(parsed, cw, const)
2457            cw.nl()
2458
2459            cw.p('/* Common nested types */')
2460            for attr_set, struct in parsed.pure_nested_structs.items():
2461                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2462                print_type_full(ri, struct)
2463
2464            for op_name, op in parsed.ops.items():
2465                cw.p(f"/* ============== {op.enum_name} ============== */")
2466
2467                if 'do' in op and 'event' not in op:
2468                    cw.p(f"/* {op.enum_name} - do */")
2469                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2470                    print_req_type(ri)
2471                    print_req_type_helpers(ri)
2472                    cw.nl()
2473                    print_rsp_type(ri)
2474                    print_rsp_type_helpers(ri)
2475                    cw.nl()
2476                    print_req_prototype(ri)
2477                    cw.nl()
2478
2479                if 'dump' in op:
2480                    cw.p(f"/* {op.enum_name} - dump */")
2481                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2482                    if 'request' in op['dump']:
2483                        print_req_type(ri)
2484                        print_req_type_helpers(ri)
2485                    if not ri.type_consistent:
2486                        print_rsp_type(ri)
2487                    print_wrapped_type(ri)
2488                    print_dump_prototype(ri)
2489                    cw.nl()
2490
2491                if op.has_ntf:
2492                    cw.p(f"/* {op.enum_name} - notify */")
2493                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2494                    if not ri.type_consistent:
2495                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2496                    print_wrapped_type(ri)
2497
2498            for op_name, op in parsed.ntfs.items():
2499                if 'event' in op:
2500                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2501                    cw.p(f"/* {op.enum_name} - event */")
2502                    print_rsp_type(ri)
2503                    cw.nl()
2504                    print_wrapped_type(ri)
2505            cw.nl()
2506        else:
2507            cw.p('/* Enums */')
2508            put_op_name(parsed, cw)
2509
2510            for name, const in parsed.consts.items():
2511                if isinstance(const, EnumSet):
2512                    put_enum_to_str(parsed, cw, const)
2513            cw.nl()
2514
2515            cw.p('/* Policies */')
2516            for name in parsed.pure_nested_structs:
2517                struct = Struct(parsed, name)
2518                put_typol(cw, struct)
2519            for name in parsed.root_sets:
2520                struct = Struct(parsed, name)
2521                put_typol(cw, struct)
2522
2523            cw.p('/* Common nested types */')
2524            for attr_set, struct in parsed.pure_nested_structs.items():
2525                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2526
2527                free_rsp_nested(ri, struct)
2528                if struct.request:
2529                    put_req_nested(ri, struct)
2530                if struct.reply:
2531                    parse_rsp_nested(ri, struct)
2532
2533            for op_name, op in parsed.ops.items():
2534                cw.p(f"/* ============== {op.enum_name} ============== */")
2535                if 'do' in op and 'event' not in op:
2536                    cw.p(f"/* {op.enum_name} - do */")
2537                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2538                    print_req_free(ri)
2539                    print_rsp_free(ri)
2540                    parse_rsp_msg(ri)
2541                    print_req(ri)
2542                    cw.nl()
2543
2544                if 'dump' in op:
2545                    cw.p(f"/* {op.enum_name} - dump */")
2546                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2547                    if not ri.type_consistent:
2548                        parse_rsp_msg(ri, deref=True)
2549                    print_dump_type_free(ri)
2550                    print_dump(ri)
2551                    cw.nl()
2552
2553                if op.has_ntf:
2554                    cw.p(f"/* {op.enum_name} - notify */")
2555                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2556                    if not ri.type_consistent:
2557                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2558                    print_ntf_type_free(ri)
2559
2560            for op_name, op in parsed.ntfs.items():
2561                if 'event' in op:
2562                    cw.p(f"/* {op.enum_name} - event */")
2563
2564                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2565                    parse_rsp_msg(ri)
2566
2567                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2568                    print_ntf_type_free(ri)
2569            cw.nl()
2570            render_user_family(parsed, cw, False)
2571
2572    if args.header:
2573        cw.p(f'#endif /* {hdr_prot} */')
2574
2575
2576if __name__ == "__main__":
2577    main()
2578