xref: /openbmc/linux/tools/net/ynl/lib/ynl.py (revision 2d99a7ec)
1# SPDX-License-Identifier: BSD-3-Clause
2
3import functools
4import os
5import random
6import socket
7import struct
8import yaml
9
10from .nlspec import SpecFamily
11
12#
13# Generic Netlink code which should really be in some library, but I can't quickly find one.
14#
15
16
17class Netlink:
18    # Netlink socket
19    SOL_NETLINK = 270
20
21    NETLINK_ADD_MEMBERSHIP = 1
22    NETLINK_CAP_ACK = 10
23    NETLINK_EXT_ACK = 11
24
25    # Netlink message
26    NLMSG_ERROR = 2
27    NLMSG_DONE = 3
28
29    NLM_F_REQUEST = 1
30    NLM_F_ACK = 4
31    NLM_F_ROOT = 0x100
32    NLM_F_MATCH = 0x200
33    NLM_F_APPEND = 0x800
34
35    NLM_F_CAPPED = 0x100
36    NLM_F_ACK_TLVS = 0x200
37
38    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
39
40    NLA_F_NESTED = 0x8000
41    NLA_F_NET_BYTEORDER = 0x4000
42
43    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
44
45    # Genetlink defines
46    NETLINK_GENERIC = 16
47
48    GENL_ID_CTRL = 0x10
49
50    # nlctrl
51    CTRL_CMD_GETFAMILY = 3
52
53    CTRL_ATTR_FAMILY_ID = 1
54    CTRL_ATTR_FAMILY_NAME = 2
55    CTRL_ATTR_MAXATTR = 5
56    CTRL_ATTR_MCAST_GROUPS = 7
57
58    CTRL_ATTR_MCAST_GRP_NAME = 1
59    CTRL_ATTR_MCAST_GRP_ID = 2
60
61    # Extack types
62    NLMSGERR_ATTR_MSG = 1
63    NLMSGERR_ATTR_OFFS = 2
64    NLMSGERR_ATTR_COOKIE = 3
65    NLMSGERR_ATTR_POLICY = 4
66    NLMSGERR_ATTR_MISS_TYPE = 5
67    NLMSGERR_ATTR_MISS_NEST = 6
68
69
70class NlAttr:
71    def __init__(self, raw, offset):
72        self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
73        self.type = self._type & ~Netlink.NLA_TYPE_MASK
74        self.payload_len = self._len
75        self.full_len = (self.payload_len + 3) & ~3
76        self.raw = raw[offset + 4:offset + self.payload_len]
77
78    def as_u8(self):
79        return struct.unpack("B", self.raw)[0]
80
81    def as_u16(self):
82        return struct.unpack("H", self.raw)[0]
83
84    def as_u32(self):
85        return struct.unpack("I", self.raw)[0]
86
87    def as_u64(self):
88        return struct.unpack("Q", self.raw)[0]
89
90    def as_strz(self):
91        return self.raw.decode('ascii')[:-1]
92
93    def as_bin(self):
94        return self.raw
95
96    def __repr__(self):
97        return f"[type:{self.type} len:{self._len}] {self.raw}"
98
99
100class NlAttrs:
101    def __init__(self, msg):
102        self.attrs = []
103
104        offset = 0
105        while offset < len(msg):
106            attr = NlAttr(msg, offset)
107            offset += attr.full_len
108            self.attrs.append(attr)
109
110    def __iter__(self):
111        yield from self.attrs
112
113    def __repr__(self):
114        msg = ''
115        for a in self.attrs:
116            if msg:
117                msg += '\n'
118            msg += repr(a)
119        return msg
120
121
122class NlMsg:
123    def __init__(self, msg, offset, attr_space=None):
124        self.hdr = msg[offset:offset + 16]
125
126        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
127            struct.unpack("IHHII", self.hdr)
128
129        self.raw = msg[offset + 16:offset + self.nl_len]
130
131        self.error = 0
132        self.done = 0
133
134        extack_off = None
135        if self.nl_type == Netlink.NLMSG_ERROR:
136            self.error = struct.unpack("i", self.raw[0:4])[0]
137            self.done = 1
138            extack_off = 20
139        elif self.nl_type == Netlink.NLMSG_DONE:
140            self.done = 1
141            extack_off = 4
142
143        self.extack = None
144        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
145            self.extack = dict()
146            extack_attrs = NlAttrs(self.raw[extack_off:])
147            for extack in extack_attrs:
148                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
149                    self.extack['msg'] = extack.as_strz()
150                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
151                    self.extack['miss-type'] = extack.as_u32()
152                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
153                    self.extack['miss-nest'] = extack.as_u32()
154                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
155                    self.extack['bad-attr-offs'] = extack.as_u32()
156                else:
157                    if 'unknown' not in self.extack:
158                        self.extack['unknown'] = []
159                    self.extack['unknown'].append(extack)
160
161            if attr_space:
162                # We don't have the ability to parse nests yet, so only do global
163                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
164                    miss_type = self.extack['miss-type']
165                    if miss_type in attr_space.attrs_by_val:
166                        spec = attr_space.attrs_by_val[miss_type]
167                        desc = spec['name']
168                        if 'doc' in spec:
169                            desc += f" ({spec['doc']})"
170                        self.extack['miss-type'] = desc
171
172    def __repr__(self):
173        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
174        if self.error:
175            msg += '\terror: ' + str(self.error)
176        if self.extack:
177            msg += '\textack: ' + repr(self.extack)
178        return msg
179
180
181class NlMsgs:
182    def __init__(self, data, attr_space=None):
183        self.msgs = []
184
185        offset = 0
186        while offset < len(data):
187            msg = NlMsg(data, offset, attr_space=attr_space)
188            offset += msg.nl_len
189            self.msgs.append(msg)
190
191    def __iter__(self):
192        yield from self.msgs
193
194
195genl_family_name_to_id = None
196
197
198def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
199    # we prepend length in _genl_msg_finalize()
200    if seq is None:
201        seq = random.randint(1, 1024)
202    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
203    genlmsg = struct.pack("bbH", genl_cmd, genl_version, 0)
204    return nlmsg + genlmsg
205
206
207def _genl_msg_finalize(msg):
208    return struct.pack("I", len(msg) + 4) + msg
209
210
211def _genl_load_families():
212    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
213        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
214
215        msg = _genl_msg(Netlink.GENL_ID_CTRL,
216                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
217                        Netlink.CTRL_CMD_GETFAMILY, 1)
218        msg = _genl_msg_finalize(msg)
219
220        sock.send(msg, 0)
221
222        global genl_family_name_to_id
223        genl_family_name_to_id = dict()
224
225        while True:
226            reply = sock.recv(128 * 1024)
227            nms = NlMsgs(reply)
228            for nl_msg in nms:
229                if nl_msg.error:
230                    print("Netlink error:", nl_msg.error)
231                    return
232                if nl_msg.done:
233                    return
234
235                gm = GenlMsg(nl_msg)
236                fam = dict()
237                for attr in gm.raw_attrs:
238                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
239                        fam['id'] = attr.as_u16()
240                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
241                        fam['name'] = attr.as_strz()
242                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
243                        fam['maxattr'] = attr.as_u32()
244                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
245                        fam['mcast'] = dict()
246                        for entry in NlAttrs(attr.raw):
247                            mcast_name = None
248                            mcast_id = None
249                            for entry_attr in NlAttrs(entry.raw):
250                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
251                                    mcast_name = entry_attr.as_strz()
252                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
253                                    mcast_id = entry_attr.as_u32()
254                            if mcast_name and mcast_id is not None:
255                                fam['mcast'][mcast_name] = mcast_id
256                if 'name' in fam and 'id' in fam:
257                    genl_family_name_to_id[fam['name']] = fam
258
259
260class GenlMsg:
261    def __init__(self, nl_msg):
262        self.nl = nl_msg
263
264        self.hdr = nl_msg.raw[0:4]
265        self.raw = nl_msg.raw[4:]
266
267        self.genl_cmd, self.genl_version, _ = struct.unpack("bbH", self.hdr)
268
269        self.raw_attrs = NlAttrs(self.raw)
270
271    def __repr__(self):
272        msg = repr(self.nl)
273        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
274        for a in self.raw_attrs:
275            msg += '\t\t' + repr(a) + '\n'
276        return msg
277
278
279class GenlFamily:
280    def __init__(self, family_name):
281        self.family_name = family_name
282
283        global genl_family_name_to_id
284        if genl_family_name_to_id is None:
285            _genl_load_families()
286
287        self.genl_family = genl_family_name_to_id[family_name]
288        self.family_id = genl_family_name_to_id[family_name]['id']
289
290
291#
292# YNL implementation details.
293#
294
295
296class YnlFamily(SpecFamily):
297    def __init__(self, def_path, schema=None):
298        super().__init__(def_path, schema)
299
300        self.include_raw = False
301
302        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
303        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
304        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
305
306        self._types = dict()
307
308        for elem in self.yaml.get('definitions', []):
309            self._types[elem['name']] = elem
310
311        self.async_msg_ids = set()
312        self.async_msg_queue = []
313
314        for msg in self.msgs.values():
315            if msg.is_async:
316                self.async_msg_ids.add(msg.rsp_value)
317
318        for op_name, op in self.ops.items():
319            bound_f = functools.partial(self._op, op_name)
320            setattr(self, op.ident_name, bound_f)
321
322        self.family = GenlFamily(self.yaml['name'])
323
324    def ntf_subscribe(self, mcast_name):
325        if mcast_name not in self.family.genl_family['mcast']:
326            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
327
328        self.sock.bind((0, 0))
329        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
330                             self.family.genl_family['mcast'][mcast_name])
331
332    def _add_attr(self, space, name, value):
333        attr = self.attr_sets[space][name]
334        nl_type = attr.value
335        if attr["type"] == 'nest':
336            nl_type |= Netlink.NLA_F_NESTED
337            attr_payload = b''
338            for subname, subvalue in value.items():
339                attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
340        elif attr["type"] == 'flag':
341            attr_payload = b''
342        elif attr["type"] == 'u32':
343            attr_payload = struct.pack("I", int(value))
344        elif attr["type"] == 'string':
345            attr_payload = str(value).encode('ascii') + b'\x00'
346        elif attr["type"] == 'binary':
347            attr_payload = value
348        else:
349            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
350
351        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
352        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
353
354    def _decode_enum(self, rsp, attr_spec):
355        raw = rsp[attr_spec['name']]
356        enum = self._types[attr_spec['enum']]
357        i = attr_spec.get('value-start', 0)
358        if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
359            value = set()
360            while raw:
361                if raw & 1:
362                    value.add(enum['entries'][i])
363                raw >>= 1
364                i += 1
365        else:
366            value = enum['entries'][raw - i]
367        rsp[attr_spec['name']] = value
368
369    def _decode(self, attrs, space):
370        attr_space = self.attr_sets[space]
371        rsp = dict()
372        for attr in attrs:
373            attr_spec = attr_space.attrs_by_val[attr.type]
374            if attr_spec["type"] == 'nest':
375                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
376                decoded = subdict
377            elif attr_spec['type'] == 'u8':
378                decoded = attr.as_u8()
379            elif attr_spec['type'] == 'u32':
380                decoded = attr.as_u32()
381            elif attr_spec['type'] == 'u64':
382                decoded = attr.as_u64()
383            elif attr_spec["type"] == 'string':
384                decoded = attr.as_strz()
385            elif attr_spec["type"] == 'binary':
386                decoded = attr.as_bin()
387            elif attr_spec["type"] == 'flag':
388                decoded = True
389            else:
390                raise Exception(f'Unknown {attr.type} {attr_spec["name"]} {attr_spec["type"]}')
391
392            if not attr_spec.is_multi:
393                rsp[attr_spec['name']] = decoded
394            elif attr_spec.name in rsp:
395                rsp[attr_spec.name].append(decoded)
396            else:
397                rsp[attr_spec.name] = [decoded]
398
399            if 'enum' in attr_spec:
400                self._decode_enum(rsp, attr_spec)
401        return rsp
402
403    def _decode_extack_path(self, attrs, attr_set, offset, target):
404        for attr in attrs:
405            attr_spec = attr_set.attrs_by_val[attr.type]
406            if offset > target:
407                break
408            if offset == target:
409                return '.' + attr_spec.name
410
411            if offset + attr.full_len <= target:
412                offset += attr.full_len
413                continue
414            if attr_spec['type'] != 'nest':
415                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
416            offset += 4
417            subpath = self._decode_extack_path(NlAttrs(attr.raw),
418                                               self.attr_sets[attr_spec['nested-attributes']],
419                                               offset, target)
420            if subpath is None:
421                return None
422            return '.' + attr_spec.name + subpath
423
424        return None
425
426    def _decode_extack(self, request, attr_space, extack):
427        if 'bad-attr-offs' not in extack:
428            return
429
430        genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space))
431        path = self._decode_extack_path(genl_req.raw_attrs, attr_space,
432                                        20, extack['bad-attr-offs'])
433        if path:
434            del extack['bad-attr-offs']
435            extack['bad-attr'] = path
436
437    def handle_ntf(self, nl_msg, genl_msg):
438        msg = dict()
439        if self.include_raw:
440            msg['nlmsg'] = nl_msg
441            msg['genlmsg'] = genl_msg
442        op = self.rsp_by_value[genl_msg.genl_cmd]
443        msg['name'] = op['name']
444        msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
445        self.async_msg_queue.append(msg)
446
447    def check_ntf(self):
448        while True:
449            try:
450                reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
451            except BlockingIOError:
452                return
453
454            nms = NlMsgs(reply)
455            for nl_msg in nms:
456                if nl_msg.error:
457                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
458                    print(nl_msg)
459                    continue
460                if nl_msg.done:
461                    print("Netlink done while checking for ntf!?")
462                    continue
463
464                gm = GenlMsg(nl_msg)
465                if gm.genl_cmd not in self.async_msg_ids:
466                    print("Unexpected msg id done while checking for ntf", gm)
467                    continue
468
469                self.handle_ntf(nl_msg, gm)
470
471    def _op(self, method, vals, dump=False):
472        op = self.ops[method]
473
474        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
475        if dump:
476            nl_flags |= Netlink.NLM_F_DUMP
477
478        req_seq = random.randint(1024, 65535)
479        msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
480        for name, value in vals.items():
481            msg += self._add_attr(op.attr_set.name, name, value)
482        msg = _genl_msg_finalize(msg)
483
484        self.sock.send(msg, 0)
485
486        done = False
487        rsp = []
488        while not done:
489            reply = self.sock.recv(128 * 1024)
490            nms = NlMsgs(reply, attr_space=op.attr_set)
491            for nl_msg in nms:
492                if nl_msg.extack:
493                    self._decode_extack(msg, op.attr_set, nl_msg.extack)
494
495                if nl_msg.error:
496                    print("Netlink error:", os.strerror(-nl_msg.error))
497                    print(nl_msg)
498                    return
499                if nl_msg.done:
500                    if nl_msg.extack:
501                        print("Netlink warning:")
502                        print(nl_msg)
503                    done = True
504                    break
505
506                gm = GenlMsg(nl_msg)
507                # Check if this is a reply to our request
508                if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
509                    if gm.genl_cmd in self.async_msg_ids:
510                        self.handle_ntf(nl_msg, gm)
511                        continue
512                    else:
513                        print('Unexpected message: ' + repr(gm))
514                        continue
515
516                rsp.append(self._decode(gm.raw_attrs, op.attr_set.name))
517
518        if not rsp:
519            return None
520        if not dump and len(rsp) == 1:
521            return rsp[0]
522        return rsp
523
524    def do(self, method, vals):
525        return self._op(method, vals)
526
527    def dump(self, method, vals):
528        return self._op(method, vals, dump=True)
529