1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4import functools 5import os 6import random 7import socket 8import struct 9from struct import Struct 10import yaml 11import ipaddress 12import uuid 13 14from .nlspec import SpecFamily 15 16# 17# Generic Netlink code which should really be in some library, but I can't quickly find one. 18# 19 20 21class Netlink: 22 # Netlink socket 23 SOL_NETLINK = 270 24 25 NETLINK_ADD_MEMBERSHIP = 1 26 NETLINK_CAP_ACK = 10 27 NETLINK_EXT_ACK = 11 28 29 # Netlink message 30 NLMSG_ERROR = 2 31 NLMSG_DONE = 3 32 33 NLM_F_REQUEST = 1 34 NLM_F_ACK = 4 35 NLM_F_ROOT = 0x100 36 NLM_F_MATCH = 0x200 37 NLM_F_APPEND = 0x800 38 39 NLM_F_CAPPED = 0x100 40 NLM_F_ACK_TLVS = 0x200 41 42 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 43 44 NLA_F_NESTED = 0x8000 45 NLA_F_NET_BYTEORDER = 0x4000 46 47 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 48 49 # Genetlink defines 50 NETLINK_GENERIC = 16 51 52 GENL_ID_CTRL = 0x10 53 54 # nlctrl 55 CTRL_CMD_GETFAMILY = 3 56 57 CTRL_ATTR_FAMILY_ID = 1 58 CTRL_ATTR_FAMILY_NAME = 2 59 CTRL_ATTR_MAXATTR = 5 60 CTRL_ATTR_MCAST_GROUPS = 7 61 62 CTRL_ATTR_MCAST_GRP_NAME = 1 63 CTRL_ATTR_MCAST_GRP_ID = 2 64 65 # Extack types 66 NLMSGERR_ATTR_MSG = 1 67 NLMSGERR_ATTR_OFFS = 2 68 NLMSGERR_ATTR_COOKIE = 3 69 NLMSGERR_ATTR_POLICY = 4 70 NLMSGERR_ATTR_MISS_TYPE = 5 71 NLMSGERR_ATTR_MISS_NEST = 6 72 73 74class NlError(Exception): 75 def __init__(self, nl_msg): 76 self.nl_msg = nl_msg 77 78 def __str__(self): 79 return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" 80 81 82class NlAttr: 83 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 84 type_formats = { 85 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 86 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 87 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 88 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 89 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 90 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 91 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 92 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 93 } 94 95 def __init__(self, raw, offset): 96 self._len, self._type = struct.unpack("HH", raw[offset:offset + 4]) 97 self.type = self._type & ~Netlink.NLA_TYPE_MASK 98 self.payload_len = self._len 99 self.full_len = (self.payload_len + 3) & ~3 100 self.raw = raw[offset + 4:offset + self.payload_len] 101 102 @classmethod 103 def get_format(cls, attr_type, byte_order=None): 104 format = cls.type_formats[attr_type] 105 if byte_order: 106 return format.big if byte_order == "big-endian" \ 107 else format.little 108 return format.native 109 110 @classmethod 111 def formatted_string(cls, raw, display_hint): 112 if display_hint == 'mac': 113 formatted = ':'.join('%02x' % b for b in raw) 114 elif display_hint == 'hex': 115 formatted = bytes.hex(raw, ' ') 116 elif display_hint in [ 'ipv4', 'ipv6' ]: 117 formatted = format(ipaddress.ip_address(raw)) 118 elif display_hint == 'uuid': 119 formatted = str(uuid.UUID(bytes=raw)) 120 else: 121 formatted = raw 122 return formatted 123 124 def as_scalar(self, attr_type, byte_order=None): 125 format = self.get_format(attr_type, byte_order) 126 return format.unpack(self.raw)[0] 127 128 def as_strz(self): 129 return self.raw.decode('ascii')[:-1] 130 131 def as_bin(self): 132 return self.raw 133 134 def as_c_array(self, type): 135 format = self.get_format(type) 136 return [ x[0] for x in format.iter_unpack(self.raw) ] 137 138 def as_struct(self, members): 139 value = dict() 140 offset = 0 141 for m in members: 142 # TODO: handle non-scalar members 143 if m.type == 'binary': 144 decoded = self.raw[offset:offset+m['len']] 145 offset += m['len'] 146 elif m.type in NlAttr.type_formats: 147 format = self.get_format(m.type, m.byte_order) 148 [ decoded ] = format.unpack_from(self.raw, offset) 149 offset += format.size 150 if m.display_hint: 151 decoded = self.formatted_string(decoded, m.display_hint) 152 value[m.name] = decoded 153 return value 154 155 def __repr__(self): 156 return f"[type:{self.type} len:{self._len}] {self.raw}" 157 158 159class NlAttrs: 160 def __init__(self, msg): 161 self.attrs = [] 162 163 offset = 0 164 while offset < len(msg): 165 attr = NlAttr(msg, offset) 166 offset += attr.full_len 167 self.attrs.append(attr) 168 169 def __iter__(self): 170 yield from self.attrs 171 172 def __repr__(self): 173 msg = '' 174 for a in self.attrs: 175 if msg: 176 msg += '\n' 177 msg += repr(a) 178 return msg 179 180 181class NlMsg: 182 def __init__(self, msg, offset, attr_space=None): 183 self.hdr = msg[offset:offset + 16] 184 185 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 186 struct.unpack("IHHII", self.hdr) 187 188 self.raw = msg[offset + 16:offset + self.nl_len] 189 190 self.error = 0 191 self.done = 0 192 193 extack_off = None 194 if self.nl_type == Netlink.NLMSG_ERROR: 195 self.error = struct.unpack("i", self.raw[0:4])[0] 196 self.done = 1 197 extack_off = 20 198 elif self.nl_type == Netlink.NLMSG_DONE: 199 self.done = 1 200 extack_off = 4 201 202 self.extack = None 203 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 204 self.extack = dict() 205 extack_attrs = NlAttrs(self.raw[extack_off:]) 206 for extack in extack_attrs: 207 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 208 self.extack['msg'] = extack.as_strz() 209 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 210 self.extack['miss-type'] = extack.as_scalar('u32') 211 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 212 self.extack['miss-nest'] = extack.as_scalar('u32') 213 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 214 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 215 else: 216 if 'unknown' not in self.extack: 217 self.extack['unknown'] = [] 218 self.extack['unknown'].append(extack) 219 220 if attr_space: 221 # We don't have the ability to parse nests yet, so only do global 222 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 223 miss_type = self.extack['miss-type'] 224 if miss_type in attr_space.attrs_by_val: 225 spec = attr_space.attrs_by_val[miss_type] 226 desc = spec['name'] 227 if 'doc' in spec: 228 desc += f" ({spec['doc']})" 229 self.extack['miss-type'] = desc 230 231 def __repr__(self): 232 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" 233 if self.error: 234 msg += '\terror: ' + str(self.error) 235 if self.extack: 236 msg += '\textack: ' + repr(self.extack) 237 return msg 238 239 240class NlMsgs: 241 def __init__(self, data, attr_space=None): 242 self.msgs = [] 243 244 offset = 0 245 while offset < len(data): 246 msg = NlMsg(data, offset, attr_space=attr_space) 247 offset += msg.nl_len 248 self.msgs.append(msg) 249 250 def __iter__(self): 251 yield from self.msgs 252 253 254genl_family_name_to_id = None 255 256 257def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 258 # we prepend length in _genl_msg_finalize() 259 if seq is None: 260 seq = random.randint(1, 1024) 261 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 262 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 263 return nlmsg + genlmsg 264 265 266def _genl_msg_finalize(msg): 267 return struct.pack("I", len(msg) + 4) + msg 268 269 270def _genl_load_families(): 271 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 272 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 273 274 msg = _genl_msg(Netlink.GENL_ID_CTRL, 275 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 276 Netlink.CTRL_CMD_GETFAMILY, 1) 277 msg = _genl_msg_finalize(msg) 278 279 sock.send(msg, 0) 280 281 global genl_family_name_to_id 282 genl_family_name_to_id = dict() 283 284 while True: 285 reply = sock.recv(128 * 1024) 286 nms = NlMsgs(reply) 287 for nl_msg in nms: 288 if nl_msg.error: 289 print("Netlink error:", nl_msg.error) 290 return 291 if nl_msg.done: 292 return 293 294 gm = GenlMsg(nl_msg) 295 fam = dict() 296 for attr in gm.raw_attrs: 297 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 298 fam['id'] = attr.as_scalar('u16') 299 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 300 fam['name'] = attr.as_strz() 301 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 302 fam['maxattr'] = attr.as_scalar('u32') 303 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 304 fam['mcast'] = dict() 305 for entry in NlAttrs(attr.raw): 306 mcast_name = None 307 mcast_id = None 308 for entry_attr in NlAttrs(entry.raw): 309 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 310 mcast_name = entry_attr.as_strz() 311 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 312 mcast_id = entry_attr.as_scalar('u32') 313 if mcast_name and mcast_id is not None: 314 fam['mcast'][mcast_name] = mcast_id 315 if 'name' in fam and 'id' in fam: 316 genl_family_name_to_id[fam['name']] = fam 317 318 319class GenlMsg: 320 def __init__(self, nl_msg, fixed_header_members=[]): 321 self.nl = nl_msg 322 323 self.hdr = nl_msg.raw[0:4] 324 offset = 4 325 326 self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr) 327 328 self.fixed_header_attrs = dict() 329 for m in fixed_header_members: 330 format = NlAttr.get_format(m.type, m.byte_order) 331 decoded = format.unpack_from(nl_msg.raw, offset) 332 offset += format.size 333 self.fixed_header_attrs[m.name] = decoded[0] 334 335 self.raw = nl_msg.raw[offset:] 336 self.raw_attrs = NlAttrs(self.raw) 337 338 def __repr__(self): 339 msg = repr(self.nl) 340 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 341 for a in self.raw_attrs: 342 msg += '\t\t' + repr(a) + '\n' 343 return msg 344 345 346class GenlFamily: 347 def __init__(self, family_name): 348 self.family_name = family_name 349 350 global genl_family_name_to_id 351 if genl_family_name_to_id is None: 352 _genl_load_families() 353 354 self.genl_family = genl_family_name_to_id[family_name] 355 self.family_id = genl_family_name_to_id[family_name]['id'] 356 357 358# 359# YNL implementation details. 360# 361 362 363class YnlFamily(SpecFamily): 364 def __init__(self, def_path, schema=None): 365 super().__init__(def_path, schema) 366 367 self.include_raw = False 368 369 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) 370 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 371 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 372 373 self.async_msg_ids = set() 374 self.async_msg_queue = [] 375 376 for msg in self.msgs.values(): 377 if msg.is_async: 378 self.async_msg_ids.add(msg.rsp_value) 379 380 for op_name, op in self.ops.items(): 381 bound_f = functools.partial(self._op, op_name) 382 setattr(self, op.ident_name, bound_f) 383 384 try: 385 self.family = GenlFamily(self.yaml['name']) 386 except KeyError: 387 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 388 389 def ntf_subscribe(self, mcast_name): 390 if mcast_name not in self.family.genl_family['mcast']: 391 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 392 393 self.sock.bind((0, 0)) 394 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 395 self.family.genl_family['mcast'][mcast_name]) 396 397 def _add_attr(self, space, name, value): 398 try: 399 attr = self.attr_sets[space][name] 400 except KeyError: 401 raise Exception(f"Space '{space}' has no attribute '{name}'") 402 nl_type = attr.value 403 if attr["type"] == 'nest': 404 nl_type |= Netlink.NLA_F_NESTED 405 attr_payload = b'' 406 for subname, subvalue in value.items(): 407 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) 408 elif attr["type"] == 'flag': 409 attr_payload = b'' 410 elif attr["type"] == 'string': 411 attr_payload = str(value).encode('ascii') + b'\x00' 412 elif attr["type"] == 'binary': 413 if isinstance(value, bytes): 414 attr_payload = value 415 elif isinstance(value, str): 416 attr_payload = bytes.fromhex(value) 417 else: 418 raise Exception(f'Unknown type for binary attribute, value: {value}') 419 elif attr['type'] in NlAttr.type_formats: 420 format = NlAttr.get_format(attr['type'], attr.byte_order) 421 attr_payload = format.pack(int(value)) 422 else: 423 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 424 425 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 426 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 427 428 def _decode_enum(self, raw, attr_spec): 429 enum = self.consts[attr_spec['enum']] 430 if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']: 431 i = 0 432 value = set() 433 while raw: 434 if raw & 1: 435 value.add(enum.entries_by_val[i].name) 436 raw >>= 1 437 i += 1 438 else: 439 value = enum.entries_by_val[raw].name 440 return value 441 442 def _decode_binary(self, attr, attr_spec): 443 if attr_spec.struct_name: 444 members = self.consts[attr_spec.struct_name] 445 decoded = attr.as_struct(members) 446 for m in members: 447 if m.enum: 448 decoded[m.name] = self._decode_enum(decoded[m.name], m) 449 elif attr_spec.sub_type: 450 decoded = attr.as_c_array(attr_spec.sub_type) 451 else: 452 decoded = attr.as_bin() 453 if attr_spec.display_hint: 454 decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint) 455 return decoded 456 457 def _decode(self, attrs, space): 458 attr_space = self.attr_sets[space] 459 rsp = dict() 460 for attr in attrs: 461 try: 462 attr_spec = attr_space.attrs_by_val[attr.type] 463 except KeyError: 464 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 465 if attr_spec["type"] == 'nest': 466 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) 467 decoded = subdict 468 elif attr_spec["type"] == 'string': 469 decoded = attr.as_strz() 470 elif attr_spec["type"] == 'binary': 471 decoded = self._decode_binary(attr, attr_spec) 472 elif attr_spec["type"] == 'flag': 473 decoded = True 474 elif attr_spec["type"] in NlAttr.type_formats: 475 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 476 else: 477 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 478 479 if 'enum' in attr_spec: 480 decoded = self._decode_enum(decoded, attr_spec) 481 482 if not attr_spec.is_multi: 483 rsp[attr_spec['name']] = decoded 484 elif attr_spec.name in rsp: 485 rsp[attr_spec.name].append(decoded) 486 else: 487 rsp[attr_spec.name] = [decoded] 488 489 return rsp 490 491 def _decode_extack_path(self, attrs, attr_set, offset, target): 492 for attr in attrs: 493 try: 494 attr_spec = attr_set.attrs_by_val[attr.type] 495 except KeyError: 496 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 497 if offset > target: 498 break 499 if offset == target: 500 return '.' + attr_spec.name 501 502 if offset + attr.full_len <= target: 503 offset += attr.full_len 504 continue 505 if attr_spec['type'] != 'nest': 506 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 507 offset += 4 508 subpath = self._decode_extack_path(NlAttrs(attr.raw), 509 self.attr_sets[attr_spec['nested-attributes']], 510 offset, target) 511 if subpath is None: 512 return None 513 return '.' + attr_spec.name + subpath 514 515 return None 516 517 def _decode_extack(self, request, attr_space, extack): 518 if 'bad-attr-offs' not in extack: 519 return 520 521 genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space)) 522 path = self._decode_extack_path(genl_req.raw_attrs, attr_space, 523 20, extack['bad-attr-offs']) 524 if path: 525 del extack['bad-attr-offs'] 526 extack['bad-attr'] = path 527 528 def handle_ntf(self, nl_msg, genl_msg): 529 msg = dict() 530 if self.include_raw: 531 msg['nlmsg'] = nl_msg 532 msg['genlmsg'] = genl_msg 533 op = self.rsp_by_value[genl_msg.genl_cmd] 534 msg['name'] = op['name'] 535 msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name) 536 self.async_msg_queue.append(msg) 537 538 def check_ntf(self): 539 while True: 540 try: 541 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) 542 except BlockingIOError: 543 return 544 545 nms = NlMsgs(reply) 546 for nl_msg in nms: 547 if nl_msg.error: 548 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 549 print(nl_msg) 550 continue 551 if nl_msg.done: 552 print("Netlink done while checking for ntf!?") 553 continue 554 555 gm = GenlMsg(nl_msg) 556 if gm.genl_cmd not in self.async_msg_ids: 557 print("Unexpected msg id done while checking for ntf", gm) 558 continue 559 560 self.handle_ntf(nl_msg, gm) 561 562 def operation_do_attributes(self, name): 563 """ 564 For a given operation name, find and return a supported 565 set of attributes (as a dict). 566 """ 567 op = self.find_operation(name) 568 if not op: 569 return None 570 571 return op['do']['request']['attributes'].copy() 572 573 def _op(self, method, vals, dump=False): 574 op = self.ops[method] 575 576 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 577 if dump: 578 nl_flags |= Netlink.NLM_F_DUMP 579 580 req_seq = random.randint(1024, 65535) 581 msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq) 582 fixed_header_members = [] 583 if op.fixed_header: 584 fixed_header_members = self.consts[op.fixed_header].members 585 for m in fixed_header_members: 586 value = vals.pop(m.name) if m.name in vals else 0 587 format = NlAttr.get_format(m.type, m.byte_order) 588 msg += format.pack(value) 589 for name, value in vals.items(): 590 msg += self._add_attr(op.attr_set.name, name, value) 591 msg = _genl_msg_finalize(msg) 592 593 self.sock.send(msg, 0) 594 595 done = False 596 rsp = [] 597 while not done: 598 reply = self.sock.recv(128 * 1024) 599 nms = NlMsgs(reply, attr_space=op.attr_set) 600 for nl_msg in nms: 601 if nl_msg.extack: 602 self._decode_extack(msg, op.attr_set, nl_msg.extack) 603 604 if nl_msg.error: 605 raise NlError(nl_msg) 606 if nl_msg.done: 607 if nl_msg.extack: 608 print("Netlink warning:") 609 print(nl_msg) 610 done = True 611 break 612 613 gm = GenlMsg(nl_msg, fixed_header_members) 614 # Check if this is a reply to our request 615 if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value: 616 if gm.genl_cmd in self.async_msg_ids: 617 self.handle_ntf(nl_msg, gm) 618 continue 619 else: 620 print('Unexpected message: ' + repr(gm)) 621 continue 622 623 rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name) 624 rsp_msg.update(gm.fixed_header_attrs) 625 rsp.append(rsp_msg) 626 627 if not rsp: 628 return None 629 if not dump and len(rsp) == 1: 630 return rsp[0] 631 return rsp 632 633 def do(self, method, vals): 634 return self._op(method, vals) 635 636 def dump(self, method, vals): 637 return self._op(method, vals, dump=True) 638