xref: /openbmc/linux/tools/net/ynl/lib/ynl.py (revision c1e0230e)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4import functools
5import os
6import random
7import socket
8import struct
9from struct import Struct
10import yaml
11import ipaddress
12import uuid
13
14from .nlspec import SpecFamily
15
16#
17# Generic Netlink code which should really be in some library, but I can't quickly find one.
18#
19
20
21class Netlink:
22    # Netlink socket
23    SOL_NETLINK = 270
24
25    NETLINK_ADD_MEMBERSHIP = 1
26    NETLINK_CAP_ACK = 10
27    NETLINK_EXT_ACK = 11
28
29    # Netlink message
30    NLMSG_ERROR = 2
31    NLMSG_DONE = 3
32
33    NLM_F_REQUEST = 1
34    NLM_F_ACK = 4
35    NLM_F_ROOT = 0x100
36    NLM_F_MATCH = 0x200
37    NLM_F_APPEND = 0x800
38
39    NLM_F_CAPPED = 0x100
40    NLM_F_ACK_TLVS = 0x200
41
42    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
43
44    NLA_F_NESTED = 0x8000
45    NLA_F_NET_BYTEORDER = 0x4000
46
47    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
48
49    # Genetlink defines
50    NETLINK_GENERIC = 16
51
52    GENL_ID_CTRL = 0x10
53
54    # nlctrl
55    CTRL_CMD_GETFAMILY = 3
56
57    CTRL_ATTR_FAMILY_ID = 1
58    CTRL_ATTR_FAMILY_NAME = 2
59    CTRL_ATTR_MAXATTR = 5
60    CTRL_ATTR_MCAST_GROUPS = 7
61
62    CTRL_ATTR_MCAST_GRP_NAME = 1
63    CTRL_ATTR_MCAST_GRP_ID = 2
64
65    # Extack types
66    NLMSGERR_ATTR_MSG = 1
67    NLMSGERR_ATTR_OFFS = 2
68    NLMSGERR_ATTR_COOKIE = 3
69    NLMSGERR_ATTR_POLICY = 4
70    NLMSGERR_ATTR_MISS_TYPE = 5
71    NLMSGERR_ATTR_MISS_NEST = 6
72
73
74class NlError(Exception):
75  def __init__(self, nl_msg):
76    self.nl_msg = nl_msg
77
78  def __str__(self):
79    return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
80
81
82class NlAttr:
83    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
84    type_formats = {
85        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
86        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
87        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
88        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
89        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
90        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
91        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
92        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
93    }
94
95    def __init__(self, raw, offset):
96        self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
97        self.type = self._type & ~Netlink.NLA_TYPE_MASK
98        self.payload_len = self._len
99        self.full_len = (self.payload_len + 3) & ~3
100        self.raw = raw[offset + 4:offset + self.payload_len]
101
102    @classmethod
103    def get_format(cls, attr_type, byte_order=None):
104        format = cls.type_formats[attr_type]
105        if byte_order:
106            return format.big if byte_order == "big-endian" \
107                else format.little
108        return format.native
109
110    @classmethod
111    def formatted_string(cls, raw, display_hint):
112        if display_hint == 'mac':
113            formatted = ':'.join('%02x' % b for b in raw)
114        elif display_hint == 'hex':
115            formatted = bytes.hex(raw, ' ')
116        elif display_hint in [ 'ipv4', 'ipv6' ]:
117            formatted = format(ipaddress.ip_address(raw))
118        elif display_hint == 'uuid':
119            formatted = str(uuid.UUID(bytes=raw))
120        else:
121            formatted = raw
122        return formatted
123
124    def as_scalar(self, attr_type, byte_order=None):
125        format = self.get_format(attr_type, byte_order)
126        return format.unpack(self.raw)[0]
127
128    def as_strz(self):
129        return self.raw.decode('ascii')[:-1]
130
131    def as_bin(self):
132        return self.raw
133
134    def as_c_array(self, type):
135        format = self.get_format(type)
136        return [ x[0] for x in format.iter_unpack(self.raw) ]
137
138    def as_struct(self, members):
139        value = dict()
140        offset = 0
141        for m in members:
142            # TODO: handle non-scalar members
143            if m.type == 'binary':
144                decoded = self.raw[offset:offset+m['len']]
145                offset += m['len']
146            elif m.type in NlAttr.type_formats:
147                format = self.get_format(m.type, m.byte_order)
148                [ decoded ] = format.unpack_from(self.raw, offset)
149                offset += format.size
150            if m.display_hint:
151                decoded = self.formatted_string(decoded, m.display_hint)
152            value[m.name] = decoded
153        return value
154
155    def __repr__(self):
156        return f"[type:{self.type} len:{self._len}] {self.raw}"
157
158
159class NlAttrs:
160    def __init__(self, msg):
161        self.attrs = []
162
163        offset = 0
164        while offset < len(msg):
165            attr = NlAttr(msg, offset)
166            offset += attr.full_len
167            self.attrs.append(attr)
168
169    def __iter__(self):
170        yield from self.attrs
171
172    def __repr__(self):
173        msg = ''
174        for a in self.attrs:
175            if msg:
176                msg += '\n'
177            msg += repr(a)
178        return msg
179
180
181class NlMsg:
182    def __init__(self, msg, offset, attr_space=None):
183        self.hdr = msg[offset:offset + 16]
184
185        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
186            struct.unpack("IHHII", self.hdr)
187
188        self.raw = msg[offset + 16:offset + self.nl_len]
189
190        self.error = 0
191        self.done = 0
192
193        extack_off = None
194        if self.nl_type == Netlink.NLMSG_ERROR:
195            self.error = struct.unpack("i", self.raw[0:4])[0]
196            self.done = 1
197            extack_off = 20
198        elif self.nl_type == Netlink.NLMSG_DONE:
199            self.done = 1
200            extack_off = 4
201
202        self.extack = None
203        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
204            self.extack = dict()
205            extack_attrs = NlAttrs(self.raw[extack_off:])
206            for extack in extack_attrs:
207                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
208                    self.extack['msg'] = extack.as_strz()
209                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
210                    self.extack['miss-type'] = extack.as_scalar('u32')
211                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
212                    self.extack['miss-nest'] = extack.as_scalar('u32')
213                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
214                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
215                else:
216                    if 'unknown' not in self.extack:
217                        self.extack['unknown'] = []
218                    self.extack['unknown'].append(extack)
219
220            if attr_space:
221                # We don't have the ability to parse nests yet, so only do global
222                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
223                    miss_type = self.extack['miss-type']
224                    if miss_type in attr_space.attrs_by_val:
225                        spec = attr_space.attrs_by_val[miss_type]
226                        desc = spec['name']
227                        if 'doc' in spec:
228                            desc += f" ({spec['doc']})"
229                        self.extack['miss-type'] = desc
230
231    def __repr__(self):
232        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
233        if self.error:
234            msg += '\terror: ' + str(self.error)
235        if self.extack:
236            msg += '\textack: ' + repr(self.extack)
237        return msg
238
239
240class NlMsgs:
241    def __init__(self, data, attr_space=None):
242        self.msgs = []
243
244        offset = 0
245        while offset < len(data):
246            msg = NlMsg(data, offset, attr_space=attr_space)
247            offset += msg.nl_len
248            self.msgs.append(msg)
249
250    def __iter__(self):
251        yield from self.msgs
252
253
254genl_family_name_to_id = None
255
256
257def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
258    # we prepend length in _genl_msg_finalize()
259    if seq is None:
260        seq = random.randint(1, 1024)
261    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
262    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
263    return nlmsg + genlmsg
264
265
266def _genl_msg_finalize(msg):
267    return struct.pack("I", len(msg) + 4) + msg
268
269
270def _genl_load_families():
271    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
272        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
273
274        msg = _genl_msg(Netlink.GENL_ID_CTRL,
275                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
276                        Netlink.CTRL_CMD_GETFAMILY, 1)
277        msg = _genl_msg_finalize(msg)
278
279        sock.send(msg, 0)
280
281        global genl_family_name_to_id
282        genl_family_name_to_id = dict()
283
284        while True:
285            reply = sock.recv(128 * 1024)
286            nms = NlMsgs(reply)
287            for nl_msg in nms:
288                if nl_msg.error:
289                    print("Netlink error:", nl_msg.error)
290                    return
291                if nl_msg.done:
292                    return
293
294                gm = GenlMsg(nl_msg)
295                fam = dict()
296                for attr in gm.raw_attrs:
297                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
298                        fam['id'] = attr.as_scalar('u16')
299                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
300                        fam['name'] = attr.as_strz()
301                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
302                        fam['maxattr'] = attr.as_scalar('u32')
303                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
304                        fam['mcast'] = dict()
305                        for entry in NlAttrs(attr.raw):
306                            mcast_name = None
307                            mcast_id = None
308                            for entry_attr in NlAttrs(entry.raw):
309                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
310                                    mcast_name = entry_attr.as_strz()
311                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
312                                    mcast_id = entry_attr.as_scalar('u32')
313                            if mcast_name and mcast_id is not None:
314                                fam['mcast'][mcast_name] = mcast_id
315                if 'name' in fam and 'id' in fam:
316                    genl_family_name_to_id[fam['name']] = fam
317
318
319class GenlMsg:
320    def __init__(self, nl_msg, fixed_header_members=[]):
321        self.nl = nl_msg
322
323        self.hdr = nl_msg.raw[0:4]
324        offset = 4
325
326        self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr)
327
328        self.fixed_header_attrs = dict()
329        for m in fixed_header_members:
330            format = NlAttr.get_format(m.type, m.byte_order)
331            decoded = format.unpack_from(nl_msg.raw, offset)
332            offset += format.size
333            self.fixed_header_attrs[m.name] = decoded[0]
334
335        self.raw = nl_msg.raw[offset:]
336        self.raw_attrs = NlAttrs(self.raw)
337
338    def __repr__(self):
339        msg = repr(self.nl)
340        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
341        for a in self.raw_attrs:
342            msg += '\t\t' + repr(a) + '\n'
343        return msg
344
345
346class GenlFamily:
347    def __init__(self, family_name):
348        self.family_name = family_name
349
350        global genl_family_name_to_id
351        if genl_family_name_to_id is None:
352            _genl_load_families()
353
354        self.genl_family = genl_family_name_to_id[family_name]
355        self.family_id = genl_family_name_to_id[family_name]['id']
356
357
358#
359# YNL implementation details.
360#
361
362
363class YnlFamily(SpecFamily):
364    def __init__(self, def_path, schema=None):
365        super().__init__(def_path, schema)
366
367        self.include_raw = False
368
369        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
370        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
371        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
372
373        self.async_msg_ids = set()
374        self.async_msg_queue = []
375
376        for msg in self.msgs.values():
377            if msg.is_async:
378                self.async_msg_ids.add(msg.rsp_value)
379
380        for op_name, op in self.ops.items():
381            bound_f = functools.partial(self._op, op_name)
382            setattr(self, op.ident_name, bound_f)
383
384        try:
385            self.family = GenlFamily(self.yaml['name'])
386        except KeyError:
387            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
388
389    def ntf_subscribe(self, mcast_name):
390        if mcast_name not in self.family.genl_family['mcast']:
391            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
392
393        self.sock.bind((0, 0))
394        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
395                             self.family.genl_family['mcast'][mcast_name])
396
397    def _add_attr(self, space, name, value):
398        try:
399            attr = self.attr_sets[space][name]
400        except KeyError:
401            raise Exception(f"Space '{space}' has no attribute '{name}'")
402        nl_type = attr.value
403        if attr["type"] == 'nest':
404            nl_type |= Netlink.NLA_F_NESTED
405            attr_payload = b''
406            for subname, subvalue in value.items():
407                attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
408        elif attr["type"] == 'flag':
409            attr_payload = b''
410        elif attr["type"] == 'string':
411            attr_payload = str(value).encode('ascii') + b'\x00'
412        elif attr["type"] == 'binary':
413            if isinstance(value, bytes):
414                attr_payload = value
415            elif isinstance(value, str):
416                attr_payload = bytes.fromhex(value)
417            else:
418                raise Exception(f'Unknown type for binary attribute, value: {value}')
419        elif attr['type'] in NlAttr.type_formats:
420            format = NlAttr.get_format(attr['type'], attr.byte_order)
421            attr_payload = format.pack(int(value))
422        else:
423            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
424
425        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
426        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
427
428    def _decode_enum(self, raw, attr_spec):
429        enum = self.consts[attr_spec['enum']]
430        if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
431            i = 0
432            value = set()
433            while raw:
434                if raw & 1:
435                    value.add(enum.entries_by_val[i].name)
436                raw >>= 1
437                i += 1
438        else:
439            value = enum.entries_by_val[raw].name
440        return value
441
442    def _decode_binary(self, attr, attr_spec):
443        if attr_spec.struct_name:
444            members = self.consts[attr_spec.struct_name]
445            decoded = attr.as_struct(members)
446            for m in members:
447                if m.enum:
448                    decoded[m.name] = self._decode_enum(decoded[m.name], m)
449        elif attr_spec.sub_type:
450            decoded = attr.as_c_array(attr_spec.sub_type)
451        else:
452            decoded = attr.as_bin()
453            if attr_spec.display_hint:
454                decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint)
455        return decoded
456
457    def _decode(self, attrs, space):
458        attr_space = self.attr_sets[space]
459        rsp = dict()
460        for attr in attrs:
461            try:
462                attr_spec = attr_space.attrs_by_val[attr.type]
463            except KeyError:
464                raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
465            if attr_spec["type"] == 'nest':
466                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
467                decoded = subdict
468            elif attr_spec["type"] == 'string':
469                decoded = attr.as_strz()
470            elif attr_spec["type"] == 'binary':
471                decoded = self._decode_binary(attr, attr_spec)
472            elif attr_spec["type"] == 'flag':
473                decoded = True
474            elif attr_spec["type"] in NlAttr.type_formats:
475                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
476            else:
477                raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
478
479            if 'enum' in attr_spec:
480                decoded = self._decode_enum(decoded, attr_spec)
481
482            if not attr_spec.is_multi:
483                rsp[attr_spec['name']] = decoded
484            elif attr_spec.name in rsp:
485                rsp[attr_spec.name].append(decoded)
486            else:
487                rsp[attr_spec.name] = [decoded]
488
489        return rsp
490
491    def _decode_extack_path(self, attrs, attr_set, offset, target):
492        for attr in attrs:
493            try:
494                attr_spec = attr_set.attrs_by_val[attr.type]
495            except KeyError:
496                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
497            if offset > target:
498                break
499            if offset == target:
500                return '.' + attr_spec.name
501
502            if offset + attr.full_len <= target:
503                offset += attr.full_len
504                continue
505            if attr_spec['type'] != 'nest':
506                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
507            offset += 4
508            subpath = self._decode_extack_path(NlAttrs(attr.raw),
509                                               self.attr_sets[attr_spec['nested-attributes']],
510                                               offset, target)
511            if subpath is None:
512                return None
513            return '.' + attr_spec.name + subpath
514
515        return None
516
517    def _decode_extack(self, request, attr_space, extack):
518        if 'bad-attr-offs' not in extack:
519            return
520
521        genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space))
522        path = self._decode_extack_path(genl_req.raw_attrs, attr_space,
523                                        20, extack['bad-attr-offs'])
524        if path:
525            del extack['bad-attr-offs']
526            extack['bad-attr'] = path
527
528    def handle_ntf(self, nl_msg, genl_msg):
529        msg = dict()
530        if self.include_raw:
531            msg['nlmsg'] = nl_msg
532            msg['genlmsg'] = genl_msg
533        op = self.rsp_by_value[genl_msg.genl_cmd]
534        msg['name'] = op['name']
535        msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
536        self.async_msg_queue.append(msg)
537
538    def check_ntf(self):
539        while True:
540            try:
541                reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
542            except BlockingIOError:
543                return
544
545            nms = NlMsgs(reply)
546            for nl_msg in nms:
547                if nl_msg.error:
548                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
549                    print(nl_msg)
550                    continue
551                if nl_msg.done:
552                    print("Netlink done while checking for ntf!?")
553                    continue
554
555                gm = GenlMsg(nl_msg)
556                if gm.genl_cmd not in self.async_msg_ids:
557                    print("Unexpected msg id done while checking for ntf", gm)
558                    continue
559
560                self.handle_ntf(nl_msg, gm)
561
562    def operation_do_attributes(self, name):
563      """
564      For a given operation name, find and return a supported
565      set of attributes (as a dict).
566      """
567      op = self.find_operation(name)
568      if not op:
569        return None
570
571      return op['do']['request']['attributes'].copy()
572
573    def _op(self, method, vals, dump=False):
574        op = self.ops[method]
575
576        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
577        if dump:
578            nl_flags |= Netlink.NLM_F_DUMP
579
580        req_seq = random.randint(1024, 65535)
581        msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
582        fixed_header_members = []
583        if op.fixed_header:
584            fixed_header_members = self.consts[op.fixed_header].members
585            for m in fixed_header_members:
586                value = vals.pop(m.name) if m.name in vals else 0
587                format = NlAttr.get_format(m.type, m.byte_order)
588                msg += format.pack(value)
589        for name, value in vals.items():
590            msg += self._add_attr(op.attr_set.name, name, value)
591        msg = _genl_msg_finalize(msg)
592
593        self.sock.send(msg, 0)
594
595        done = False
596        rsp = []
597        while not done:
598            reply = self.sock.recv(128 * 1024)
599            nms = NlMsgs(reply, attr_space=op.attr_set)
600            for nl_msg in nms:
601                if nl_msg.extack:
602                    self._decode_extack(msg, op.attr_set, nl_msg.extack)
603
604                if nl_msg.error:
605                    raise NlError(nl_msg)
606                if nl_msg.done:
607                    if nl_msg.extack:
608                        print("Netlink warning:")
609                        print(nl_msg)
610                    done = True
611                    break
612
613                gm = GenlMsg(nl_msg, fixed_header_members)
614                # Check if this is a reply to our request
615                if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
616                    if gm.genl_cmd in self.async_msg_ids:
617                        self.handle_ntf(nl_msg, gm)
618                        continue
619                    else:
620                        print('Unexpected message: ' + repr(gm))
621                        continue
622
623                rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name)
624                rsp_msg.update(gm.fixed_header_attrs)
625                rsp.append(rsp_msg)
626
627        if not rsp:
628            return None
629        if not dump and len(rsp) == 1:
630            return rsp[0]
631        return rsp
632
633    def do(self, method, vals):
634        return self._op(method, vals)
635
636    def dump(self, method, vals):
637        return self._op(method, vals, dump=True)
638