# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause from collections import namedtuple import functools import os import random import socket import struct from struct import Struct import yaml from .nlspec import SpecFamily # # Generic Netlink code which should really be in some library, but I can't quickly find one. # class Netlink: # Netlink socket SOL_NETLINK = 270 NETLINK_ADD_MEMBERSHIP = 1 NETLINK_CAP_ACK = 10 NETLINK_EXT_ACK = 11 # Netlink message NLMSG_ERROR = 2 NLMSG_DONE = 3 NLM_F_REQUEST = 1 NLM_F_ACK = 4 NLM_F_ROOT = 0x100 NLM_F_MATCH = 0x200 NLM_F_APPEND = 0x800 NLM_F_CAPPED = 0x100 NLM_F_ACK_TLVS = 0x200 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH NLA_F_NESTED = 0x8000 NLA_F_NET_BYTEORDER = 0x4000 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER # Genetlink defines NETLINK_GENERIC = 16 GENL_ID_CTRL = 0x10 # nlctrl CTRL_CMD_GETFAMILY = 3 CTRL_ATTR_FAMILY_ID = 1 CTRL_ATTR_FAMILY_NAME = 2 CTRL_ATTR_MAXATTR = 5 CTRL_ATTR_MCAST_GROUPS = 7 CTRL_ATTR_MCAST_GRP_NAME = 1 CTRL_ATTR_MCAST_GRP_ID = 2 # Extack types NLMSGERR_ATTR_MSG = 1 NLMSGERR_ATTR_OFFS = 2 NLMSGERR_ATTR_COOKIE = 3 NLMSGERR_ATTR_POLICY = 4 NLMSGERR_ATTR_MISS_TYPE = 5 NLMSGERR_ATTR_MISS_NEST = 6 class NlError(Exception): def __init__(self, nl_msg): self.nl_msg = nl_msg def __str__(self): return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" class NlAttr: ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) type_formats = { 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("h"), Struct("I"), Struct("i"), Struct("Q"), Struct("q"), Struct(">= 1 i += 1 else: value = enum.entries_by_val[raw - i].name rsp[attr_spec['name']] = value def _decode_binary(self, attr, attr_spec): if attr_spec.struct_name: members = self.consts[attr_spec.struct_name] decoded = attr.as_struct(members) for m in members: if m.enum: self._decode_enum(decoded, m) elif attr_spec.sub_type: decoded = attr.as_c_array(attr_spec.sub_type) else: decoded = attr.as_bin() return decoded def _decode(self, attrs, space): attr_space = self.attr_sets[space] rsp = dict() for attr in attrs: attr_spec = attr_space.attrs_by_val[attr.type] if attr_spec["type"] == 'nest': subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) decoded = subdict elif attr_spec["type"] == 'string': decoded = attr.as_strz() elif attr_spec["type"] == 'binary': decoded = self._decode_binary(attr, attr_spec) elif attr_spec["type"] == 'flag': decoded = True elif attr_spec["type"] in NlAttr.type_formats: decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) else: raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') if not attr_spec.is_multi: rsp[attr_spec['name']] = decoded elif attr_spec.name in rsp: rsp[attr_spec.name].append(decoded) else: rsp[attr_spec.name] = [decoded] if 'enum' in attr_spec: self._decode_enum(rsp, attr_spec) return rsp def _decode_extack_path(self, attrs, attr_set, offset, target): for attr in attrs: attr_spec = attr_set.attrs_by_val[attr.type] if offset > target: break if offset == target: return '.' + attr_spec.name if offset + attr.full_len <= target: offset += attr.full_len continue if attr_spec['type'] != 'nest': raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") offset += 4 subpath = self._decode_extack_path(NlAttrs(attr.raw), self.attr_sets[attr_spec['nested-attributes']], offset, target) if subpath is None: return None return '.' + attr_spec.name + subpath return None def _decode_extack(self, request, attr_space, extack): if 'bad-attr-offs' not in extack: return genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space)) path = self._decode_extack_path(genl_req.raw_attrs, attr_space, 20, extack['bad-attr-offs']) if path: del extack['bad-attr-offs'] extack['bad-attr'] = path def handle_ntf(self, nl_msg, genl_msg): msg = dict() if self.include_raw: msg['nlmsg'] = nl_msg msg['genlmsg'] = genl_msg op = self.rsp_by_value[genl_msg.genl_cmd] msg['name'] = op['name'] msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name) self.async_msg_queue.append(msg) def check_ntf(self): while True: try: reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) except BlockingIOError: return nms = NlMsgs(reply) for nl_msg in nms: if nl_msg.error: print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) print(nl_msg) continue if nl_msg.done: print("Netlink done while checking for ntf!?") continue gm = GenlMsg(nl_msg) if gm.genl_cmd not in self.async_msg_ids: print("Unexpected msg id done while checking for ntf", gm) continue self.handle_ntf(nl_msg, gm) def operation_do_attributes(self, name): """ For a given operation name, find and return a supported set of attributes (as a dict). """ op = self.find_operation(name) if not op: return None return op['do']['request']['attributes'].copy() def _op(self, method, vals, dump=False): op = self.ops[method] nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK if dump: nl_flags |= Netlink.NLM_F_DUMP req_seq = random.randint(1024, 65535) msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq) fixed_header_members = [] if op.fixed_header: fixed_header_members = self.consts[op.fixed_header].members for m in fixed_header_members: value = vals.pop(m.name) if m.name in vals else 0 format = NlAttr.get_format(m.type, m.byte_order) msg += format.pack(value) for name, value in vals.items(): msg += self._add_attr(op.attr_set.name, name, value) msg = _genl_msg_finalize(msg) self.sock.send(msg, 0) done = False rsp = [] while not done: reply = self.sock.recv(128 * 1024) nms = NlMsgs(reply, attr_space=op.attr_set) for nl_msg in nms: if nl_msg.extack: self._decode_extack(msg, op.attr_set, nl_msg.extack) if nl_msg.error: raise NlError(nl_msg) if nl_msg.done: if nl_msg.extack: print("Netlink warning:") print(nl_msg) done = True break gm = GenlMsg(nl_msg, fixed_header_members) # Check if this is a reply to our request if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value: if gm.genl_cmd in self.async_msg_ids: self.handle_ntf(nl_msg, gm) continue else: print('Unexpected message: ' + repr(gm)) continue rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name) rsp_msg.update(gm.fixed_header_attrs) rsp.append(rsp_msg) if not rsp: return None if not dump and len(rsp) == 1: return rsp[0] return rsp def do(self, method, vals): return self._op(method, vals) def dump(self, method, vals): return self._op(method, vals, dump=True)