xref: /openbmc/linux/tools/net/ynl/lib/ynl.py (revision f98e51585f2ca00efbd16e27ed1d94a5e5520703)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4import functools
5import os
6import random
7import socket
8import struct
9from struct import Struct
10import yaml
11import ipaddress
12import uuid
13
14from .nlspec import SpecFamily
15
16#
17# Generic Netlink code which should really be in some library, but I can't quickly find one.
18#
19
20
21class Netlink:
22    # Netlink socket
23    SOL_NETLINK = 270
24
25    NETLINK_ADD_MEMBERSHIP = 1
26    NETLINK_CAP_ACK = 10
27    NETLINK_EXT_ACK = 11
28
29    # Netlink message
30    NLMSG_ERROR = 2
31    NLMSG_DONE = 3
32
33    NLM_F_REQUEST = 1
34    NLM_F_ACK = 4
35    NLM_F_ROOT = 0x100
36    NLM_F_MATCH = 0x200
37    NLM_F_APPEND = 0x800
38
39    NLM_F_CAPPED = 0x100
40    NLM_F_ACK_TLVS = 0x200
41
42    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
43
44    NLA_F_NESTED = 0x8000
45    NLA_F_NET_BYTEORDER = 0x4000
46
47    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
48
49    # Genetlink defines
50    NETLINK_GENERIC = 16
51
52    GENL_ID_CTRL = 0x10
53
54    # nlctrl
55    CTRL_CMD_GETFAMILY = 3
56
57    CTRL_ATTR_FAMILY_ID = 1
58    CTRL_ATTR_FAMILY_NAME = 2
59    CTRL_ATTR_MAXATTR = 5
60    CTRL_ATTR_MCAST_GROUPS = 7
61
62    CTRL_ATTR_MCAST_GRP_NAME = 1
63    CTRL_ATTR_MCAST_GRP_ID = 2
64
65    # Extack types
66    NLMSGERR_ATTR_MSG = 1
67    NLMSGERR_ATTR_OFFS = 2
68    NLMSGERR_ATTR_COOKIE = 3
69    NLMSGERR_ATTR_POLICY = 4
70    NLMSGERR_ATTR_MISS_TYPE = 5
71    NLMSGERR_ATTR_MISS_NEST = 6
72
73
74class NlError(Exception):
75  def __init__(self, nl_msg):
76    self.nl_msg = nl_msg
77
78  def __str__(self):
79    return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
80
81
82class NlAttr:
83    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
84    type_formats = {
85        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
86        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
87        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
88        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
89        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
90        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
91        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
92        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
93    }
94
95    def __init__(self, raw, offset):
96        self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
97        self.type = self._type & ~Netlink.NLA_TYPE_MASK
98        self.payload_len = self._len
99        self.full_len = (self.payload_len + 3) & ~3
100        self.raw = raw[offset + 4:offset + self.payload_len]
101
102    @classmethod
103    def get_format(cls, attr_type, byte_order=None):
104        format = cls.type_formats[attr_type]
105        if byte_order:
106            return format.big if byte_order == "big-endian" \
107                else format.little
108        return format.native
109
110    @classmethod
111    def formatted_string(cls, raw, display_hint):
112        if display_hint == 'mac':
113            formatted = ':'.join('%02x' % b for b in raw)
114        elif display_hint == 'hex':
115            formatted = bytes.hex(raw, ' ')
116        elif display_hint in [ 'ipv4', 'ipv6' ]:
117            formatted = format(ipaddress.ip_address(raw))
118        elif display_hint == 'uuid':
119            formatted = str(uuid.UUID(bytes=raw))
120        else:
121            formatted = raw
122        return formatted
123
124    def as_scalar(self, attr_type, byte_order=None):
125        format = self.get_format(attr_type, byte_order)
126        return format.unpack(self.raw)[0]
127
128    def as_strz(self):
129        return self.raw.decode('ascii')[:-1]
130
131    def as_bin(self):
132        return self.raw
133
134    def as_c_array(self, type):
135        format = self.get_format(type)
136        return [ x[0] for x in format.iter_unpack(self.raw) ]
137
138    def as_struct(self, members):
139        value = dict()
140        offset = 0
141        for m in members:
142            # TODO: handle non-scalar members
143            if m.type == 'binary':
144                decoded = self.raw[offset:offset+m['len']]
145                offset += m['len']
146            elif m.type in NlAttr.type_formats:
147                format = self.get_format(m.type, m.byte_order)
148                [ decoded ] = format.unpack_from(self.raw, offset)
149                offset += format.size
150            if m.display_hint:
151                decoded = self.formatted_string(decoded, m.display_hint)
152            value[m.name] = decoded
153        return value
154
155    def __repr__(self):
156        return f"[type:{self.type} len:{self._len}] {self.raw}"
157
158
159class NlAttrs:
160    def __init__(self, msg):
161        self.attrs = []
162
163        offset = 0
164        while offset < len(msg):
165            attr = NlAttr(msg, offset)
166            offset += attr.full_len
167            self.attrs.append(attr)
168
169    def __iter__(self):
170        yield from self.attrs
171
172    def __repr__(self):
173        msg = ''
174        for a in self.attrs:
175            if msg:
176                msg += '\n'
177            msg += repr(a)
178        return msg
179
180
181class NlMsg:
182    def __init__(self, msg, offset, attr_space=None):
183        self.hdr = msg[offset:offset + 16]
184
185        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
186            struct.unpack("IHHII", self.hdr)
187
188        self.raw = msg[offset + 16:offset + self.nl_len]
189
190        self.error = 0
191        self.done = 0
192
193        extack_off = None
194        if self.nl_type == Netlink.NLMSG_ERROR:
195            self.error = struct.unpack("i", self.raw[0:4])[0]
196            self.done = 1
197            extack_off = 20
198        elif self.nl_type == Netlink.NLMSG_DONE:
199            self.done = 1
200            extack_off = 4
201
202        self.extack = None
203        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
204            self.extack = dict()
205            extack_attrs = NlAttrs(self.raw[extack_off:])
206            for extack in extack_attrs:
207                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
208                    self.extack['msg'] = extack.as_strz()
209                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
210                    self.extack['miss-type'] = extack.as_scalar('u32')
211                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
212                    self.extack['miss-nest'] = extack.as_scalar('u32')
213                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
214                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
215                else:
216                    if 'unknown' not in self.extack:
217                        self.extack['unknown'] = []
218                    self.extack['unknown'].append(extack)
219
220            if attr_space:
221                # We don't have the ability to parse nests yet, so only do global
222                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
223                    miss_type = self.extack['miss-type']
224                    if miss_type in attr_space.attrs_by_val:
225                        spec = attr_space.attrs_by_val[miss_type]
226                        desc = spec['name']
227                        if 'doc' in spec:
228                            desc += f" ({spec['doc']})"
229                        self.extack['miss-type'] = desc
230
231    def __repr__(self):
232        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
233        if self.error:
234            msg += '\terror: ' + str(self.error)
235        if self.extack:
236            msg += '\textack: ' + repr(self.extack)
237        return msg
238
239
240class NlMsgs:
241    def __init__(self, data, attr_space=None):
242        self.msgs = []
243
244        offset = 0
245        while offset < len(data):
246            msg = NlMsg(data, offset, attr_space=attr_space)
247            offset += msg.nl_len
248            self.msgs.append(msg)
249
250    def __iter__(self):
251        yield from self.msgs
252
253
254genl_family_name_to_id = None
255
256
257def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
258    # we prepend length in _genl_msg_finalize()
259    if seq is None:
260        seq = random.randint(1, 1024)
261    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
262    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
263    return nlmsg + genlmsg
264
265
266def _genl_msg_finalize(msg):
267    return struct.pack("I", len(msg) + 4) + msg
268
269
270def _genl_load_families():
271    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
272        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
273
274        msg = _genl_msg(Netlink.GENL_ID_CTRL,
275                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
276                        Netlink.CTRL_CMD_GETFAMILY, 1)
277        msg = _genl_msg_finalize(msg)
278
279        sock.send(msg, 0)
280
281        global genl_family_name_to_id
282        genl_family_name_to_id = dict()
283
284        while True:
285            reply = sock.recv(128 * 1024)
286            nms = NlMsgs(reply)
287            for nl_msg in nms:
288                if nl_msg.error:
289                    print("Netlink error:", nl_msg.error)
290                    return
291                if nl_msg.done:
292                    return
293
294                gm = GenlMsg(nl_msg)
295                fam = dict()
296                for attr in gm.raw_attrs:
297                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
298                        fam['id'] = attr.as_scalar('u16')
299                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
300                        fam['name'] = attr.as_strz()
301                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
302                        fam['maxattr'] = attr.as_scalar('u32')
303                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
304                        fam['mcast'] = dict()
305                        for entry in NlAttrs(attr.raw):
306                            mcast_name = None
307                            mcast_id = None
308                            for entry_attr in NlAttrs(entry.raw):
309                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
310                                    mcast_name = entry_attr.as_strz()
311                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
312                                    mcast_id = entry_attr.as_scalar('u32')
313                            if mcast_name and mcast_id is not None:
314                                fam['mcast'][mcast_name] = mcast_id
315                if 'name' in fam and 'id' in fam:
316                    genl_family_name_to_id[fam['name']] = fam
317
318
319class GenlMsg:
320    def __init__(self, nl_msg, fixed_header_members=[]):
321        self.nl = nl_msg
322
323        self.hdr = nl_msg.raw[0:4]
324        offset = 4
325
326        self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr)
327
328        self.fixed_header_attrs = dict()
329        for m in fixed_header_members:
330            format = NlAttr.get_format(m.type, m.byte_order)
331            decoded = format.unpack_from(nl_msg.raw, offset)
332            offset += format.size
333            self.fixed_header_attrs[m.name] = decoded[0]
334
335        self.raw = nl_msg.raw[offset:]
336        self.raw_attrs = NlAttrs(self.raw)
337
338    def __repr__(self):
339        msg = repr(self.nl)
340        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
341        for a in self.raw_attrs:
342            msg += '\t\t' + repr(a) + '\n'
343        return msg
344
345
346class GenlFamily:
347    def __init__(self, family_name):
348        self.family_name = family_name
349
350        global genl_family_name_to_id
351        if genl_family_name_to_id is None:
352            _genl_load_families()
353
354        self.genl_family = genl_family_name_to_id[family_name]
355        self.family_id = genl_family_name_to_id[family_name]['id']
356
357
358#
359# YNL implementation details.
360#
361
362
363class YnlFamily(SpecFamily):
364    def __init__(self, def_path, schema=None):
365        super().__init__(def_path, schema)
366
367        self.include_raw = False
368
369        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
370        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
371        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
372
373        self.async_msg_ids = set()
374        self.async_msg_queue = []
375
376        for msg in self.msgs.values():
377            if msg.is_async:
378                self.async_msg_ids.add(msg.rsp_value)
379
380        for op_name, op in self.ops.items():
381            bound_f = functools.partial(self._op, op_name)
382            setattr(self, op.ident_name, bound_f)
383
384        try:
385            self.family = GenlFamily(self.yaml['name'])
386        except KeyError:
387            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
388
389    def ntf_subscribe(self, mcast_name):
390        if mcast_name not in self.family.genl_family['mcast']:
391            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
392
393        self.sock.bind((0, 0))
394        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
395                             self.family.genl_family['mcast'][mcast_name])
396
397    def _add_attr(self, space, name, value):
398        try:
399            attr = self.attr_sets[space][name]
400        except KeyError:
401            raise Exception(f"Space '{space}' has no attribute '{name}'")
402        nl_type = attr.value
403        if attr["type"] == 'nest':
404            nl_type |= Netlink.NLA_F_NESTED
405            attr_payload = b''
406            for subname, subvalue in value.items():
407                attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
408        elif attr["type"] == 'flag':
409            attr_payload = b''
410        elif attr["type"] == 'string':
411            attr_payload = str(value).encode('ascii') + b'\x00'
412        elif attr["type"] == 'binary':
413            attr_payload = bytes.fromhex(value)
414        elif attr['type'] in NlAttr.type_formats:
415            format = NlAttr.get_format(attr['type'], attr.byte_order)
416            attr_payload = format.pack(int(value))
417        else:
418            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
419
420        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
421        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
422
423    def _decode_enum(self, raw, attr_spec):
424        enum = self.consts[attr_spec['enum']]
425        if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
426            i = 0
427            value = set()
428            while raw:
429                if raw & 1:
430                    value.add(enum.entries_by_val[i].name)
431                raw >>= 1
432                i += 1
433        else:
434            value = enum.entries_by_val[raw].name
435        return value
436
437    def _decode_binary(self, attr, attr_spec):
438        if attr_spec.struct_name:
439            members = self.consts[attr_spec.struct_name]
440            decoded = attr.as_struct(members)
441            for m in members:
442                if m.enum:
443                    decoded[m.name] = self._decode_enum(decoded[m.name], m)
444        elif attr_spec.sub_type:
445            decoded = attr.as_c_array(attr_spec.sub_type)
446        else:
447            decoded = attr.as_bin()
448            if attr_spec.display_hint:
449                decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint)
450        return decoded
451
452    def _decode(self, attrs, space):
453        attr_space = self.attr_sets[space]
454        rsp = dict()
455        for attr in attrs:
456            try:
457                attr_spec = attr_space.attrs_by_val[attr.type]
458            except KeyError:
459                raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
460            if attr_spec["type"] == 'nest':
461                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
462                decoded = subdict
463            elif attr_spec["type"] == 'string':
464                decoded = attr.as_strz()
465            elif attr_spec["type"] == 'binary':
466                decoded = self._decode_binary(attr, attr_spec)
467            elif attr_spec["type"] == 'flag':
468                decoded = True
469            elif attr_spec["type"] in NlAttr.type_formats:
470                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
471            else:
472                raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
473
474            if 'enum' in attr_spec:
475                decoded = self._decode_enum(decoded, attr_spec)
476
477            if not attr_spec.is_multi:
478                rsp[attr_spec['name']] = decoded
479            elif attr_spec.name in rsp:
480                rsp[attr_spec.name].append(decoded)
481            else:
482                rsp[attr_spec.name] = [decoded]
483
484        return rsp
485
486    def _decode_extack_path(self, attrs, attr_set, offset, target):
487        for attr in attrs:
488            try:
489                attr_spec = attr_set.attrs_by_val[attr.type]
490            except KeyError:
491                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
492            if offset > target:
493                break
494            if offset == target:
495                return '.' + attr_spec.name
496
497            if offset + attr.full_len <= target:
498                offset += attr.full_len
499                continue
500            if attr_spec['type'] != 'nest':
501                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
502            offset += 4
503            subpath = self._decode_extack_path(NlAttrs(attr.raw),
504                                               self.attr_sets[attr_spec['nested-attributes']],
505                                               offset, target)
506            if subpath is None:
507                return None
508            return '.' + attr_spec.name + subpath
509
510        return None
511
512    def _decode_extack(self, request, attr_space, extack):
513        if 'bad-attr-offs' not in extack:
514            return
515
516        genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space))
517        path = self._decode_extack_path(genl_req.raw_attrs, attr_space,
518                                        20, extack['bad-attr-offs'])
519        if path:
520            del extack['bad-attr-offs']
521            extack['bad-attr'] = path
522
523    def handle_ntf(self, nl_msg, genl_msg):
524        msg = dict()
525        if self.include_raw:
526            msg['nlmsg'] = nl_msg
527            msg['genlmsg'] = genl_msg
528        op = self.rsp_by_value[genl_msg.genl_cmd]
529        msg['name'] = op['name']
530        msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
531        self.async_msg_queue.append(msg)
532
533    def check_ntf(self):
534        while True:
535            try:
536                reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
537            except BlockingIOError:
538                return
539
540            nms = NlMsgs(reply)
541            for nl_msg in nms:
542                if nl_msg.error:
543                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
544                    print(nl_msg)
545                    continue
546                if nl_msg.done:
547                    print("Netlink done while checking for ntf!?")
548                    continue
549
550                gm = GenlMsg(nl_msg)
551                if gm.genl_cmd not in self.async_msg_ids:
552                    print("Unexpected msg id done while checking for ntf", gm)
553                    continue
554
555                self.handle_ntf(nl_msg, gm)
556
557    def operation_do_attributes(self, name):
558      """
559      For a given operation name, find and return a supported
560      set of attributes (as a dict).
561      """
562      op = self.find_operation(name)
563      if not op:
564        return None
565
566      return op['do']['request']['attributes'].copy()
567
568    def _op(self, method, vals, dump=False):
569        op = self.ops[method]
570
571        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
572        if dump:
573            nl_flags |= Netlink.NLM_F_DUMP
574
575        req_seq = random.randint(1024, 65535)
576        msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
577        fixed_header_members = []
578        if op.fixed_header:
579            fixed_header_members = self.consts[op.fixed_header].members
580            for m in fixed_header_members:
581                value = vals.pop(m.name) if m.name in vals else 0
582                format = NlAttr.get_format(m.type, m.byte_order)
583                msg += format.pack(value)
584        for name, value in vals.items():
585            msg += self._add_attr(op.attr_set.name, name, value)
586        msg = _genl_msg_finalize(msg)
587
588        self.sock.send(msg, 0)
589
590        done = False
591        rsp = []
592        while not done:
593            reply = self.sock.recv(128 * 1024)
594            nms = NlMsgs(reply, attr_space=op.attr_set)
595            for nl_msg in nms:
596                if nl_msg.extack:
597                    self._decode_extack(msg, op.attr_set, nl_msg.extack)
598
599                if nl_msg.error:
600                    raise NlError(nl_msg)
601                if nl_msg.done:
602                    if nl_msg.extack:
603                        print("Netlink warning:")
604                        print(nl_msg)
605                    done = True
606                    break
607
608                gm = GenlMsg(nl_msg, fixed_header_members)
609                # Check if this is a reply to our request
610                if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
611                    if gm.genl_cmd in self.async_msg_ids:
612                        self.handle_ntf(nl_msg, gm)
613                        continue
614                    else:
615                        print('Unexpected message: ' + repr(gm))
616                        continue
617
618                rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name)
619                rsp_msg.update(gm.fixed_header_attrs)
620                rsp.append(rsp_msg)
621
622        if not rsp:
623            return None
624        if not dump and len(rsp) == 1:
625            return rsp[0]
626        return rsp
627
628    def do(self, method, vals):
629        return self._op(method, vals)
630
631    def dump(self, method, vals):
632        return self._op(method, vals, dump=True)
633