1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4import functools 5import os 6import random 7import socket 8import struct 9from struct import Struct 10import yaml 11import ipaddress 12import uuid 13 14from .nlspec import SpecFamily 15 16# 17# Generic Netlink code which should really be in some library, but I can't quickly find one. 18# 19 20 21class Netlink: 22 # Netlink socket 23 SOL_NETLINK = 270 24 25 NETLINK_ADD_MEMBERSHIP = 1 26 NETLINK_CAP_ACK = 10 27 NETLINK_EXT_ACK = 11 28 29 # Netlink message 30 NLMSG_ERROR = 2 31 NLMSG_DONE = 3 32 33 NLM_F_REQUEST = 1 34 NLM_F_ACK = 4 35 NLM_F_ROOT = 0x100 36 NLM_F_MATCH = 0x200 37 NLM_F_APPEND = 0x800 38 39 NLM_F_CAPPED = 0x100 40 NLM_F_ACK_TLVS = 0x200 41 42 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 43 44 NLA_F_NESTED = 0x8000 45 NLA_F_NET_BYTEORDER = 0x4000 46 47 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 48 49 # Genetlink defines 50 NETLINK_GENERIC = 16 51 52 GENL_ID_CTRL = 0x10 53 54 # nlctrl 55 CTRL_CMD_GETFAMILY = 3 56 57 CTRL_ATTR_FAMILY_ID = 1 58 CTRL_ATTR_FAMILY_NAME = 2 59 CTRL_ATTR_MAXATTR = 5 60 CTRL_ATTR_MCAST_GROUPS = 7 61 62 CTRL_ATTR_MCAST_GRP_NAME = 1 63 CTRL_ATTR_MCAST_GRP_ID = 2 64 65 # Extack types 66 NLMSGERR_ATTR_MSG = 1 67 NLMSGERR_ATTR_OFFS = 2 68 NLMSGERR_ATTR_COOKIE = 3 69 NLMSGERR_ATTR_POLICY = 4 70 NLMSGERR_ATTR_MISS_TYPE = 5 71 NLMSGERR_ATTR_MISS_NEST = 6 72 73 74class NlError(Exception): 75 def __init__(self, nl_msg): 76 self.nl_msg = nl_msg 77 78 def __str__(self): 79 return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" 80 81 82class NlAttr: 83 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 84 type_formats = { 85 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 86 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 87 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 88 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 89 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 90 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 91 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 92 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 93 } 94 95 def __init__(self, raw, offset): 96 self._len, self._type = struct.unpack("HH", raw[offset:offset + 4]) 97 self.type = self._type & ~Netlink.NLA_TYPE_MASK 98 self.payload_len = self._len 99 self.full_len = (self.payload_len + 3) & ~3 100 self.raw = raw[offset + 4:offset + self.payload_len] 101 102 @classmethod 103 def get_format(cls, attr_type, byte_order=None): 104 format = cls.type_formats[attr_type] 105 if byte_order: 106 return format.big if byte_order == "big-endian" \ 107 else format.little 108 return format.native 109 110 @classmethod 111 def formatted_string(cls, raw, display_hint): 112 if display_hint == 'mac': 113 formatted = ':'.join('%02x' % b for b in raw) 114 elif display_hint == 'hex': 115 formatted = bytes.hex(raw, ' ') 116 elif display_hint in [ 'ipv4', 'ipv6' ]: 117 formatted = format(ipaddress.ip_address(raw)) 118 elif display_hint == 'uuid': 119 formatted = str(uuid.UUID(bytes=raw)) 120 else: 121 formatted = raw 122 return formatted 123 124 def as_scalar(self, attr_type, byte_order=None): 125 format = self.get_format(attr_type, byte_order) 126 return format.unpack(self.raw)[0] 127 128 def as_strz(self): 129 return self.raw.decode('ascii')[:-1] 130 131 def as_bin(self): 132 return self.raw 133 134 def as_c_array(self, type): 135 format = self.get_format(type) 136 return [ x[0] for x in format.iter_unpack(self.raw) ] 137 138 def as_struct(self, members): 139 value = dict() 140 offset = 0 141 for m in members: 142 # TODO: handle non-scalar members 143 if m.type == 'binary': 144 decoded = self.raw[offset:offset+m['len']] 145 offset += m['len'] 146 elif m.type in NlAttr.type_formats: 147 format = self.get_format(m.type, m.byte_order) 148 [ decoded ] = format.unpack_from(self.raw, offset) 149 offset += format.size 150 if m.display_hint: 151 decoded = self.formatted_string(decoded, m.display_hint) 152 value[m.name] = decoded 153 return value 154 155 def __repr__(self): 156 return f"[type:{self.type} len:{self._len}] {self.raw}" 157 158 159class NlAttrs: 160 def __init__(self, msg): 161 self.attrs = [] 162 163 offset = 0 164 while offset < len(msg): 165 attr = NlAttr(msg, offset) 166 offset += attr.full_len 167 self.attrs.append(attr) 168 169 def __iter__(self): 170 yield from self.attrs 171 172 def __repr__(self): 173 msg = '' 174 for a in self.attrs: 175 if msg: 176 msg += '\n' 177 msg += repr(a) 178 return msg 179 180 181class NlMsg: 182 def __init__(self, msg, offset, attr_space=None): 183 self.hdr = msg[offset:offset + 16] 184 185 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 186 struct.unpack("IHHII", self.hdr) 187 188 self.raw = msg[offset + 16:offset + self.nl_len] 189 190 self.error = 0 191 self.done = 0 192 193 extack_off = None 194 if self.nl_type == Netlink.NLMSG_ERROR: 195 self.error = struct.unpack("i", self.raw[0:4])[0] 196 self.done = 1 197 extack_off = 20 198 elif self.nl_type == Netlink.NLMSG_DONE: 199 self.done = 1 200 extack_off = 4 201 202 self.extack = None 203 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 204 self.extack = dict() 205 extack_attrs = NlAttrs(self.raw[extack_off:]) 206 for extack in extack_attrs: 207 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 208 self.extack['msg'] = extack.as_strz() 209 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 210 self.extack['miss-type'] = extack.as_scalar('u32') 211 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 212 self.extack['miss-nest'] = extack.as_scalar('u32') 213 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 214 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 215 else: 216 if 'unknown' not in self.extack: 217 self.extack['unknown'] = [] 218 self.extack['unknown'].append(extack) 219 220 if attr_space: 221 # We don't have the ability to parse nests yet, so only do global 222 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 223 miss_type = self.extack['miss-type'] 224 if miss_type in attr_space.attrs_by_val: 225 spec = attr_space.attrs_by_val[miss_type] 226 desc = spec['name'] 227 if 'doc' in spec: 228 desc += f" ({spec['doc']})" 229 self.extack['miss-type'] = desc 230 231 def __repr__(self): 232 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" 233 if self.error: 234 msg += '\terror: ' + str(self.error) 235 if self.extack: 236 msg += '\textack: ' + repr(self.extack) 237 return msg 238 239 240class NlMsgs: 241 def __init__(self, data, attr_space=None): 242 self.msgs = [] 243 244 offset = 0 245 while offset < len(data): 246 msg = NlMsg(data, offset, attr_space=attr_space) 247 offset += msg.nl_len 248 self.msgs.append(msg) 249 250 def __iter__(self): 251 yield from self.msgs 252 253 254genl_family_name_to_id = None 255 256 257def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 258 # we prepend length in _genl_msg_finalize() 259 if seq is None: 260 seq = random.randint(1, 1024) 261 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 262 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 263 return nlmsg + genlmsg 264 265 266def _genl_msg_finalize(msg): 267 return struct.pack("I", len(msg) + 4) + msg 268 269 270def _genl_load_families(): 271 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 272 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 273 274 msg = _genl_msg(Netlink.GENL_ID_CTRL, 275 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 276 Netlink.CTRL_CMD_GETFAMILY, 1) 277 msg = _genl_msg_finalize(msg) 278 279 sock.send(msg, 0) 280 281 global genl_family_name_to_id 282 genl_family_name_to_id = dict() 283 284 while True: 285 reply = sock.recv(128 * 1024) 286 nms = NlMsgs(reply) 287 for nl_msg in nms: 288 if nl_msg.error: 289 print("Netlink error:", nl_msg.error) 290 return 291 if nl_msg.done: 292 return 293 294 gm = GenlMsg(nl_msg) 295 fam = dict() 296 for attr in gm.raw_attrs: 297 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 298 fam['id'] = attr.as_scalar('u16') 299 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 300 fam['name'] = attr.as_strz() 301 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 302 fam['maxattr'] = attr.as_scalar('u32') 303 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 304 fam['mcast'] = dict() 305 for entry in NlAttrs(attr.raw): 306 mcast_name = None 307 mcast_id = None 308 for entry_attr in NlAttrs(entry.raw): 309 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 310 mcast_name = entry_attr.as_strz() 311 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 312 mcast_id = entry_attr.as_scalar('u32') 313 if mcast_name and mcast_id is not None: 314 fam['mcast'][mcast_name] = mcast_id 315 if 'name' in fam and 'id' in fam: 316 genl_family_name_to_id[fam['name']] = fam 317 318 319class GenlMsg: 320 def __init__(self, nl_msg, fixed_header_members=[]): 321 self.nl = nl_msg 322 323 self.hdr = nl_msg.raw[0:4] 324 offset = 4 325 326 self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr) 327 328 self.fixed_header_attrs = dict() 329 for m in fixed_header_members: 330 format = NlAttr.get_format(m.type, m.byte_order) 331 decoded = format.unpack_from(nl_msg.raw, offset) 332 offset += format.size 333 self.fixed_header_attrs[m.name] = decoded[0] 334 335 self.raw = nl_msg.raw[offset:] 336 self.raw_attrs = NlAttrs(self.raw) 337 338 def __repr__(self): 339 msg = repr(self.nl) 340 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 341 for a in self.raw_attrs: 342 msg += '\t\t' + repr(a) + '\n' 343 return msg 344 345 346class GenlFamily: 347 def __init__(self, family_name): 348 self.family_name = family_name 349 350 global genl_family_name_to_id 351 if genl_family_name_to_id is None: 352 _genl_load_families() 353 354 self.genl_family = genl_family_name_to_id[family_name] 355 self.family_id = genl_family_name_to_id[family_name]['id'] 356 357 358# 359# YNL implementation details. 360# 361 362 363class YnlFamily(SpecFamily): 364 def __init__(self, def_path, schema=None): 365 super().__init__(def_path, schema) 366 367 self.include_raw = False 368 369 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) 370 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 371 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 372 373 self.async_msg_ids = set() 374 self.async_msg_queue = [] 375 376 for msg in self.msgs.values(): 377 if msg.is_async: 378 self.async_msg_ids.add(msg.rsp_value) 379 380 for op_name, op in self.ops.items(): 381 bound_f = functools.partial(self._op, op_name) 382 setattr(self, op.ident_name, bound_f) 383 384 try: 385 self.family = GenlFamily(self.yaml['name']) 386 except KeyError: 387 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 388 389 def ntf_subscribe(self, mcast_name): 390 if mcast_name not in self.family.genl_family['mcast']: 391 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 392 393 self.sock.bind((0, 0)) 394 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 395 self.family.genl_family['mcast'][mcast_name]) 396 397 def _add_attr(self, space, name, value): 398 try: 399 attr = self.attr_sets[space][name] 400 except KeyError: 401 raise Exception(f"Space '{space}' has no attribute '{name}'") 402 nl_type = attr.value 403 if attr["type"] == 'nest': 404 nl_type |= Netlink.NLA_F_NESTED 405 attr_payload = b'' 406 for subname, subvalue in value.items(): 407 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) 408 elif attr["type"] == 'flag': 409 attr_payload = b'' 410 elif attr["type"] == 'string': 411 attr_payload = str(value).encode('ascii') + b'\x00' 412 elif attr["type"] == 'binary': 413 attr_payload = bytes.fromhex(value) 414 elif attr['type'] in NlAttr.type_formats: 415 format = NlAttr.get_format(attr['type'], attr.byte_order) 416 attr_payload = format.pack(int(value)) 417 else: 418 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 419 420 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 421 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 422 423 def _decode_enum(self, raw, attr_spec): 424 enum = self.consts[attr_spec['enum']] 425 if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']: 426 i = 0 427 value = set() 428 while raw: 429 if raw & 1: 430 value.add(enum.entries_by_val[i].name) 431 raw >>= 1 432 i += 1 433 else: 434 value = enum.entries_by_val[raw].name 435 return value 436 437 def _decode_binary(self, attr, attr_spec): 438 if attr_spec.struct_name: 439 members = self.consts[attr_spec.struct_name] 440 decoded = attr.as_struct(members) 441 for m in members: 442 if m.enum: 443 decoded[m.name] = self._decode_enum(decoded[m.name], m) 444 elif attr_spec.sub_type: 445 decoded = attr.as_c_array(attr_spec.sub_type) 446 else: 447 decoded = attr.as_bin() 448 if attr_spec.display_hint: 449 decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint) 450 return decoded 451 452 def _decode(self, attrs, space): 453 attr_space = self.attr_sets[space] 454 rsp = dict() 455 for attr in attrs: 456 try: 457 attr_spec = attr_space.attrs_by_val[attr.type] 458 except KeyError: 459 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 460 if attr_spec["type"] == 'nest': 461 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) 462 decoded = subdict 463 elif attr_spec["type"] == 'string': 464 decoded = attr.as_strz() 465 elif attr_spec["type"] == 'binary': 466 decoded = self._decode_binary(attr, attr_spec) 467 elif attr_spec["type"] == 'flag': 468 decoded = True 469 elif attr_spec["type"] in NlAttr.type_formats: 470 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 471 else: 472 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 473 474 if 'enum' in attr_spec: 475 decoded = self._decode_enum(decoded, attr_spec) 476 477 if not attr_spec.is_multi: 478 rsp[attr_spec['name']] = decoded 479 elif attr_spec.name in rsp: 480 rsp[attr_spec.name].append(decoded) 481 else: 482 rsp[attr_spec.name] = [decoded] 483 484 return rsp 485 486 def _decode_extack_path(self, attrs, attr_set, offset, target): 487 for attr in attrs: 488 try: 489 attr_spec = attr_set.attrs_by_val[attr.type] 490 except KeyError: 491 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 492 if offset > target: 493 break 494 if offset == target: 495 return '.' + attr_spec.name 496 497 if offset + attr.full_len <= target: 498 offset += attr.full_len 499 continue 500 if attr_spec['type'] != 'nest': 501 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 502 offset += 4 503 subpath = self._decode_extack_path(NlAttrs(attr.raw), 504 self.attr_sets[attr_spec['nested-attributes']], 505 offset, target) 506 if subpath is None: 507 return None 508 return '.' + attr_spec.name + subpath 509 510 return None 511 512 def _decode_extack(self, request, attr_space, extack): 513 if 'bad-attr-offs' not in extack: 514 return 515 516 genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space)) 517 path = self._decode_extack_path(genl_req.raw_attrs, attr_space, 518 20, extack['bad-attr-offs']) 519 if path: 520 del extack['bad-attr-offs'] 521 extack['bad-attr'] = path 522 523 def handle_ntf(self, nl_msg, genl_msg): 524 msg = dict() 525 if self.include_raw: 526 msg['nlmsg'] = nl_msg 527 msg['genlmsg'] = genl_msg 528 op = self.rsp_by_value[genl_msg.genl_cmd] 529 msg['name'] = op['name'] 530 msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name) 531 self.async_msg_queue.append(msg) 532 533 def check_ntf(self): 534 while True: 535 try: 536 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) 537 except BlockingIOError: 538 return 539 540 nms = NlMsgs(reply) 541 for nl_msg in nms: 542 if nl_msg.error: 543 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 544 print(nl_msg) 545 continue 546 if nl_msg.done: 547 print("Netlink done while checking for ntf!?") 548 continue 549 550 gm = GenlMsg(nl_msg) 551 if gm.genl_cmd not in self.async_msg_ids: 552 print("Unexpected msg id done while checking for ntf", gm) 553 continue 554 555 self.handle_ntf(nl_msg, gm) 556 557 def operation_do_attributes(self, name): 558 """ 559 For a given operation name, find and return a supported 560 set of attributes (as a dict). 561 """ 562 op = self.find_operation(name) 563 if not op: 564 return None 565 566 return op['do']['request']['attributes'].copy() 567 568 def _op(self, method, vals, dump=False): 569 op = self.ops[method] 570 571 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 572 if dump: 573 nl_flags |= Netlink.NLM_F_DUMP 574 575 req_seq = random.randint(1024, 65535) 576 msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq) 577 fixed_header_members = [] 578 if op.fixed_header: 579 fixed_header_members = self.consts[op.fixed_header].members 580 for m in fixed_header_members: 581 value = vals.pop(m.name) if m.name in vals else 0 582 format = NlAttr.get_format(m.type, m.byte_order) 583 msg += format.pack(value) 584 for name, value in vals.items(): 585 msg += self._add_attr(op.attr_set.name, name, value) 586 msg = _genl_msg_finalize(msg) 587 588 self.sock.send(msg, 0) 589 590 done = False 591 rsp = [] 592 while not done: 593 reply = self.sock.recv(128 * 1024) 594 nms = NlMsgs(reply, attr_space=op.attr_set) 595 for nl_msg in nms: 596 if nl_msg.extack: 597 self._decode_extack(msg, op.attr_set, nl_msg.extack) 598 599 if nl_msg.error: 600 raise NlError(nl_msg) 601 if nl_msg.done: 602 if nl_msg.extack: 603 print("Netlink warning:") 604 print(nl_msg) 605 done = True 606 break 607 608 gm = GenlMsg(nl_msg, fixed_header_members) 609 # Check if this is a reply to our request 610 if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value: 611 if gm.genl_cmd in self.async_msg_ids: 612 self.handle_ntf(nl_msg, gm) 613 continue 614 else: 615 print('Unexpected message: ' + repr(gm)) 616 continue 617 618 rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name) 619 rsp_msg.update(gm.fixed_header_attrs) 620 rsp.append(rsp_msg) 621 622 if not rsp: 623 return None 624 if not dump and len(rsp) == 1: 625 return rsp[0] 626 return rsp 627 628 def do(self, method, vals): 629 return self._op(method, vals) 630 631 def dump(self, method, vals): 632 return self._op(method, vals, dump=True) 633