xref: /openbmc/linux/tools/net/ynl/ynl-gen-c.py (revision a9ca9f9c)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import os
7import re
8import yaml
9
10from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
11
12
13def c_upper(name):
14    return name.upper().replace('-', '_')
15
16
17def c_lower(name):
18    return name.lower().replace('-', '_')
19
20
21class BaseNlLib:
22    def get_family_id(self):
23        return 'ys->family_id'
24
25    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
26        ind = '\n\t\t' + '\t' * indent + ' '
27        if is_dump:
28            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
29        else:
30            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
31                   "ynl_cb_array, NLMSG_MIN_TYPE)"
32
33
34class Type(SpecAttr):
35    def __init__(self, family, attr_set, attr, value):
36        super().__init__(family, attr_set, attr, value)
37
38        self.attr = attr
39        self.attr_set = attr_set
40        self.type = attr['type']
41        self.checks = attr.get('checks', {})
42
43        if 'len' in attr:
44            self.len = attr['len']
45        if 'nested-attributes' in attr:
46            self.nested_attrs = attr['nested-attributes']
47            if self.nested_attrs == family.name:
48                self.nested_render_name = f"{family.name}"
49            else:
50                self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
51
52            if self.nested_attrs in self.family.consts:
53                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
54            else:
55                self.nested_struct_type = 'struct ' + self.nested_render_name
56
57        self.c_name = c_lower(self.name)
58        if self.c_name in _C_KW:
59            self.c_name += '_'
60
61        # Added by resolve():
62        self.enum_name = None
63        delattr(self, "enum_name")
64
65    def resolve(self):
66        if 'name-prefix' in self.attr:
67            enum_name = f"{self.attr['name-prefix']}{self.name}"
68        else:
69            enum_name = f"{self.attr_set.name_prefix}{self.name}"
70        self.enum_name = c_upper(enum_name)
71
72    def is_multi_val(self):
73        return None
74
75    def is_scalar(self):
76        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
77
78    def presence_type(self):
79        return 'bit'
80
81    def presence_member(self, space, type_filter):
82        if self.presence_type() != type_filter:
83            return
84
85        if self.presence_type() == 'bit':
86            pfx = '__' if space == 'user' else ''
87            return f"{pfx}u32 {self.c_name}:1;"
88
89        if self.presence_type() == 'len':
90            pfx = '__' if space == 'user' else ''
91            return f"{pfx}u32 {self.c_name}_len;"
92
93    def _complex_member_type(self, ri):
94        return None
95
96    def free_needs_iter(self):
97        return False
98
99    def free(self, ri, var, ref):
100        if self.is_multi_val() or self.presence_type() == 'len':
101            ri.cw.p(f'free({var}->{ref}{self.c_name});')
102
103    def arg_member(self, ri):
104        member = self._complex_member_type(ri)
105        if member:
106            arg = [member + ' *' + self.c_name]
107            if self.presence_type() == 'count':
108                arg += ['unsigned int n_' + self.c_name]
109            return arg
110        raise Exception(f"Struct member not implemented for class type {self.type}")
111
112    def struct_member(self, ri):
113        if self.is_multi_val():
114            ri.cw.p(f"unsigned int n_{self.c_name};")
115        member = self._complex_member_type(ri)
116        if member:
117            ptr = '*' if self.is_multi_val() else ''
118            ri.cw.p(f"{member} {ptr}{self.c_name};")
119            return
120        members = self.arg_member(ri)
121        for one in members:
122            ri.cw.p(one + ';')
123
124    def _attr_policy(self, policy):
125        return '{ .type = ' + policy + ', }'
126
127    def attr_policy(self, cw):
128        policy = c_upper('nla-' + self.attr['type'])
129
130        spec = self._attr_policy(policy)
131        cw.p(f"\t[{self.enum_name}] = {spec},")
132
133    def _attr_typol(self):
134        raise Exception(f"Type policy not implemented for class type {self.type}")
135
136    def attr_typol(self, cw):
137        typol = self._attr_typol()
138        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
139
140    def _attr_put_line(self, ri, var, line):
141        if self.presence_type() == 'bit':
142            ri.cw.p(f"if ({var}->_present.{self.c_name})")
143        elif self.presence_type() == 'len':
144            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
145        ri.cw.p(f"{line};")
146
147    def _attr_put_simple(self, ri, var, put_type):
148        line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
149        self._attr_put_line(ri, var, line)
150
151    def attr_put(self, ri, var):
152        raise Exception(f"Put not implemented for class type {self.type}")
153
154    def _attr_get(self, ri, var):
155        raise Exception(f"Attr get not implemented for class type {self.type}")
156
157    def attr_get(self, ri, var, first):
158        lines, init_lines, local_vars = self._attr_get(ri, var)
159        if type(lines) is str:
160            lines = [lines]
161        if type(init_lines) is str:
162            init_lines = [init_lines]
163
164        kw = 'if' if first else 'else if'
165        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
166        if local_vars:
167            for local in local_vars:
168                ri.cw.p(local)
169            ri.cw.nl()
170
171        if not self.is_multi_val():
172            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
173            ri.cw.p("return MNL_CB_ERROR;")
174            if self.presence_type() == 'bit':
175                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
176
177        if init_lines:
178            ri.cw.nl()
179            for line in init_lines:
180                ri.cw.p(line)
181
182        for line in lines:
183            ri.cw.p(line)
184        ri.cw.block_end()
185        return True
186
187    def _setter_lines(self, ri, member, presence):
188        raise Exception(f"Setter not implemented for class type {self.type}")
189
190    def setter(self, ri, space, direction, deref=False, ref=None):
191        ref = (ref if ref else []) + [self.c_name]
192        var = "req"
193        member = f"{var}->{'.'.join(ref)}"
194
195        code = []
196        presence = ''
197        for i in range(0, len(ref)):
198            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
199            if self.presence_type() == 'bit':
200                code.append(presence + ' = 1;')
201        code += self._setter_lines(ri, member, presence)
202
203        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
204        free = bool([x for x in code if 'free(' in x])
205        alloc = bool([x for x in code if 'alloc(' in x])
206        if free and not alloc:
207            func_name = '__' + func_name
208        ri.cw.write_func('static inline void', func_name, body=code,
209                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
210
211
212class TypeUnused(Type):
213    def presence_type(self):
214        return ''
215
216    def arg_member(self, ri):
217        return []
218
219    def _attr_get(self, ri, var):
220        return ['return MNL_CB_ERROR;'], None, None
221
222    def _attr_typol(self):
223        return '.type = YNL_PT_REJECT, '
224
225    def attr_policy(self, cw):
226        pass
227
228
229class TypePad(Type):
230    def presence_type(self):
231        return ''
232
233    def arg_member(self, ri):
234        return []
235
236    def _attr_typol(self):
237        return '.type = YNL_PT_IGNORE, '
238
239    def attr_put(self, ri, var):
240        pass
241
242    def attr_get(self, ri, var, first):
243        pass
244
245    def attr_policy(self, cw):
246        pass
247
248    def setter(self, ri, space, direction, deref=False, ref=None):
249        pass
250
251
252class TypeScalar(Type):
253    def __init__(self, family, attr_set, attr, value):
254        super().__init__(family, attr_set, attr, value)
255
256        self.byte_order_comment = ''
257        if 'byte-order' in attr:
258            self.byte_order_comment = f" /* {attr['byte-order']} */"
259
260        # Added by resolve():
261        self.is_bitfield = None
262        delattr(self, "is_bitfield")
263        self.type_name = None
264        delattr(self, "type_name")
265
266    def resolve(self):
267        self.resolve_up(super())
268
269        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
270            self.is_bitfield = True
271        elif 'enum' in self.attr:
272            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
273        else:
274            self.is_bitfield = False
275
276        maybe_enum = not self.is_bitfield and 'enum' in self.attr
277        if maybe_enum and self.family.consts[self.attr['enum']].enum_name:
278            self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
279        else:
280            self.type_name = '__' + self.type
281
282    def _mnl_type(self):
283        t = self.type
284        # mnl does not have a helper for signed types
285        if t[0] == 's':
286            t = 'u' + t[1:]
287        return t
288
289    def _attr_policy(self, policy):
290        if 'flags-mask' in self.checks or self.is_bitfield:
291            if self.is_bitfield:
292                enum = self.family.consts[self.attr['enum']]
293                mask = enum.get_mask(as_flags=True)
294            else:
295                flags = self.family.consts[self.checks['flags-mask']]
296                flag_cnt = len(flags['entries'])
297                mask = (1 << flag_cnt) - 1
298            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
299        elif 'min' in self.checks:
300            return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
301        elif 'enum' in self.attr:
302            enum = self.family.consts[self.attr['enum']]
303            low, high = enum.value_range()
304            if low == 0:
305                return f"NLA_POLICY_MAX({policy}, {high})"
306            return f"NLA_POLICY_RANGE({policy}, {low}, {high})"
307        return super()._attr_policy(policy)
308
309    def _attr_typol(self):
310        return f'.type = YNL_PT_U{self.type[1:]}, '
311
312    def arg_member(self, ri):
313        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
314
315    def attr_put(self, ri, var):
316        self._attr_put_simple(ri, var, self._mnl_type())
317
318    def _attr_get(self, ri, var):
319        return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
320
321    def _setter_lines(self, ri, member, presence):
322        return [f"{member} = {self.c_name};"]
323
324
325class TypeFlag(Type):
326    def arg_member(self, ri):
327        return []
328
329    def _attr_typol(self):
330        return '.type = YNL_PT_FLAG, '
331
332    def attr_put(self, ri, var):
333        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
334
335    def _attr_get(self, ri, var):
336        return [], None, None
337
338    def _setter_lines(self, ri, member, presence):
339        return []
340
341
342class TypeString(Type):
343    def arg_member(self, ri):
344        return [f"const char *{self.c_name}"]
345
346    def presence_type(self):
347        return 'len'
348
349    def struct_member(self, ri):
350        ri.cw.p(f"char *{self.c_name};")
351
352    def _attr_typol(self):
353        return f'.type = YNL_PT_NUL_STR, '
354
355    def _attr_policy(self, policy):
356        mem = '{ .type = ' + policy
357        if 'max-len' in self.checks:
358            mem += ', .len = ' + str(self.checks['max-len'])
359        mem += ', }'
360        return mem
361
362    def attr_policy(self, cw):
363        if self.checks.get('unterminated-ok', False):
364            policy = 'NLA_STRING'
365        else:
366            policy = 'NLA_NUL_STRING'
367
368        spec = self._attr_policy(policy)
369        cw.p(f"\t[{self.enum_name}] = {spec},")
370
371    def attr_put(self, ri, var):
372        self._attr_put_simple(ri, var, 'strz')
373
374    def _attr_get(self, ri, var):
375        len_mem = var + '->_present.' + self.c_name + '_len'
376        return [f"{len_mem} = len;",
377                f"{var}->{self.c_name} = malloc(len + 1);",
378                f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
379                f"{var}->{self.c_name}[len] = 0;"], \
380               ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
381               ['unsigned int len;']
382
383    def _setter_lines(self, ri, member, presence):
384        return [f"free({member});",
385                f"{presence}_len = strlen({self.c_name});",
386                f"{member} = malloc({presence}_len + 1);",
387                f'memcpy({member}, {self.c_name}, {presence}_len);',
388                f'{member}[{presence}_len] = 0;']
389
390
391class TypeBinary(Type):
392    def arg_member(self, ri):
393        return [f"const void *{self.c_name}", 'size_t len']
394
395    def presence_type(self):
396        return 'len'
397
398    def struct_member(self, ri):
399        ri.cw.p(f"void *{self.c_name};")
400
401    def _attr_typol(self):
402        return f'.type = YNL_PT_BINARY,'
403
404    def _attr_policy(self, policy):
405        mem = '{ '
406        if len(self.checks) == 1 and 'min-len' in self.checks:
407            mem += '.len = ' + str(self.checks['min-len'])
408        elif len(self.checks) == 0:
409            mem += '.type = NLA_BINARY'
410        else:
411            raise Exception('One or more of binary type checks not implemented, yet')
412        mem += ', }'
413        return mem
414
415    def attr_put(self, ri, var):
416        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
417                            f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
418
419    def _attr_get(self, ri, var):
420        len_mem = var + '->_present.' + self.c_name + '_len'
421        return [f"{len_mem} = len;",
422                f"{var}->{self.c_name} = malloc(len);",
423                f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
424               ['len = mnl_attr_get_payload_len(attr);'], \
425               ['unsigned int len;']
426
427    def _setter_lines(self, ri, member, presence):
428        return [f"free({member});",
429                f"{member} = malloc({presence}_len);",
430                f'memcpy({member}, {self.c_name}, {presence}_len);']
431
432
433class TypeNest(Type):
434    def _complex_member_type(self, ri):
435        return self.nested_struct_type
436
437    def free(self, ri, var, ref):
438        ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
439
440    def _attr_typol(self):
441        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
442
443    def _attr_policy(self, policy):
444        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
445
446    def attr_put(self, ri, var):
447        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
448                            f"{self.enum_name}, &{var}->{self.c_name})")
449
450    def _attr_get(self, ri, var):
451        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
452                     "return MNL_CB_ERROR;"]
453        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
454                      f"parg.data = &{var}->{self.c_name};"]
455        return get_lines, init_lines, None
456
457    def setter(self, ri, space, direction, deref=False, ref=None):
458        ref = (ref if ref else []) + [self.c_name]
459
460        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
461            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
462
463
464class TypeMultiAttr(Type):
465    def __init__(self, family, attr_set, attr, value, base_type):
466        super().__init__(family, attr_set, attr, value)
467
468        self.base_type = base_type
469
470    def is_multi_val(self):
471        return True
472
473    def presence_type(self):
474        return 'count'
475
476    def _mnl_type(self):
477        t = self.type
478        # mnl does not have a helper for signed types
479        if t[0] == 's':
480            t = 'u' + t[1:]
481        return t
482
483    def _complex_member_type(self, ri):
484        if 'type' not in self.attr or self.attr['type'] == 'nest':
485            return self.nested_struct_type
486        elif self.attr['type'] in scalars:
487            scalar_pfx = '__' if ri.ku_space == 'user' else ''
488            return scalar_pfx + self.attr['type']
489        else:
490            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
491
492    def free_needs_iter(self):
493        return 'type' not in self.attr or self.attr['type'] == 'nest'
494
495    def free(self, ri, var, ref):
496        if self.attr['type'] in scalars:
497            ri.cw.p(f"free({var}->{ref}{self.c_name});")
498        elif 'type' not in self.attr or self.attr['type'] == 'nest':
499            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
500            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
501            ri.cw.p(f"free({var}->{ref}{self.c_name});")
502        else:
503            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
504
505    def _attr_policy(self, policy):
506        return self.base_type._attr_policy(policy)
507
508    def _attr_typol(self):
509        return self.base_type._attr_typol()
510
511    def _attr_get(self, ri, var):
512        return f'n_{self.c_name}++;', None, None
513
514    def attr_put(self, ri, var):
515        if self.attr['type'] in scalars:
516            put_type = self._mnl_type()
517            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
518            ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
519        elif 'type' not in self.attr or self.attr['type'] == 'nest':
520            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
521            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
522                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
523        else:
524            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
525
526    def _setter_lines(self, ri, member, presence):
527        # For multi-attr we have a count, not presence, hack up the presence
528        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
529        return [f"free({member});",
530                f"{member} = {self.c_name};",
531                f"{presence} = n_{self.c_name};"]
532
533
534class TypeArrayNest(Type):
535    def is_multi_val(self):
536        return True
537
538    def presence_type(self):
539        return 'count'
540
541    def _complex_member_type(self, ri):
542        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
543            return self.nested_struct_type
544        elif self.attr['sub-type'] in scalars:
545            scalar_pfx = '__' if ri.ku_space == 'user' else ''
546            return scalar_pfx + self.attr['sub-type']
547        else:
548            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
549
550    def _attr_typol(self):
551        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
552
553    def _attr_get(self, ri, var):
554        local_vars = ['const struct nlattr *attr2;']
555        get_lines = [f'attr_{self.c_name} = attr;',
556                     'mnl_attr_for_each_nested(attr2, attr)',
557                     f'\t{var}->n_{self.c_name}++;']
558        return get_lines, None, local_vars
559
560
561class TypeNestTypeValue(Type):
562    def _complex_member_type(self, ri):
563        return self.nested_struct_type
564
565    def _attr_typol(self):
566        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
567
568    def _attr_get(self, ri, var):
569        prev = 'attr'
570        tv_args = ''
571        get_lines = []
572        local_vars = []
573        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
574                      f"parg.data = &{var}->{self.c_name};"]
575        if 'type-value' in self.attr:
576            tv_names = [c_lower(x) for x in self.attr["type-value"]]
577            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
578            local_vars += [f'__u32 {", ".join(tv_names)};']
579            for level in self.attr["type-value"]:
580                level = c_lower(level)
581                get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
582                get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
583                prev = 'attr_' + level
584
585            tv_args = f", {', '.join(tv_names)}"
586
587        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
588        return get_lines, init_lines, local_vars
589
590
591class Struct:
592    def __init__(self, family, space_name, type_list=None, inherited=None):
593        self.family = family
594        self.space_name = space_name
595        self.attr_set = family.attr_sets[space_name]
596        # Use list to catch comparisons with empty sets
597        self._inherited = inherited if inherited is not None else []
598        self.inherited = []
599
600        self.nested = type_list is None
601        if family.name == c_lower(space_name):
602            self.render_name = f"{family.name}"
603        else:
604            self.render_name = f"{family.name}_{c_lower(space_name)}"
605        self.struct_name = 'struct ' + self.render_name
606        if self.nested and space_name in family.consts:
607            self.struct_name += '_'
608        self.ptr_name = self.struct_name + ' *'
609
610        self.request = False
611        self.reply = False
612
613        self.attr_list = []
614        self.attrs = dict()
615        if type_list:
616            for t in type_list:
617                self.attr_list.append((t, self.attr_set[t]),)
618        else:
619            for t in self.attr_set:
620                self.attr_list.append((t, self.attr_set[t]),)
621
622        max_val = 0
623        self.attr_max_val = None
624        for name, attr in self.attr_list:
625            if attr.value >= max_val:
626                max_val = attr.value
627                self.attr_max_val = attr
628            self.attrs[name] = attr
629
630    def __iter__(self):
631        yield from self.attrs
632
633    def __getitem__(self, key):
634        return self.attrs[key]
635
636    def member_list(self):
637        return self.attr_list
638
639    def set_inherited(self, new_inherited):
640        if self._inherited != new_inherited:
641            raise Exception("Inheriting different members not supported")
642        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
643
644
645class EnumEntry(SpecEnumEntry):
646    def __init__(self, enum_set, yaml, prev, value_start):
647        super().__init__(enum_set, yaml, prev, value_start)
648
649        if prev:
650            self.value_change = (self.value != prev.value + 1)
651        else:
652            self.value_change = (self.value != 0)
653        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
654
655        # Added by resolve:
656        self.c_name = None
657        delattr(self, "c_name")
658
659    def resolve(self):
660        self.resolve_up(super())
661
662        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
663
664
665class EnumSet(SpecEnumSet):
666    def __init__(self, family, yaml):
667        self.render_name = c_lower(family.name + '-' + yaml['name'])
668
669        if 'enum-name' in yaml:
670            if yaml['enum-name']:
671                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
672            else:
673                self.enum_name = None
674        else:
675            self.enum_name = 'enum ' + self.render_name
676
677        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
678
679        super().__init__(family, yaml)
680
681    def new_entry(self, entry, prev_entry, value_start):
682        return EnumEntry(self, entry, prev_entry, value_start)
683
684    def value_range(self):
685        low = min([x.value for x in self.entries.values()])
686        high = max([x.value for x in self.entries.values()])
687
688        if high - low + 1 != len(self.entries):
689            raise Exception("Can't get value range for a noncontiguous enum")
690
691        return low, high
692
693
694class AttrSet(SpecAttrSet):
695    def __init__(self, family, yaml):
696        super().__init__(family, yaml)
697
698        if self.subset_of is None:
699            if 'name-prefix' in yaml:
700                pfx = yaml['name-prefix']
701            elif self.name == family.name:
702                pfx = family.name + '-a-'
703            else:
704                pfx = f"{family.name}-a-{self.name}-"
705            self.name_prefix = c_upper(pfx)
706            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
707        else:
708            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
709            self.max_name = family.attr_sets[self.subset_of].max_name
710
711        # Added by resolve:
712        self.c_name = None
713        delattr(self, "c_name")
714
715    def resolve(self):
716        self.c_name = c_lower(self.name)
717        if self.c_name in _C_KW:
718            self.c_name += '_'
719        if self.c_name == self.family.c_name:
720            self.c_name = ''
721
722    def new_attr(self, elem, value):
723        if elem['type'] in scalars:
724            t = TypeScalar(self.family, self, elem, value)
725        elif elem['type'] == 'unused':
726            t = TypeUnused(self.family, self, elem, value)
727        elif elem['type'] == 'pad':
728            t = TypePad(self.family, self, elem, value)
729        elif elem['type'] == 'flag':
730            t = TypeFlag(self.family, self, elem, value)
731        elif elem['type'] == 'string':
732            t = TypeString(self.family, self, elem, value)
733        elif elem['type'] == 'binary':
734            t = TypeBinary(self.family, self, elem, value)
735        elif elem['type'] == 'nest':
736            t = TypeNest(self.family, self, elem, value)
737        elif elem['type'] == 'array-nest':
738            t = TypeArrayNest(self.family, self, elem, value)
739        elif elem['type'] == 'nest-type-value':
740            t = TypeNestTypeValue(self.family, self, elem, value)
741        else:
742            raise Exception(f"No typed class for type {elem['type']}")
743
744        if 'multi-attr' in elem and elem['multi-attr']:
745            t = TypeMultiAttr(self.family, self, elem, value, t)
746
747        return t
748
749
750class Operation(SpecOperation):
751    def __init__(self, family, yaml, req_value, rsp_value):
752        super().__init__(family, yaml, req_value, rsp_value)
753
754        self.render_name = family.name + '_' + c_lower(self.name)
755
756        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
757                         ('dump' in yaml and 'request' in yaml['dump'])
758
759        self.has_ntf = False
760
761        # Added by resolve:
762        self.enum_name = None
763        delattr(self, "enum_name")
764
765    def resolve(self):
766        self.resolve_up(super())
767
768        if not self.is_async:
769            self.enum_name = self.family.op_prefix + c_upper(self.name)
770        else:
771            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
772
773    def mark_has_ntf(self):
774        self.has_ntf = True
775
776
777class Family(SpecFamily):
778    def __init__(self, file_name, exclude_ops):
779        # Added by resolve:
780        self.c_name = None
781        delattr(self, "c_name")
782        self.op_prefix = None
783        delattr(self, "op_prefix")
784        self.async_op_prefix = None
785        delattr(self, "async_op_prefix")
786        self.mcgrps = None
787        delattr(self, "mcgrps")
788        self.consts = None
789        delattr(self, "consts")
790        self.hooks = None
791        delattr(self, "hooks")
792
793        super().__init__(file_name, exclude_ops=exclude_ops)
794
795        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
796        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
797
798        if 'definitions' not in self.yaml:
799            self.yaml['definitions'] = []
800
801        if 'uapi-header' in self.yaml:
802            self.uapi_header = self.yaml['uapi-header']
803        else:
804            self.uapi_header = f"linux/{self.name}.h"
805
806    def resolve(self):
807        self.resolve_up(super())
808
809        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
810            raise Exception("Codegen only supported for genetlink")
811
812        self.c_name = c_lower(self.name)
813        if 'name-prefix' in self.yaml['operations']:
814            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
815        else:
816            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
817        if 'async-prefix' in self.yaml['operations']:
818            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
819        else:
820            self.async_op_prefix = self.op_prefix
821
822        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
823
824        self.hooks = dict()
825        for when in ['pre', 'post']:
826            self.hooks[when] = dict()
827            for op_mode in ['do', 'dump']:
828                self.hooks[when][op_mode] = dict()
829                self.hooks[when][op_mode]['set'] = set()
830                self.hooks[when][op_mode]['list'] = []
831
832        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
833        self.root_sets = dict()
834        # dict space-name -> set('request', 'reply')
835        self.pure_nested_structs = dict()
836
837        self._mark_notify()
838        self._mock_up_events()
839
840        self._load_root_sets()
841        self._load_nested_sets()
842        self._load_hooks()
843
844        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
845        if self.kernel_policy == 'global':
846            self._load_global_policy()
847
848    def new_enum(self, elem):
849        return EnumSet(self, elem)
850
851    def new_attr_set(self, elem):
852        return AttrSet(self, elem)
853
854    def new_operation(self, elem, req_value, rsp_value):
855        return Operation(self, elem, req_value, rsp_value)
856
857    def _mark_notify(self):
858        for op in self.msgs.values():
859            if 'notify' in op:
860                self.ops[op['notify']].mark_has_ntf()
861
862    # Fake a 'do' equivalent of all events, so that we can render their response parsing
863    def _mock_up_events(self):
864        for op in self.yaml['operations']['list']:
865            if 'event' in op:
866                op['do'] = {
867                    'reply': {
868                        'attributes': op['event']['attributes']
869                    }
870                }
871
872    def _load_root_sets(self):
873        for op_name, op in self.msgs.items():
874            if 'attribute-set' not in op:
875                continue
876
877            req_attrs = set()
878            rsp_attrs = set()
879            for op_mode in ['do', 'dump']:
880                if op_mode in op and 'request' in op[op_mode]:
881                    req_attrs.update(set(op[op_mode]['request']['attributes']))
882                if op_mode in op and 'reply' in op[op_mode]:
883                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
884            if 'event' in op:
885                rsp_attrs.update(set(op['event']['attributes']))
886
887            if op['attribute-set'] not in self.root_sets:
888                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
889            else:
890                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
891                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
892
893    def _load_nested_sets(self):
894        attr_set_queue = list(self.root_sets.keys())
895        attr_set_seen = set(self.root_sets.keys())
896
897        while len(attr_set_queue):
898            a_set = attr_set_queue.pop(0)
899            for attr, spec in self.attr_sets[a_set].items():
900                if 'nested-attributes' not in spec:
901                    continue
902
903                nested = spec['nested-attributes']
904                if nested not in attr_set_seen:
905                    attr_set_queue.append(nested)
906                    attr_set_seen.add(nested)
907
908                inherit = set()
909                if nested not in self.root_sets:
910                    if nested not in self.pure_nested_structs:
911                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
912                else:
913                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
914
915                if 'type-value' in spec:
916                    if nested in self.root_sets:
917                        raise Exception("Inheriting members to a space used as root not supported")
918                    inherit.update(set(spec['type-value']))
919                elif spec['type'] == 'array-nest':
920                    inherit.add('idx')
921                self.pure_nested_structs[nested].set_inherited(inherit)
922
923        for root_set, rs_members in self.root_sets.items():
924            for attr, spec in self.attr_sets[root_set].items():
925                if 'nested-attributes' in spec:
926                    nested = spec['nested-attributes']
927                    if attr in rs_members['request']:
928                        self.pure_nested_structs[nested].request = True
929                    if attr in rs_members['reply']:
930                        self.pure_nested_structs[nested].reply = True
931
932        # Try to reorder according to dependencies
933        pns_key_list = list(self.pure_nested_structs.keys())
934        pns_key_seen = set()
935        rounds = len(pns_key_list)**2  # it's basically bubble sort
936        for _ in range(rounds):
937            if len(pns_key_list) == 0:
938                break
939            name = pns_key_list.pop(0)
940            finished = True
941            for _, spec in self.attr_sets[name].items():
942                if 'nested-attributes' in spec:
943                    if spec['nested-attributes'] not in pns_key_seen:
944                        # Dicts are sorted, this will make struct last
945                        struct = self.pure_nested_structs.pop(name)
946                        self.pure_nested_structs[name] = struct
947                        finished = False
948                        break
949            if finished:
950                pns_key_seen.add(name)
951            else:
952                pns_key_list.append(name)
953        # Propagate the request / reply
954        for attr_set, struct in reversed(self.pure_nested_structs.items()):
955            for _, spec in self.attr_sets[attr_set].items():
956                if 'nested-attributes' in spec:
957                    child = self.pure_nested_structs.get(spec['nested-attributes'])
958                    if child:
959                        child.request |= struct.request
960                        child.reply |= struct.reply
961
962    def _load_global_policy(self):
963        global_set = set()
964        attr_set_name = None
965        for op_name, op in self.ops.items():
966            if not op:
967                continue
968            if 'attribute-set' not in op:
969                continue
970
971            if attr_set_name is None:
972                attr_set_name = op['attribute-set']
973            if attr_set_name != op['attribute-set']:
974                raise Exception('For a global policy all ops must use the same set')
975
976            for op_mode in ['do', 'dump']:
977                if op_mode in op:
978                    global_set.update(op[op_mode].get('request', []))
979
980        self.global_policy = []
981        self.global_policy_set = attr_set_name
982        for attr in self.attr_sets[attr_set_name]:
983            if attr in global_set:
984                self.global_policy.append(attr)
985
986    def _load_hooks(self):
987        for op in self.ops.values():
988            for op_mode in ['do', 'dump']:
989                if op_mode not in op:
990                    continue
991                for when in ['pre', 'post']:
992                    if when not in op[op_mode]:
993                        continue
994                    name = op[op_mode][when]
995                    if name in self.hooks[when][op_mode]['set']:
996                        continue
997                    self.hooks[when][op_mode]['set'].add(name)
998                    self.hooks[when][op_mode]['list'].append(name)
999
1000
1001class RenderInfo:
1002    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1003        self.family = family
1004        self.nl = cw.nlib
1005        self.ku_space = ku_space
1006        self.op_mode = op_mode
1007        self.op = op
1008
1009        # 'do' and 'dump' response parsing is identical
1010        self.type_consistent = True
1011        if op_mode != 'do' and 'dump' in op and 'do' in op:
1012            if ('reply' in op['do']) != ('reply' in op["dump"]):
1013                self.type_consistent = False
1014            elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1015                self.type_consistent = False
1016
1017        self.attr_set = attr_set
1018        if not self.attr_set:
1019            self.attr_set = op['attribute-set']
1020
1021        self.type_name_conflict = False
1022        if op:
1023            self.type_name = c_lower(op.name)
1024        else:
1025            self.type_name = c_lower(attr_set)
1026            if attr_set in family.consts:
1027                self.type_name_conflict = True
1028
1029        self.cw = cw
1030
1031        self.struct = dict()
1032        if op_mode == 'notify':
1033            op_mode = 'do'
1034        for op_dir in ['request', 'reply']:
1035            if op and op_dir in op[op_mode]:
1036                self.struct[op_dir] = Struct(family, self.attr_set,
1037                                             type_list=op[op_mode][op_dir]['attributes'])
1038        if op_mode == 'event':
1039            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1040
1041
1042class CodeWriter:
1043    def __init__(self, nlib, out_file):
1044        self.nlib = nlib
1045
1046        self._nl = False
1047        self._block_end = False
1048        self._silent_block = False
1049        self._ind = 0
1050        self._out = out_file
1051
1052    @classmethod
1053    def _is_cond(cls, line):
1054        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1055
1056    def p(self, line, add_ind=0):
1057        if self._block_end:
1058            self._block_end = False
1059            if line.startswith('else'):
1060                line = '} ' + line
1061            else:
1062                self._out.write('\t' * self._ind + '}\n')
1063
1064        if self._nl:
1065            self._out.write('\n')
1066            self._nl = False
1067
1068        ind = self._ind
1069        if line[-1] == ':':
1070            ind -= 1
1071        if self._silent_block:
1072            ind += 1
1073        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1074        if add_ind:
1075            ind += add_ind
1076        self._out.write('\t' * ind + line + '\n')
1077
1078    def nl(self):
1079        self._nl = True
1080
1081    def block_start(self, line=''):
1082        if line:
1083            line = line + ' '
1084        self.p(line + '{')
1085        self._ind += 1
1086
1087    def block_end(self, line=''):
1088        if line and line[0] not in {';', ','}:
1089            line = ' ' + line
1090        self._ind -= 1
1091        self._nl = False
1092        if not line:
1093            # Delay printing closing bracket in case "else" comes next
1094            if self._block_end:
1095                self._out.write('\t' * (self._ind + 1) + '}\n')
1096            self._block_end = True
1097        else:
1098            self.p('}' + line)
1099
1100    def write_doc_line(self, doc, indent=True):
1101        words = doc.split()
1102        line = ' *'
1103        for word in words:
1104            if len(line) + len(word) >= 79:
1105                self.p(line)
1106                line = ' *'
1107                if indent:
1108                    line += '  '
1109            line += ' ' + word
1110        self.p(line)
1111
1112    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1113        if not args:
1114            args = ['void']
1115
1116        if doc:
1117            self.p('/*')
1118            self.p(' * ' + doc)
1119            self.p(' */')
1120
1121        oneline = qual_ret
1122        if qual_ret[-1] != '*':
1123            oneline += ' '
1124        oneline += f"{name}({', '.join(args)}){suffix}"
1125
1126        if len(oneline) < 80:
1127            self.p(oneline)
1128            return
1129
1130        v = qual_ret
1131        if len(v) > 3:
1132            self.p(v)
1133            v = ''
1134        elif qual_ret[-1] != '*':
1135            v += ' '
1136        v += name + '('
1137        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1138        delta_ind = len(v) - len(ind)
1139        v += args[0]
1140        i = 1
1141        while i < len(args):
1142            next_len = len(v) + len(args[i])
1143            if v[0] == '\t':
1144                next_len += delta_ind
1145            if next_len > 76:
1146                self.p(v + ',')
1147                v = ind
1148            else:
1149                v += ', '
1150            v += args[i]
1151            i += 1
1152        self.p(v + ')' + suffix)
1153
1154    def write_func_lvar(self, local_vars):
1155        if not local_vars:
1156            return
1157
1158        if type(local_vars) is str:
1159            local_vars = [local_vars]
1160
1161        local_vars.sort(key=len, reverse=True)
1162        for var in local_vars:
1163            self.p(var)
1164        self.nl()
1165
1166    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1167        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1168        self.write_func_lvar(local_vars=local_vars)
1169
1170        self.block_start()
1171        for line in body:
1172            self.p(line)
1173        self.block_end()
1174
1175    def writes_defines(self, defines):
1176        longest = 0
1177        for define in defines:
1178            if len(define[0]) > longest:
1179                longest = len(define[0])
1180        longest = ((longest + 8) // 8) * 8
1181        for define in defines:
1182            line = '#define ' + define[0]
1183            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1184            if type(define[1]) is int:
1185                line += str(define[1])
1186            elif type(define[1]) is str:
1187                line += '"' + define[1] + '"'
1188            self.p(line)
1189
1190    def write_struct_init(self, members):
1191        longest = max([len(x[0]) for x in members])
1192        longest += 1  # because we prepend a .
1193        longest = ((longest + 8) // 8) * 8
1194        for one in members:
1195            line = '.' + one[0]
1196            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1197            line += '= ' + one[1] + ','
1198            self.p(line)
1199
1200
1201scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1202
1203direction_to_suffix = {
1204    'reply': '_rsp',
1205    'request': '_req',
1206    '': ''
1207}
1208
1209op_mode_to_wrapper = {
1210    'do': '',
1211    'dump': '_list',
1212    'notify': '_ntf',
1213    'event': '',
1214}
1215
1216_C_KW = {
1217    'auto',
1218    'bool',
1219    'break',
1220    'case',
1221    'char',
1222    'const',
1223    'continue',
1224    'default',
1225    'do',
1226    'double',
1227    'else',
1228    'enum',
1229    'extern',
1230    'float',
1231    'for',
1232    'goto',
1233    'if',
1234    'inline',
1235    'int',
1236    'long',
1237    'register',
1238    'return',
1239    'short',
1240    'signed',
1241    'sizeof',
1242    'static',
1243    'struct',
1244    'switch',
1245    'typedef',
1246    'union',
1247    'unsigned',
1248    'void',
1249    'volatile',
1250    'while'
1251}
1252
1253
1254def rdir(direction):
1255    if direction == 'reply':
1256        return 'request'
1257    if direction == 'request':
1258        return 'reply'
1259    return direction
1260
1261
1262def op_prefix(ri, direction, deref=False):
1263    suffix = f"_{ri.type_name}"
1264
1265    if not ri.op_mode or ri.op_mode == 'do':
1266        suffix += f"{direction_to_suffix[direction]}"
1267    else:
1268        if direction == 'request':
1269            suffix += '_req_dump'
1270        else:
1271            if ri.type_consistent:
1272                if deref:
1273                    suffix += f"{direction_to_suffix[direction]}"
1274                else:
1275                    suffix += op_mode_to_wrapper[ri.op_mode]
1276            else:
1277                suffix += '_rsp'
1278                suffix += '_dump' if deref else '_list'
1279
1280    return f"{ri.family['name']}{suffix}"
1281
1282
1283def type_name(ri, direction, deref=False):
1284    return f"struct {op_prefix(ri, direction, deref=deref)}"
1285
1286
1287def print_prototype(ri, direction, terminate=True, doc=None):
1288    suffix = ';' if terminate else ''
1289
1290    fname = ri.op.render_name
1291    if ri.op_mode == 'dump':
1292        fname += '_dump'
1293
1294    args = ['struct ynl_sock *ys']
1295    if 'request' in ri.op[ri.op_mode]:
1296        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1297
1298    ret = 'int'
1299    if 'reply' in ri.op[ri.op_mode]:
1300        ret = f"{type_name(ri, rdir(direction))} *"
1301
1302    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1303
1304
1305def print_req_prototype(ri):
1306    print_prototype(ri, "request", doc=ri.op['doc'])
1307
1308
1309def print_dump_prototype(ri):
1310    print_prototype(ri, "request")
1311
1312
1313def put_typol(cw, struct):
1314    type_max = struct.attr_set.max_name
1315    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1316
1317    for _, arg in struct.member_list():
1318        arg.attr_typol(cw)
1319
1320    cw.block_end(line=';')
1321    cw.nl()
1322
1323    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1324    cw.p(f'.max_attr = {type_max},')
1325    cw.p(f'.table = {struct.render_name}_policy,')
1326    cw.block_end(line=';')
1327    cw.nl()
1328
1329
1330def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1331    args = [f'int {arg_name}']
1332    if enum and not ('enum-name' in enum and not enum['enum-name']):
1333        args = [f'enum {render_name} {arg_name}']
1334    cw.write_func_prot('const char *', f'{render_name}_str', args)
1335    cw.block_start()
1336    if enum and enum.type == 'flags':
1337        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1338    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1339    cw.p('return NULL;')
1340    cw.p(f'return {map_name}[{arg_name}];')
1341    cw.block_end()
1342    cw.nl()
1343
1344
1345def put_op_name_fwd(family, cw):
1346    cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'], suffix=';')
1347
1348
1349def put_op_name(family, cw):
1350    map_name = f'{family.name}_op_strmap'
1351    cw.block_start(line=f"static const char * const {map_name}[] =")
1352    for op_name, op in family.msgs.items():
1353        if op.rsp_value:
1354            if op.req_value == op.rsp_value:
1355                cw.p(f'[{op.enum_name}] = "{op_name}",')
1356            else:
1357                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1358    cw.block_end(line=';')
1359    cw.nl()
1360
1361    _put_enum_to_str_helper(cw, family.name + '_op', map_name, 'op')
1362
1363
1364def put_enum_to_str_fwd(family, cw, enum):
1365    args = [f'enum {enum.render_name} value']
1366    if 'enum-name' in enum and not enum['enum-name']:
1367        args = ['int value']
1368    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1369
1370
1371def put_enum_to_str(family, cw, enum):
1372    map_name = f'{enum.render_name}_strmap'
1373    cw.block_start(line=f"static const char * const {map_name}[] =")
1374    for entry in enum.entries.values():
1375        cw.p(f'[{entry.value}] = "{entry.name}",')
1376    cw.block_end(line=';')
1377    cw.nl()
1378
1379    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1380
1381
1382def put_req_nested(ri, struct):
1383    func_args = ['struct nlmsghdr *nlh',
1384                 'unsigned int attr_type',
1385                 f'{struct.ptr_name}obj']
1386
1387    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1388    ri.cw.block_start()
1389    ri.cw.write_func_lvar('struct nlattr *nest;')
1390
1391    ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1392
1393    for _, arg in struct.member_list():
1394        arg.attr_put(ri, "obj")
1395
1396    ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1397
1398    ri.cw.nl()
1399    ri.cw.p('return 0;')
1400    ri.cw.block_end()
1401    ri.cw.nl()
1402
1403
1404def _multi_parse(ri, struct, init_lines, local_vars):
1405    if struct.nested:
1406        iter_line = "mnl_attr_for_each_nested(attr, nested)"
1407    else:
1408        iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1409
1410    array_nests = set()
1411    multi_attrs = set()
1412    needs_parg = False
1413    for arg, aspec in struct.member_list():
1414        if aspec['type'] == 'array-nest':
1415            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1416            array_nests.add(arg)
1417        if 'multi-attr' in aspec:
1418            multi_attrs.add(arg)
1419        needs_parg |= 'nested-attributes' in aspec
1420    if array_nests or multi_attrs:
1421        local_vars.append('int i;')
1422    if needs_parg:
1423        local_vars.append('struct ynl_parse_arg parg;')
1424        init_lines.append('parg.ys = yarg->ys;')
1425
1426    all_multi = array_nests | multi_attrs
1427
1428    for anest in sorted(all_multi):
1429        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1430
1431    ri.cw.block_start()
1432    ri.cw.write_func_lvar(local_vars)
1433
1434    for line in init_lines:
1435        ri.cw.p(line)
1436    ri.cw.nl()
1437
1438    for arg in struct.inherited:
1439        ri.cw.p(f'dst->{arg} = {arg};')
1440
1441    for anest in sorted(all_multi):
1442        aspec = struct[anest]
1443        ri.cw.p(f"if (dst->{aspec.c_name})")
1444        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1445
1446    ri.cw.nl()
1447    ri.cw.block_start(line=iter_line)
1448    ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1449    ri.cw.nl()
1450
1451    first = True
1452    for _, arg in struct.member_list():
1453        good = arg.attr_get(ri, 'dst', first=first)
1454        # First may be 'unused' or 'pad', ignore those
1455        first &= not good
1456
1457    ri.cw.block_end()
1458    ri.cw.nl()
1459
1460    for anest in sorted(array_nests):
1461        aspec = struct[anest]
1462
1463        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1464        ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1465        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1466        ri.cw.p('i = 0;')
1467        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1468        ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1469        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1470        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1471        ri.cw.p('return MNL_CB_ERROR;')
1472        ri.cw.p('i++;')
1473        ri.cw.block_end()
1474        ri.cw.block_end()
1475    ri.cw.nl()
1476
1477    for anest in sorted(multi_attrs):
1478        aspec = struct[anest]
1479        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1480        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1481        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1482        ri.cw.p('i = 0;')
1483        if 'nested-attributes' in aspec:
1484            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1485        ri.cw.block_start(line=iter_line)
1486        ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1487        if 'nested-attributes' in aspec:
1488            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1489            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1490            ri.cw.p('return MNL_CB_ERROR;')
1491        elif aspec['type'] in scalars:
1492            t = aspec['type']
1493            if t[0] == 's':
1494                t = 'u' + t[1:]
1495            ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1496        else:
1497            raise Exception('Nest parsing type not supported yet')
1498        ri.cw.p('i++;')
1499        ri.cw.block_end()
1500        ri.cw.block_end()
1501        ri.cw.block_end()
1502    ri.cw.nl()
1503
1504    if struct.nested:
1505        ri.cw.p('return 0;')
1506    else:
1507        ri.cw.p('return MNL_CB_OK;')
1508    ri.cw.block_end()
1509    ri.cw.nl()
1510
1511
1512def parse_rsp_nested(ri, struct):
1513    func_args = ['struct ynl_parse_arg *yarg',
1514                 'const struct nlattr *nested']
1515    for arg in struct.inherited:
1516        func_args.append('__u32 ' + arg)
1517
1518    local_vars = ['const struct nlattr *attr;',
1519                  f'{struct.ptr_name}dst = yarg->data;']
1520    init_lines = []
1521
1522    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1523
1524    _multi_parse(ri, struct, init_lines, local_vars)
1525
1526
1527def parse_rsp_msg(ri, deref=False):
1528    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1529        return
1530
1531    func_args = ['const struct nlmsghdr *nlh',
1532                 'void *data']
1533
1534    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1535                  'struct ynl_parse_arg *yarg = data;',
1536                  'const struct nlattr *attr;']
1537    init_lines = ['dst = yarg->data;']
1538
1539    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1540
1541    _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1542
1543
1544def print_req(ri):
1545    ret_ok = '0'
1546    ret_err = '-1'
1547    direction = "request"
1548    local_vars = ['struct nlmsghdr *nlh;',
1549                  'int err;']
1550
1551    if 'reply' in ri.op[ri.op_mode]:
1552        ret_ok = 'rsp'
1553        ret_err = 'NULL'
1554        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1555                       'struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };']
1556
1557    print_prototype(ri, direction, terminate=False)
1558    ri.cw.block_start()
1559    ri.cw.write_func_lvar(local_vars)
1560
1561    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1562
1563    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1564    if 'reply' in ri.op[ri.op_mode]:
1565        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1566    ri.cw.nl()
1567    for _, attr in ri.struct["request"].member_list():
1568        attr.attr_put(ri, "req")
1569    ri.cw.nl()
1570
1571    parse_arg = "NULL"
1572    if 'reply' in ri.op[ri.op_mode]:
1573        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1574        ri.cw.p('yrs.yarg.data = rsp;')
1575        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1576        if ri.op.value is not None:
1577            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1578        else:
1579            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1580        ri.cw.nl()
1581        parse_arg = '&yrs'
1582    ri.cw.p(f"err = ynl_exec(ys, nlh, {parse_arg});")
1583    ri.cw.p('if (err < 0)')
1584    if 'reply' in ri.op[ri.op_mode]:
1585        ri.cw.p('goto err_free;')
1586    else:
1587        ri.cw.p('return -1;')
1588    ri.cw.nl()
1589
1590    ri.cw.p(f"return {ret_ok};")
1591    ri.cw.nl()
1592
1593    if 'reply' in ri.op[ri.op_mode]:
1594        ri.cw.p('err_free:')
1595        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1596        ri.cw.p(f"return {ret_err};")
1597
1598    ri.cw.block_end()
1599
1600
1601def print_dump(ri):
1602    direction = "request"
1603    print_prototype(ri, direction, terminate=False)
1604    ri.cw.block_start()
1605    local_vars = ['struct ynl_dump_state yds = {};',
1606                  'struct nlmsghdr *nlh;',
1607                  'int err;']
1608
1609    for var in local_vars:
1610        ri.cw.p(f'{var}')
1611    ri.cw.nl()
1612
1613    ri.cw.p('yds.ys = ys;')
1614    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1615    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1616    if ri.op.value is not None:
1617        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1618    else:
1619        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1620    ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1621    ri.cw.nl()
1622    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1623
1624    if "request" in ri.op[ri.op_mode]:
1625        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1626        ri.cw.nl()
1627        for _, attr in ri.struct["request"].member_list():
1628            attr.attr_put(ri, "req")
1629    ri.cw.nl()
1630
1631    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1632    ri.cw.p('if (err < 0)')
1633    ri.cw.p('goto free_list;')
1634    ri.cw.nl()
1635
1636    ri.cw.p('return yds.first;')
1637    ri.cw.nl()
1638    ri.cw.p('free_list:')
1639    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1640    ri.cw.p('return NULL;')
1641    ri.cw.block_end()
1642
1643
1644def call_free(ri, direction, var):
1645    return f"{op_prefix(ri, direction)}_free({var});"
1646
1647
1648def free_arg_name(direction):
1649    if direction:
1650        return direction_to_suffix[direction][1:]
1651    return 'obj'
1652
1653
1654def print_alloc_wrapper(ri, direction):
1655    name = op_prefix(ri, direction)
1656    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1657    ri.cw.block_start()
1658    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1659    ri.cw.block_end()
1660
1661
1662def print_free_prototype(ri, direction, suffix=';'):
1663    name = op_prefix(ri, direction)
1664    struct_name = name
1665    if ri.type_name_conflict:
1666        struct_name += '_'
1667    arg = free_arg_name(direction)
1668    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1669
1670
1671def _print_type(ri, direction, struct):
1672    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1673    if not direction and ri.type_name_conflict:
1674        suffix += '_'
1675
1676    if ri.op_mode == 'dump':
1677        suffix += '_dump'
1678
1679    ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1680
1681    meta_started = False
1682    for _, attr in struct.member_list():
1683        for type_filter in ['len', 'bit']:
1684            line = attr.presence_member(ri.ku_space, type_filter)
1685            if line:
1686                if not meta_started:
1687                    ri.cw.block_start(line=f"struct")
1688                    meta_started = True
1689                ri.cw.p(line)
1690    if meta_started:
1691        ri.cw.block_end(line='_present;')
1692        ri.cw.nl()
1693
1694    for arg in struct.inherited:
1695        ri.cw.p(f"__u32 {arg};")
1696
1697    for _, attr in struct.member_list():
1698        attr.struct_member(ri)
1699
1700    ri.cw.block_end(line=';')
1701    ri.cw.nl()
1702
1703
1704def print_type(ri, direction):
1705    _print_type(ri, direction, ri.struct[direction])
1706
1707
1708def print_type_full(ri, struct):
1709    _print_type(ri, "", struct)
1710
1711
1712def print_type_helpers(ri, direction, deref=False):
1713    print_free_prototype(ri, direction)
1714    ri.cw.nl()
1715
1716    if ri.ku_space == 'user' and direction == 'request':
1717        for _, attr in ri.struct[direction].member_list():
1718            attr.setter(ri, ri.attr_set, direction, deref=deref)
1719    ri.cw.nl()
1720
1721
1722def print_req_type_helpers(ri):
1723    print_alloc_wrapper(ri, "request")
1724    print_type_helpers(ri, "request")
1725
1726
1727def print_rsp_type_helpers(ri):
1728    if 'reply' not in ri.op[ri.op_mode]:
1729        return
1730    print_type_helpers(ri, "reply")
1731
1732
1733def print_parse_prototype(ri, direction, terminate=True):
1734    suffix = "_rsp" if direction == "reply" else "_req"
1735    term = ';' if terminate else ''
1736
1737    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1738                          ['const struct nlattr **tb',
1739                           f"struct {ri.op.render_name}{suffix} *req"],
1740                          suffix=term)
1741
1742
1743def print_req_type(ri):
1744    print_type(ri, "request")
1745
1746
1747def print_req_free(ri):
1748    if 'request' not in ri.op[ri.op_mode]:
1749        return
1750    _free_type(ri, 'request', ri.struct['request'])
1751
1752
1753def print_rsp_type(ri):
1754    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1755        direction = 'reply'
1756    elif ri.op_mode == 'event':
1757        direction = 'reply'
1758    else:
1759        return
1760    print_type(ri, direction)
1761
1762
1763def print_wrapped_type(ri):
1764    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1765    if ri.op_mode == 'dump':
1766        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1767    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1768        ri.cw.p('__u16 family;')
1769        ri.cw.p('__u8 cmd;')
1770        ri.cw.p('struct ynl_ntf_base_type *next;')
1771        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1772    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1773    ri.cw.block_end(line=';')
1774    ri.cw.nl()
1775    print_free_prototype(ri, 'reply')
1776    ri.cw.nl()
1777
1778
1779def _free_type_members_iter(ri, struct):
1780    for _, attr in struct.member_list():
1781        if attr.free_needs_iter():
1782            ri.cw.p('unsigned int i;')
1783            ri.cw.nl()
1784            break
1785
1786
1787def _free_type_members(ri, var, struct, ref=''):
1788    for _, attr in struct.member_list():
1789        attr.free(ri, var, ref)
1790
1791
1792def _free_type(ri, direction, struct):
1793    var = free_arg_name(direction)
1794
1795    print_free_prototype(ri, direction, suffix='')
1796    ri.cw.block_start()
1797    _free_type_members_iter(ri, struct)
1798    _free_type_members(ri, var, struct)
1799    if direction:
1800        ri.cw.p(f'free({var});')
1801    ri.cw.block_end()
1802    ri.cw.nl()
1803
1804
1805def free_rsp_nested(ri, struct):
1806    _free_type(ri, "", struct)
1807
1808
1809def print_rsp_free(ri):
1810    if 'reply' not in ri.op[ri.op_mode]:
1811        return
1812    _free_type(ri, 'reply', ri.struct['reply'])
1813
1814
1815def print_dump_type_free(ri):
1816    sub_type = type_name(ri, 'reply')
1817
1818    print_free_prototype(ri, 'reply', suffix='')
1819    ri.cw.block_start()
1820    ri.cw.p(f"{sub_type} *next = rsp;")
1821    ri.cw.nl()
1822    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1823    _free_type_members_iter(ri, ri.struct['reply'])
1824    ri.cw.p('rsp = next;')
1825    ri.cw.p('next = rsp->next;')
1826    ri.cw.nl()
1827
1828    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1829    ri.cw.p(f'free(rsp);')
1830    ri.cw.block_end()
1831    ri.cw.block_end()
1832    ri.cw.nl()
1833
1834
1835def print_ntf_type_free(ri):
1836    print_free_prototype(ri, 'reply', suffix='')
1837    ri.cw.block_start()
1838    _free_type_members_iter(ri, ri.struct['reply'])
1839    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1840    ri.cw.p(f'free(rsp);')
1841    ri.cw.block_end()
1842    ri.cw.nl()
1843
1844
1845def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1846    if terminate and ri and policy_should_be_static(struct.family):
1847        return
1848
1849    if terminate:
1850        prefix = 'extern '
1851    else:
1852        if ri and policy_should_be_static(struct.family):
1853            prefix = 'static '
1854        else:
1855            prefix = ''
1856
1857    suffix = ';' if terminate else ' = {'
1858
1859    max_attr = struct.attr_max_val
1860    if ri:
1861        name = ri.op.render_name
1862        if ri.op.dual_policy:
1863            name += '_' + ri.op_mode
1864    else:
1865        name = struct.render_name
1866    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1867
1868
1869def print_req_policy(cw, struct, ri=None):
1870    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1871    for _, arg in struct.member_list():
1872        arg.attr_policy(cw)
1873    cw.p("};")
1874
1875
1876def kernel_can_gen_family_struct(family):
1877    return family.proto == 'genetlink'
1878
1879
1880def policy_should_be_static(family):
1881    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
1882
1883
1884def print_kernel_op_table_fwd(family, cw, terminate):
1885    exported = not kernel_can_gen_family_struct(family)
1886
1887    if not terminate or exported:
1888        cw.p(f"/* Ops table for {family.name} */")
1889
1890        pol_to_struct = {'global': 'genl_small_ops',
1891                         'per-op': 'genl_ops',
1892                         'split': 'genl_split_ops'}
1893        struct_type = pol_to_struct[family.kernel_policy]
1894
1895        if not exported:
1896            cnt = ""
1897        elif family.kernel_policy == 'split':
1898            cnt = 0
1899            for op in family.ops.values():
1900                if 'do' in op:
1901                    cnt += 1
1902                if 'dump' in op:
1903                    cnt += 1
1904        else:
1905            cnt = len(family.ops)
1906
1907        qual = 'static const' if not exported else 'const'
1908        line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1909        if terminate:
1910            cw.p(f"extern {line};")
1911        else:
1912            cw.block_start(line=line + ' =')
1913
1914    if not terminate:
1915        return
1916
1917    cw.nl()
1918    for name in family.hooks['pre']['do']['list']:
1919        cw.write_func_prot('int', c_lower(name),
1920                           ['const struct genl_split_ops *ops',
1921                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1922    for name in family.hooks['post']['do']['list']:
1923        cw.write_func_prot('void', c_lower(name),
1924                           ['const struct genl_split_ops *ops',
1925                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1926    for name in family.hooks['pre']['dump']['list']:
1927        cw.write_func_prot('int', c_lower(name),
1928                           ['struct netlink_callback *cb'], suffix=';')
1929    for name in family.hooks['post']['dump']['list']:
1930        cw.write_func_prot('int', c_lower(name),
1931                           ['struct netlink_callback *cb'], suffix=';')
1932
1933    cw.nl()
1934
1935    for op_name, op in family.ops.items():
1936        if op.is_async:
1937            continue
1938
1939        if 'do' in op:
1940            name = c_lower(f"{family.name}-nl-{op_name}-doit")
1941            cw.write_func_prot('int', name,
1942                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1943
1944        if 'dump' in op:
1945            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1946            cw.write_func_prot('int', name,
1947                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1948    cw.nl()
1949
1950
1951def print_kernel_op_table_hdr(family, cw):
1952    print_kernel_op_table_fwd(family, cw, terminate=True)
1953
1954
1955def print_kernel_op_table(family, cw):
1956    print_kernel_op_table_fwd(family, cw, terminate=False)
1957    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1958        for op_name, op in family.ops.items():
1959            if op.is_async:
1960                continue
1961
1962            cw.block_start()
1963            members = [('cmd', op.enum_name)]
1964            if 'dont-validate' in op:
1965                members.append(('validate',
1966                                ' | '.join([c_upper('genl-dont-validate-' + x)
1967                                            for x in op['dont-validate']])), )
1968            for op_mode in ['do', 'dump']:
1969                if op_mode in op:
1970                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1971                    members.append((op_mode + 'it', name))
1972            if family.kernel_policy == 'per-op':
1973                struct = Struct(family, op['attribute-set'],
1974                                type_list=op['do']['request']['attributes'])
1975
1976                name = c_lower(f"{family.name}-{op_name}-nl-policy")
1977                members.append(('policy', name))
1978                members.append(('maxattr', struct.attr_max_val.enum_name))
1979            if 'flags' in op:
1980                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1981            cw.write_struct_init(members)
1982            cw.block_end(line=',')
1983    elif family.kernel_policy == 'split':
1984        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1985                    'dump': {'pre': 'start', 'post': 'done'}}
1986
1987        for op_name, op in family.ops.items():
1988            for op_mode in ['do', 'dump']:
1989                if op.is_async or op_mode not in op:
1990                    continue
1991
1992                cw.block_start()
1993                members = [('cmd', op.enum_name)]
1994                if 'dont-validate' in op:
1995                    dont_validate = []
1996                    for x in op['dont-validate']:
1997                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
1998                            continue
1999                        if op_mode == "dump" and x == 'strict':
2000                            continue
2001                        dont_validate.append(x)
2002
2003                    members.append(('validate',
2004                                    ' | '.join([c_upper('genl-dont-validate-' + x)
2005                                                for x in dont_validate])), )
2006                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2007                if 'pre' in op[op_mode]:
2008                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2009                members.append((op_mode + 'it', name))
2010                if 'post' in op[op_mode]:
2011                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2012                if 'request' in op[op_mode]:
2013                    struct = Struct(family, op['attribute-set'],
2014                                    type_list=op[op_mode]['request']['attributes'])
2015
2016                    if op.dual_policy:
2017                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2018                    else:
2019                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
2020                    members.append(('policy', name))
2021                    members.append(('maxattr', struct.attr_max_val.enum_name))
2022                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2023                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2024                cw.write_struct_init(members)
2025                cw.block_end(line=',')
2026
2027    cw.block_end(line=';')
2028    cw.nl()
2029
2030
2031def print_kernel_mcgrp_hdr(family, cw):
2032    if not family.mcgrps['list']:
2033        return
2034
2035    cw.block_start('enum')
2036    for grp in family.mcgrps['list']:
2037        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2038        cw.p(grp_id)
2039    cw.block_end(';')
2040    cw.nl()
2041
2042
2043def print_kernel_mcgrp_src(family, cw):
2044    if not family.mcgrps['list']:
2045        return
2046
2047    cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
2048    for grp in family.mcgrps['list']:
2049        name = grp['name']
2050        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2051        cw.p('[' + grp_id + '] = { "' + name + '", },')
2052    cw.block_end(';')
2053    cw.nl()
2054
2055
2056def print_kernel_family_struct_hdr(family, cw):
2057    if not kernel_can_gen_family_struct(family):
2058        return
2059
2060    cw.p(f"extern struct genl_family {family.name}_nl_family;")
2061    cw.nl()
2062
2063
2064def print_kernel_family_struct_src(family, cw):
2065    if not kernel_can_gen_family_struct(family):
2066        return
2067
2068    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2069    cw.p('.name\t\t= ' + family.fam_key + ',')
2070    cw.p('.version\t= ' + family.ver_key + ',')
2071    cw.p('.netnsok\t= true,')
2072    cw.p('.parallel_ops\t= true,')
2073    cw.p('.module\t\t= THIS_MODULE,')
2074    if family.kernel_policy == 'per-op':
2075        cw.p(f'.ops\t\t= {family.name}_nl_ops,')
2076        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
2077    elif family.kernel_policy == 'split':
2078        cw.p(f'.split_ops\t= {family.name}_nl_ops,')
2079        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
2080    if family.mcgrps['list']:
2081        cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
2082        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
2083    cw.block_end(';')
2084
2085
2086def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2087    start_line = 'enum'
2088    if enum_name in obj:
2089        if obj[enum_name]:
2090            start_line = 'enum ' + c_lower(obj[enum_name])
2091    elif ckey and ckey in obj:
2092        start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
2093    cw.block_start(line=start_line)
2094
2095
2096def render_uapi(family, cw):
2097    hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
2098    cw.p('#ifndef ' + hdr_prot)
2099    cw.p('#define ' + hdr_prot)
2100    cw.nl()
2101
2102    defines = [(family.fam_key, family["name"]),
2103               (family.ver_key, family.get('version', 1))]
2104    cw.writes_defines(defines)
2105    cw.nl()
2106
2107    defines = []
2108    for const in family['definitions']:
2109        if const['type'] != 'const':
2110            cw.writes_defines(defines)
2111            defines = []
2112            cw.nl()
2113
2114        # Write kdoc for enum and flags (one day maybe also structs)
2115        if const['type'] == 'enum' or const['type'] == 'flags':
2116            enum = family.consts[const['name']]
2117
2118            if enum.has_doc():
2119                cw.p('/**')
2120                doc = ''
2121                if 'doc' in enum:
2122                    doc = ' - ' + enum['doc']
2123                cw.write_doc_line(enum.enum_name + doc)
2124                for entry in enum.entries.values():
2125                    if entry.has_doc():
2126                        doc = '@' + entry.c_name + ': ' + entry['doc']
2127                        cw.write_doc_line(doc)
2128                cw.p(' */')
2129
2130            uapi_enum_start(family, cw, const, 'name')
2131            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2132            for entry in enum.entries.values():
2133                suffix = ','
2134                if entry.value_change:
2135                    suffix = f" = {entry.user_value()}" + suffix
2136                cw.p(entry.c_name + suffix)
2137
2138            if const.get('render-max', False):
2139                cw.nl()
2140                cw.p('/* private: */')
2141                if const['type'] == 'flags':
2142                    max_name = c_upper(name_pfx + 'mask')
2143                    max_val = f' = {enum.get_mask()},'
2144                    cw.p(max_name + max_val)
2145                else:
2146                    max_name = c_upper(name_pfx + 'max')
2147                    cw.p('__' + max_name + ',')
2148                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2149            cw.block_end(line=';')
2150            cw.nl()
2151        elif const['type'] == 'const':
2152            defines.append([c_upper(family.get('c-define-name',
2153                                               f"{family.name}-{const['name']}")),
2154                            const['value']])
2155
2156    if defines:
2157        cw.writes_defines(defines)
2158        cw.nl()
2159
2160    max_by_define = family.get('max-by-define', False)
2161
2162    for _, attr_set in family.attr_sets.items():
2163        if attr_set.subset_of:
2164            continue
2165
2166        cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2167        max_value = f"({cnt_name} - 1)"
2168
2169        val = 0
2170        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2171        for _, attr in attr_set.items():
2172            suffix = ','
2173            if attr.value != val:
2174                suffix = f" = {attr.value},"
2175                val = attr.value
2176            val += 1
2177            cw.p(attr.enum_name + suffix)
2178        cw.nl()
2179        cw.p(cnt_name + ('' if max_by_define else ','))
2180        if not max_by_define:
2181            cw.p(f"{attr_set.max_name} = {max_value}")
2182        cw.block_end(line=';')
2183        if max_by_define:
2184            cw.p(f"#define {attr_set.max_name} {max_value}")
2185        cw.nl()
2186
2187    # Commands
2188    separate_ntf = 'async-prefix' in family['operations']
2189
2190    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2191    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2192    max_value = f"({cnt_name} - 1)"
2193
2194    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2195    val = 0
2196    for op in family.msgs.values():
2197        if separate_ntf and ('notify' in op or 'event' in op):
2198            continue
2199
2200        suffix = ','
2201        if op.value != val:
2202            suffix = f" = {op.value},"
2203            val = op.value
2204        cw.p(op.enum_name + suffix)
2205        val += 1
2206    cw.nl()
2207    cw.p(cnt_name + ('' if max_by_define else ','))
2208    if not max_by_define:
2209        cw.p(f"{max_name} = {max_value}")
2210    cw.block_end(line=';')
2211    if max_by_define:
2212        cw.p(f"#define {max_name} {max_value}")
2213    cw.nl()
2214
2215    if separate_ntf:
2216        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2217        for op in family.msgs.values():
2218            if separate_ntf and not ('notify' in op or 'event' in op):
2219                continue
2220
2221            suffix = ','
2222            if 'value' in op:
2223                suffix = f" = {op['value']},"
2224            cw.p(op.enum_name + suffix)
2225        cw.block_end(line=';')
2226        cw.nl()
2227
2228    # Multicast
2229    defines = []
2230    for grp in family.mcgrps['list']:
2231        name = grp['name']
2232        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2233                        f'{name}'])
2234    cw.nl()
2235    if defines:
2236        cw.writes_defines(defines)
2237        cw.nl()
2238
2239    cw.p(f'#endif /* {hdr_prot} */')
2240
2241
2242def _render_user_ntf_entry(ri, op):
2243    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2244    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2245    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2246    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2247    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2248    ri.cw.block_end(line=',')
2249
2250
2251def render_user_family(family, cw, prototype):
2252    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2253    if prototype:
2254        cw.p(f'extern {symbol};')
2255        return
2256
2257    if family.ntfs:
2258        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2259        for ntf_op_name, ntf_op in family.ntfs.items():
2260            if 'notify' in ntf_op:
2261                op = family.ops[ntf_op['notify']]
2262                ri = RenderInfo(cw, family, "user", op, "notify")
2263            elif 'event' in ntf_op:
2264                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2265            else:
2266                raise Exception('Invalid notification ' + ntf_op_name)
2267            _render_user_ntf_entry(ri, ntf_op)
2268        for op_name, op in family.ops.items():
2269            if 'event' not in op:
2270                continue
2271            ri = RenderInfo(cw, family, "user", op, "event")
2272            _render_user_ntf_entry(ri, op)
2273        cw.block_end(line=";")
2274        cw.nl()
2275
2276    cw.block_start(f'{symbol} = ')
2277    cw.p(f'.name\t\t= "{family.name}",')
2278    if family.ntfs:
2279        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2280        cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2281    cw.block_end(line=';')
2282
2283
2284def find_kernel_root(full_path):
2285    sub_path = ''
2286    while True:
2287        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2288        full_path = os.path.dirname(full_path)
2289        maintainers = os.path.join(full_path, "MAINTAINERS")
2290        if os.path.exists(maintainers):
2291            return full_path, sub_path[:-1]
2292
2293
2294def main():
2295    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2296    parser.add_argument('--mode', dest='mode', type=str, required=True)
2297    parser.add_argument('--spec', dest='spec', type=str, required=True)
2298    parser.add_argument('--header', dest='header', action='store_true', default=None)
2299    parser.add_argument('--source', dest='header', action='store_false')
2300    parser.add_argument('--user-header', nargs='+', default=[])
2301    parser.add_argument('--exclude-op', action='append', default=[])
2302    parser.add_argument('-o', dest='out_file', type=str)
2303    args = parser.parse_args()
2304
2305    out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout
2306
2307    if args.header is None:
2308        parser.error("--header or --source is required")
2309
2310    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2311
2312    try:
2313        parsed = Family(args.spec, exclude_ops)
2314        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2315            print('Spec license:', parsed.license)
2316            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2317            os.sys.exit(1)
2318    except yaml.YAMLError as exc:
2319        print(exc)
2320        os.sys.exit(1)
2321        return
2322
2323    supported_models = ['unified']
2324    if args.mode in ['user', 'kernel']:
2325        supported_models += ['directional']
2326    if parsed.msg_id_model not in supported_models:
2327        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2328        os.sys.exit(1)
2329
2330    cw = CodeWriter(BaseNlLib(), out_file)
2331
2332    _, spec_kernel = find_kernel_root(args.spec)
2333    if args.mode == 'uapi' or args.header:
2334        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2335    else:
2336        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2337    cw.p("/* Do not edit directly, auto-generated from: */")
2338    cw.p(f"/*\t{spec_kernel} */")
2339    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2340    if args.exclude_op or args.user_header:
2341        line = ''
2342        line += ' --user-header '.join([''] + args.user_header)
2343        line += ' --exclude-op '.join([''] + args.exclude_op)
2344        cw.p(f'/* YNL-ARG{line} */')
2345    cw.nl()
2346
2347    if args.mode == 'uapi':
2348        render_uapi(parsed, cw)
2349        return
2350
2351    hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2352    if args.header:
2353        cw.p('#ifndef ' + hdr_prot)
2354        cw.p('#define ' + hdr_prot)
2355        cw.nl()
2356
2357    if args.mode == 'kernel':
2358        cw.p('#include <net/netlink.h>')
2359        cw.p('#include <net/genetlink.h>')
2360        cw.nl()
2361        if not args.header:
2362            if args.out_file:
2363                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2364            cw.nl()
2365        headers = ['uapi/' + parsed.uapi_header]
2366    else:
2367        cw.p('#include <stdlib.h>')
2368        cw.p('#include <string.h>')
2369        if args.header:
2370            cw.p('#include <linux/types.h>')
2371        else:
2372            cw.p(f'#include "{parsed.name}-user.h"')
2373            cw.p('#include "ynl.h"')
2374        headers = [parsed.uapi_header]
2375    for definition in parsed['definitions']:
2376        if 'header' in definition:
2377            headers.append(definition['header'])
2378    for one in headers:
2379        cw.p(f"#include <{one}>")
2380    cw.nl()
2381
2382    if args.mode == "user":
2383        if not args.header:
2384            cw.p("#include <libmnl/libmnl.h>")
2385            cw.p("#include <linux/genetlink.h>")
2386            cw.nl()
2387            for one in args.user_header:
2388                cw.p(f'#include "{one}"')
2389        else:
2390            cw.p('struct ynl_sock;')
2391            cw.nl()
2392            render_user_family(parsed, cw, True)
2393        cw.nl()
2394
2395    if args.mode == "kernel":
2396        if args.header:
2397            for _, struct in sorted(parsed.pure_nested_structs.items()):
2398                if struct.request:
2399                    cw.p('/* Common nested types */')
2400                    break
2401            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2402                if struct.request:
2403                    print_req_policy_fwd(cw, struct)
2404            cw.nl()
2405
2406            if parsed.kernel_policy == 'global':
2407                cw.p(f"/* Global operation policy for {parsed.name} */")
2408
2409                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2410                print_req_policy_fwd(cw, struct)
2411                cw.nl()
2412
2413            if parsed.kernel_policy in {'per-op', 'split'}:
2414                for op_name, op in parsed.ops.items():
2415                    if 'do' in op and 'event' not in op:
2416                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2417                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2418                        cw.nl()
2419
2420            print_kernel_op_table_hdr(parsed, cw)
2421            print_kernel_mcgrp_hdr(parsed, cw)
2422            print_kernel_family_struct_hdr(parsed, cw)
2423        else:
2424            for _, struct in sorted(parsed.pure_nested_structs.items()):
2425                if struct.request:
2426                    cw.p('/* Common nested types */')
2427                    break
2428            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2429                if struct.request:
2430                    print_req_policy(cw, struct)
2431            cw.nl()
2432
2433            if parsed.kernel_policy == 'global':
2434                cw.p(f"/* Global operation policy for {parsed.name} */")
2435
2436                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2437                print_req_policy(cw, struct)
2438                cw.nl()
2439
2440            for op_name, op in parsed.ops.items():
2441                if parsed.kernel_policy in {'per-op', 'split'}:
2442                    for op_mode in ['do', 'dump']:
2443                        if op_mode in op and 'request' in op[op_mode]:
2444                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2445                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2446                            print_req_policy(cw, ri.struct['request'], ri=ri)
2447                            cw.nl()
2448
2449            print_kernel_op_table(parsed, cw)
2450            print_kernel_mcgrp_src(parsed, cw)
2451            print_kernel_family_struct_src(parsed, cw)
2452
2453    if args.mode == "user":
2454        if args.header:
2455            cw.p('/* Enums */')
2456            put_op_name_fwd(parsed, cw)
2457
2458            for name, const in parsed.consts.items():
2459                if isinstance(const, EnumSet):
2460                    put_enum_to_str_fwd(parsed, cw, const)
2461            cw.nl()
2462
2463            cw.p('/* Common nested types */')
2464            for attr_set, struct in parsed.pure_nested_structs.items():
2465                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2466                print_type_full(ri, struct)
2467
2468            for op_name, op in parsed.ops.items():
2469                cw.p(f"/* ============== {op.enum_name} ============== */")
2470
2471                if 'do' in op and 'event' not in op:
2472                    cw.p(f"/* {op.enum_name} - do */")
2473                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2474                    print_req_type(ri)
2475                    print_req_type_helpers(ri)
2476                    cw.nl()
2477                    print_rsp_type(ri)
2478                    print_rsp_type_helpers(ri)
2479                    cw.nl()
2480                    print_req_prototype(ri)
2481                    cw.nl()
2482
2483                if 'dump' in op:
2484                    cw.p(f"/* {op.enum_name} - dump */")
2485                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2486                    if 'request' in op['dump']:
2487                        print_req_type(ri)
2488                        print_req_type_helpers(ri)
2489                    if not ri.type_consistent:
2490                        print_rsp_type(ri)
2491                    print_wrapped_type(ri)
2492                    print_dump_prototype(ri)
2493                    cw.nl()
2494
2495                if op.has_ntf:
2496                    cw.p(f"/* {op.enum_name} - notify */")
2497                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2498                    if not ri.type_consistent:
2499                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2500                    print_wrapped_type(ri)
2501
2502            for op_name, op in parsed.ntfs.items():
2503                if 'event' in op:
2504                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2505                    cw.p(f"/* {op.enum_name} - event */")
2506                    print_rsp_type(ri)
2507                    cw.nl()
2508                    print_wrapped_type(ri)
2509            cw.nl()
2510        else:
2511            cw.p('/* Enums */')
2512            put_op_name(parsed, cw)
2513
2514            for name, const in parsed.consts.items():
2515                if isinstance(const, EnumSet):
2516                    put_enum_to_str(parsed, cw, const)
2517            cw.nl()
2518
2519            cw.p('/* Policies */')
2520            for name in parsed.pure_nested_structs:
2521                struct = Struct(parsed, name)
2522                put_typol(cw, struct)
2523            for name in parsed.root_sets:
2524                struct = Struct(parsed, name)
2525                put_typol(cw, struct)
2526
2527            cw.p('/* Common nested types */')
2528            for attr_set, struct in parsed.pure_nested_structs.items():
2529                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2530
2531                free_rsp_nested(ri, struct)
2532                if struct.request:
2533                    put_req_nested(ri, struct)
2534                if struct.reply:
2535                    parse_rsp_nested(ri, struct)
2536
2537            for op_name, op in parsed.ops.items():
2538                cw.p(f"/* ============== {op.enum_name} ============== */")
2539                if 'do' in op and 'event' not in op:
2540                    cw.p(f"/* {op.enum_name} - do */")
2541                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2542                    print_req_free(ri)
2543                    print_rsp_free(ri)
2544                    parse_rsp_msg(ri)
2545                    print_req(ri)
2546                    cw.nl()
2547
2548                if 'dump' in op:
2549                    cw.p(f"/* {op.enum_name} - dump */")
2550                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2551                    if not ri.type_consistent:
2552                        parse_rsp_msg(ri, deref=True)
2553                    print_dump_type_free(ri)
2554                    print_dump(ri)
2555                    cw.nl()
2556
2557                if op.has_ntf:
2558                    cw.p(f"/* {op.enum_name} - notify */")
2559                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2560                    if not ri.type_consistent:
2561                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2562                    print_ntf_type_free(ri)
2563
2564            for op_name, op in parsed.ntfs.items():
2565                if 'event' in op:
2566                    cw.p(f"/* {op.enum_name} - event */")
2567
2568                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2569                    parse_rsp_msg(ri)
2570
2571                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2572                    print_ntf_type_free(ri)
2573            cw.nl()
2574            render_user_family(parsed, cw, False)
2575
2576    if args.header:
2577        cw.p(f'#endif /* {hdr_prot} */')
2578
2579
2580if __name__ == "__main__":
2581    main()
2582