xref: /openbmc/linux/tools/net/ynl/lib/ynl.py (revision f4356947f0297b0962fdd197672db7edf9f58be6)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4import functools
5import os
6import random
7import socket
8import struct
9from struct import Struct
10import yaml
11
12from .nlspec import SpecFamily
13
14#
15# Generic Netlink code which should really be in some library, but I can't quickly find one.
16#
17
18
19class Netlink:
20    # Netlink socket
21    SOL_NETLINK = 270
22
23    NETLINK_ADD_MEMBERSHIP = 1
24    NETLINK_CAP_ACK = 10
25    NETLINK_EXT_ACK = 11
26
27    # Netlink message
28    NLMSG_ERROR = 2
29    NLMSG_DONE = 3
30
31    NLM_F_REQUEST = 1
32    NLM_F_ACK = 4
33    NLM_F_ROOT = 0x100
34    NLM_F_MATCH = 0x200
35    NLM_F_APPEND = 0x800
36
37    NLM_F_CAPPED = 0x100
38    NLM_F_ACK_TLVS = 0x200
39
40    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
41
42    NLA_F_NESTED = 0x8000
43    NLA_F_NET_BYTEORDER = 0x4000
44
45    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
46
47    # Genetlink defines
48    NETLINK_GENERIC = 16
49
50    GENL_ID_CTRL = 0x10
51
52    # nlctrl
53    CTRL_CMD_GETFAMILY = 3
54
55    CTRL_ATTR_FAMILY_ID = 1
56    CTRL_ATTR_FAMILY_NAME = 2
57    CTRL_ATTR_MAXATTR = 5
58    CTRL_ATTR_MCAST_GROUPS = 7
59
60    CTRL_ATTR_MCAST_GRP_NAME = 1
61    CTRL_ATTR_MCAST_GRP_ID = 2
62
63    # Extack types
64    NLMSGERR_ATTR_MSG = 1
65    NLMSGERR_ATTR_OFFS = 2
66    NLMSGERR_ATTR_COOKIE = 3
67    NLMSGERR_ATTR_POLICY = 4
68    NLMSGERR_ATTR_MISS_TYPE = 5
69    NLMSGERR_ATTR_MISS_NEST = 6
70
71
72class NlError(Exception):
73  def __init__(self, nl_msg):
74    self.nl_msg = nl_msg
75
76  def __str__(self):
77    return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
78
79
80class NlAttr:
81    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
82    type_formats = {
83        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
84        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
85        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
86        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
87        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
88        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
89        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
90        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
91    }
92
93    def __init__(self, raw, offset):
94        self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
95        self.type = self._type & ~Netlink.NLA_TYPE_MASK
96        self.payload_len = self._len
97        self.full_len = (self.payload_len + 3) & ~3
98        self.raw = raw[offset + 4:offset + self.payload_len]
99
100    @classmethod
101    def get_format(cls, attr_type, byte_order=None):
102        format = cls.type_formats[attr_type]
103        if byte_order:
104            return format.big if byte_order == "big-endian" \
105                else format.little
106        return format.native
107
108    def as_scalar(self, attr_type, byte_order=None):
109        format = self.get_format(attr_type, byte_order)
110        return format.unpack(self.raw)[0]
111
112    def as_strz(self):
113        return self.raw.decode('ascii')[:-1]
114
115    def as_bin(self):
116        return self.raw
117
118    def as_c_array(self, type):
119        format = self.get_format(type)
120        return [ x[0] for x in format.iter_unpack(self.raw) ]
121
122    def as_struct(self, members):
123        value = dict()
124        offset = 0
125        for m in members:
126            # TODO: handle non-scalar members
127            format = self.get_format(m.type, m.byte_order)
128            decoded = format.unpack_from(self.raw, offset)
129            offset += format.size
130            value[m.name] = decoded[0]
131        return value
132
133    def __repr__(self):
134        return f"[type:{self.type} len:{self._len}] {self.raw}"
135
136
137class NlAttrs:
138    def __init__(self, msg):
139        self.attrs = []
140
141        offset = 0
142        while offset < len(msg):
143            attr = NlAttr(msg, offset)
144            offset += attr.full_len
145            self.attrs.append(attr)
146
147    def __iter__(self):
148        yield from self.attrs
149
150    def __repr__(self):
151        msg = ''
152        for a in self.attrs:
153            if msg:
154                msg += '\n'
155            msg += repr(a)
156        return msg
157
158
159class NlMsg:
160    def __init__(self, msg, offset, attr_space=None):
161        self.hdr = msg[offset:offset + 16]
162
163        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
164            struct.unpack("IHHII", self.hdr)
165
166        self.raw = msg[offset + 16:offset + self.nl_len]
167
168        self.error = 0
169        self.done = 0
170
171        extack_off = None
172        if self.nl_type == Netlink.NLMSG_ERROR:
173            self.error = struct.unpack("i", self.raw[0:4])[0]
174            self.done = 1
175            extack_off = 20
176        elif self.nl_type == Netlink.NLMSG_DONE:
177            self.done = 1
178            extack_off = 4
179
180        self.extack = None
181        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
182            self.extack = dict()
183            extack_attrs = NlAttrs(self.raw[extack_off:])
184            for extack in extack_attrs:
185                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
186                    self.extack['msg'] = extack.as_strz()
187                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
188                    self.extack['miss-type'] = extack.as_scalar('u32')
189                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
190                    self.extack['miss-nest'] = extack.as_scalar('u32')
191                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
192                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
193                else:
194                    if 'unknown' not in self.extack:
195                        self.extack['unknown'] = []
196                    self.extack['unknown'].append(extack)
197
198            if attr_space:
199                # We don't have the ability to parse nests yet, so only do global
200                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
201                    miss_type = self.extack['miss-type']
202                    if miss_type in attr_space.attrs_by_val:
203                        spec = attr_space.attrs_by_val[miss_type]
204                        desc = spec['name']
205                        if 'doc' in spec:
206                            desc += f" ({spec['doc']})"
207                        self.extack['miss-type'] = desc
208
209    def __repr__(self):
210        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
211        if self.error:
212            msg += '\terror: ' + str(self.error)
213        if self.extack:
214            msg += '\textack: ' + repr(self.extack)
215        return msg
216
217
218class NlMsgs:
219    def __init__(self, data, attr_space=None):
220        self.msgs = []
221
222        offset = 0
223        while offset < len(data):
224            msg = NlMsg(data, offset, attr_space=attr_space)
225            offset += msg.nl_len
226            self.msgs.append(msg)
227
228    def __iter__(self):
229        yield from self.msgs
230
231
232genl_family_name_to_id = None
233
234
235def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
236    # we prepend length in _genl_msg_finalize()
237    if seq is None:
238        seq = random.randint(1, 1024)
239    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
240    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
241    return nlmsg + genlmsg
242
243
244def _genl_msg_finalize(msg):
245    return struct.pack("I", len(msg) + 4) + msg
246
247
248def _genl_load_families():
249    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
250        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
251
252        msg = _genl_msg(Netlink.GENL_ID_CTRL,
253                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
254                        Netlink.CTRL_CMD_GETFAMILY, 1)
255        msg = _genl_msg_finalize(msg)
256
257        sock.send(msg, 0)
258
259        global genl_family_name_to_id
260        genl_family_name_to_id = dict()
261
262        while True:
263            reply = sock.recv(128 * 1024)
264            nms = NlMsgs(reply)
265            for nl_msg in nms:
266                if nl_msg.error:
267                    print("Netlink error:", nl_msg.error)
268                    return
269                if nl_msg.done:
270                    return
271
272                gm = GenlMsg(nl_msg)
273                fam = dict()
274                for attr in gm.raw_attrs:
275                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
276                        fam['id'] = attr.as_scalar('u16')
277                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
278                        fam['name'] = attr.as_strz()
279                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
280                        fam['maxattr'] = attr.as_scalar('u32')
281                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
282                        fam['mcast'] = dict()
283                        for entry in NlAttrs(attr.raw):
284                            mcast_name = None
285                            mcast_id = None
286                            for entry_attr in NlAttrs(entry.raw):
287                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
288                                    mcast_name = entry_attr.as_strz()
289                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
290                                    mcast_id = entry_attr.as_scalar('u32')
291                            if mcast_name and mcast_id is not None:
292                                fam['mcast'][mcast_name] = mcast_id
293                if 'name' in fam and 'id' in fam:
294                    genl_family_name_to_id[fam['name']] = fam
295
296
297class GenlMsg:
298    def __init__(self, nl_msg, fixed_header_members=[]):
299        self.nl = nl_msg
300
301        self.hdr = nl_msg.raw[0:4]
302        offset = 4
303
304        self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr)
305
306        self.fixed_header_attrs = dict()
307        for m in fixed_header_members:
308            format = NlAttr.get_format(m.type, m.byte_order)
309            decoded = format.unpack_from(nl_msg.raw, offset)
310            offset += format.size
311            self.fixed_header_attrs[m.name] = decoded[0]
312
313        self.raw = nl_msg.raw[offset:]
314        self.raw_attrs = NlAttrs(self.raw)
315
316    def __repr__(self):
317        msg = repr(self.nl)
318        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
319        for a in self.raw_attrs:
320            msg += '\t\t' + repr(a) + '\n'
321        return msg
322
323
324class GenlFamily:
325    def __init__(self, family_name):
326        self.family_name = family_name
327
328        global genl_family_name_to_id
329        if genl_family_name_to_id is None:
330            _genl_load_families()
331
332        self.genl_family = genl_family_name_to_id[family_name]
333        self.family_id = genl_family_name_to_id[family_name]['id']
334
335
336#
337# YNL implementation details.
338#
339
340
341class YnlFamily(SpecFamily):
342    def __init__(self, def_path, schema=None):
343        super().__init__(def_path, schema)
344
345        self.include_raw = False
346
347        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
348        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
349        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
350
351        self.async_msg_ids = set()
352        self.async_msg_queue = []
353
354        for msg in self.msgs.values():
355            if msg.is_async:
356                self.async_msg_ids.add(msg.rsp_value)
357
358        for op_name, op in self.ops.items():
359            bound_f = functools.partial(self._op, op_name)
360            setattr(self, op.ident_name, bound_f)
361
362        try:
363            self.family = GenlFamily(self.yaml['name'])
364        except KeyError:
365            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
366
367    def ntf_subscribe(self, mcast_name):
368        if mcast_name not in self.family.genl_family['mcast']:
369            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
370
371        self.sock.bind((0, 0))
372        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
373                             self.family.genl_family['mcast'][mcast_name])
374
375    def _add_attr(self, space, name, value):
376        attr = self.attr_sets[space][name]
377        nl_type = attr.value
378        if attr["type"] == 'nest':
379            nl_type |= Netlink.NLA_F_NESTED
380            attr_payload = b''
381            for subname, subvalue in value.items():
382                attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
383        elif attr["type"] == 'flag':
384            attr_payload = b''
385        elif attr["type"] == 'string':
386            attr_payload = str(value).encode('ascii') + b'\x00'
387        elif attr["type"] == 'binary':
388            attr_payload = value
389        elif attr['type'] in NlAttr.type_formats:
390            format = NlAttr.get_format(attr['type'], attr.byte_order)
391            attr_payload = format.pack(int(value))
392        else:
393            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
394
395        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
396        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
397
398    def _decode_enum(self, rsp, attr_spec):
399        raw = rsp[attr_spec['name']]
400        enum = self.consts[attr_spec['enum']]
401        i = attr_spec.get('value-start', 0)
402        if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
403            value = set()
404            while raw:
405                if raw & 1:
406                    value.add(enum.entries_by_val[i].name)
407                raw >>= 1
408                i += 1
409        else:
410            value = enum.entries_by_val[raw - i].name
411        rsp[attr_spec['name']] = value
412
413    def _decode_binary(self, attr, attr_spec):
414        if attr_spec.struct_name:
415            members = self.consts[attr_spec.struct_name]
416            decoded = attr.as_struct(members)
417            for m in members:
418                if m.enum:
419                    self._decode_enum(decoded, m)
420        elif attr_spec.sub_type:
421            decoded = attr.as_c_array(attr_spec.sub_type)
422        else:
423            decoded = attr.as_bin()
424        return decoded
425
426    def _decode(self, attrs, space):
427        attr_space = self.attr_sets[space]
428        rsp = dict()
429        for attr in attrs:
430            attr_spec = attr_space.attrs_by_val[attr.type]
431            if attr_spec["type"] == 'nest':
432                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
433                decoded = subdict
434            elif attr_spec["type"] == 'string':
435                decoded = attr.as_strz()
436            elif attr_spec["type"] == 'binary':
437                decoded = self._decode_binary(attr, attr_spec)
438            elif attr_spec["type"] == 'flag':
439                decoded = True
440            elif attr_spec["type"] in NlAttr.type_formats:
441                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
442            else:
443                raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
444
445            if not attr_spec.is_multi:
446                rsp[attr_spec['name']] = decoded
447            elif attr_spec.name in rsp:
448                rsp[attr_spec.name].append(decoded)
449            else:
450                rsp[attr_spec.name] = [decoded]
451
452            if 'enum' in attr_spec:
453                self._decode_enum(rsp, attr_spec)
454        return rsp
455
456    def _decode_extack_path(self, attrs, attr_set, offset, target):
457        for attr in attrs:
458            attr_spec = attr_set.attrs_by_val[attr.type]
459            if offset > target:
460                break
461            if offset == target:
462                return '.' + attr_spec.name
463
464            if offset + attr.full_len <= target:
465                offset += attr.full_len
466                continue
467            if attr_spec['type'] != 'nest':
468                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
469            offset += 4
470            subpath = self._decode_extack_path(NlAttrs(attr.raw),
471                                               self.attr_sets[attr_spec['nested-attributes']],
472                                               offset, target)
473            if subpath is None:
474                return None
475            return '.' + attr_spec.name + subpath
476
477        return None
478
479    def _decode_extack(self, request, attr_space, extack):
480        if 'bad-attr-offs' not in extack:
481            return
482
483        genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space))
484        path = self._decode_extack_path(genl_req.raw_attrs, attr_space,
485                                        20, extack['bad-attr-offs'])
486        if path:
487            del extack['bad-attr-offs']
488            extack['bad-attr'] = path
489
490    def handle_ntf(self, nl_msg, genl_msg):
491        msg = dict()
492        if self.include_raw:
493            msg['nlmsg'] = nl_msg
494            msg['genlmsg'] = genl_msg
495        op = self.rsp_by_value[genl_msg.genl_cmd]
496        msg['name'] = op['name']
497        msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
498        self.async_msg_queue.append(msg)
499
500    def check_ntf(self):
501        while True:
502            try:
503                reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
504            except BlockingIOError:
505                return
506
507            nms = NlMsgs(reply)
508            for nl_msg in nms:
509                if nl_msg.error:
510                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
511                    print(nl_msg)
512                    continue
513                if nl_msg.done:
514                    print("Netlink done while checking for ntf!?")
515                    continue
516
517                gm = GenlMsg(nl_msg)
518                if gm.genl_cmd not in self.async_msg_ids:
519                    print("Unexpected msg id done while checking for ntf", gm)
520                    continue
521
522                self.handle_ntf(nl_msg, gm)
523
524    def operation_do_attributes(self, name):
525      """
526      For a given operation name, find and return a supported
527      set of attributes (as a dict).
528      """
529      op = self.find_operation(name)
530      if not op:
531        return None
532
533      return op['do']['request']['attributes'].copy()
534
535    def _op(self, method, vals, dump=False):
536        op = self.ops[method]
537
538        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
539        if dump:
540            nl_flags |= Netlink.NLM_F_DUMP
541
542        req_seq = random.randint(1024, 65535)
543        msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
544        fixed_header_members = []
545        if op.fixed_header:
546            fixed_header_members = self.consts[op.fixed_header].members
547            for m in fixed_header_members:
548                value = vals.pop(m.name) if m.name in vals else 0
549                format = NlAttr.get_format(m.type, m.byte_order)
550                msg += format.pack(value)
551        for name, value in vals.items():
552            msg += self._add_attr(op.attr_set.name, name, value)
553        msg = _genl_msg_finalize(msg)
554
555        self.sock.send(msg, 0)
556
557        done = False
558        rsp = []
559        while not done:
560            reply = self.sock.recv(128 * 1024)
561            nms = NlMsgs(reply, attr_space=op.attr_set)
562            for nl_msg in nms:
563                if nl_msg.extack:
564                    self._decode_extack(msg, op.attr_set, nl_msg.extack)
565
566                if nl_msg.error:
567                    raise NlError(nl_msg)
568                if nl_msg.done:
569                    if nl_msg.extack:
570                        print("Netlink warning:")
571                        print(nl_msg)
572                    done = True
573                    break
574
575                gm = GenlMsg(nl_msg, fixed_header_members)
576                # Check if this is a reply to our request
577                if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
578                    if gm.genl_cmd in self.async_msg_ids:
579                        self.handle_ntf(nl_msg, gm)
580                        continue
581                    else:
582                        print('Unexpected message: ' + repr(gm))
583                        continue
584
585                rsp.append(self._decode(gm.raw_attrs, op.attr_set.name)
586                           | gm.fixed_header_attrs)
587
588        if not rsp:
589            return None
590        if not dump and len(rsp) == 1:
591            return rsp[0]
592        return rsp
593
594    def do(self, method, vals):
595        return self._op(method, vals)
596
597    def dump(self, method, vals):
598        return self._op(method, vals, dump=True)
599