1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4import functools 5import os 6import random 7import socket 8import struct 9from struct import Struct 10import yaml 11 12from .nlspec import SpecFamily 13 14# 15# Generic Netlink code which should really be in some library, but I can't quickly find one. 16# 17 18 19class Netlink: 20 # Netlink socket 21 SOL_NETLINK = 270 22 23 NETLINK_ADD_MEMBERSHIP = 1 24 NETLINK_CAP_ACK = 10 25 NETLINK_EXT_ACK = 11 26 27 # Netlink message 28 NLMSG_ERROR = 2 29 NLMSG_DONE = 3 30 31 NLM_F_REQUEST = 1 32 NLM_F_ACK = 4 33 NLM_F_ROOT = 0x100 34 NLM_F_MATCH = 0x200 35 NLM_F_APPEND = 0x800 36 37 NLM_F_CAPPED = 0x100 38 NLM_F_ACK_TLVS = 0x200 39 40 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 41 42 NLA_F_NESTED = 0x8000 43 NLA_F_NET_BYTEORDER = 0x4000 44 45 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 46 47 # Genetlink defines 48 NETLINK_GENERIC = 16 49 50 GENL_ID_CTRL = 0x10 51 52 # nlctrl 53 CTRL_CMD_GETFAMILY = 3 54 55 CTRL_ATTR_FAMILY_ID = 1 56 CTRL_ATTR_FAMILY_NAME = 2 57 CTRL_ATTR_MAXATTR = 5 58 CTRL_ATTR_MCAST_GROUPS = 7 59 60 CTRL_ATTR_MCAST_GRP_NAME = 1 61 CTRL_ATTR_MCAST_GRP_ID = 2 62 63 # Extack types 64 NLMSGERR_ATTR_MSG = 1 65 NLMSGERR_ATTR_OFFS = 2 66 NLMSGERR_ATTR_COOKIE = 3 67 NLMSGERR_ATTR_POLICY = 4 68 NLMSGERR_ATTR_MISS_TYPE = 5 69 NLMSGERR_ATTR_MISS_NEST = 6 70 71 72class NlError(Exception): 73 def __init__(self, nl_msg): 74 self.nl_msg = nl_msg 75 76 def __str__(self): 77 return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" 78 79 80class NlAttr: 81 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 82 type_formats = { 83 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 84 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 85 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 86 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 87 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 88 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 89 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 90 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 91 } 92 93 def __init__(self, raw, offset): 94 self._len, self._type = struct.unpack("HH", raw[offset:offset + 4]) 95 self.type = self._type & ~Netlink.NLA_TYPE_MASK 96 self.payload_len = self._len 97 self.full_len = (self.payload_len + 3) & ~3 98 self.raw = raw[offset + 4:offset + self.payload_len] 99 100 @classmethod 101 def get_format(cls, attr_type, byte_order=None): 102 format = cls.type_formats[attr_type] 103 if byte_order: 104 return format.big if byte_order == "big-endian" \ 105 else format.little 106 return format.native 107 108 def as_scalar(self, attr_type, byte_order=None): 109 format = self.get_format(attr_type, byte_order) 110 return format.unpack(self.raw)[0] 111 112 def as_strz(self): 113 return self.raw.decode('ascii')[:-1] 114 115 def as_bin(self): 116 return self.raw 117 118 def as_c_array(self, type): 119 format = self.get_format(type) 120 return [ x[0] for x in format.iter_unpack(self.raw) ] 121 122 def as_struct(self, members): 123 value = dict() 124 offset = 0 125 for m in members: 126 # TODO: handle non-scalar members 127 format = self.get_format(m.type, m.byte_order) 128 decoded = format.unpack_from(self.raw, offset) 129 offset += format.size 130 value[m.name] = decoded[0] 131 return value 132 133 def __repr__(self): 134 return f"[type:{self.type} len:{self._len}] {self.raw}" 135 136 137class NlAttrs: 138 def __init__(self, msg): 139 self.attrs = [] 140 141 offset = 0 142 while offset < len(msg): 143 attr = NlAttr(msg, offset) 144 offset += attr.full_len 145 self.attrs.append(attr) 146 147 def __iter__(self): 148 yield from self.attrs 149 150 def __repr__(self): 151 msg = '' 152 for a in self.attrs: 153 if msg: 154 msg += '\n' 155 msg += repr(a) 156 return msg 157 158 159class NlMsg: 160 def __init__(self, msg, offset, attr_space=None): 161 self.hdr = msg[offset:offset + 16] 162 163 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 164 struct.unpack("IHHII", self.hdr) 165 166 self.raw = msg[offset + 16:offset + self.nl_len] 167 168 self.error = 0 169 self.done = 0 170 171 extack_off = None 172 if self.nl_type == Netlink.NLMSG_ERROR: 173 self.error = struct.unpack("i", self.raw[0:4])[0] 174 self.done = 1 175 extack_off = 20 176 elif self.nl_type == Netlink.NLMSG_DONE: 177 self.done = 1 178 extack_off = 4 179 180 self.extack = None 181 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 182 self.extack = dict() 183 extack_attrs = NlAttrs(self.raw[extack_off:]) 184 for extack in extack_attrs: 185 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 186 self.extack['msg'] = extack.as_strz() 187 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 188 self.extack['miss-type'] = extack.as_scalar('u32') 189 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 190 self.extack['miss-nest'] = extack.as_scalar('u32') 191 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 192 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 193 else: 194 if 'unknown' not in self.extack: 195 self.extack['unknown'] = [] 196 self.extack['unknown'].append(extack) 197 198 if attr_space: 199 # We don't have the ability to parse nests yet, so only do global 200 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 201 miss_type = self.extack['miss-type'] 202 if miss_type in attr_space.attrs_by_val: 203 spec = attr_space.attrs_by_val[miss_type] 204 desc = spec['name'] 205 if 'doc' in spec: 206 desc += f" ({spec['doc']})" 207 self.extack['miss-type'] = desc 208 209 def __repr__(self): 210 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" 211 if self.error: 212 msg += '\terror: ' + str(self.error) 213 if self.extack: 214 msg += '\textack: ' + repr(self.extack) 215 return msg 216 217 218class NlMsgs: 219 def __init__(self, data, attr_space=None): 220 self.msgs = [] 221 222 offset = 0 223 while offset < len(data): 224 msg = NlMsg(data, offset, attr_space=attr_space) 225 offset += msg.nl_len 226 self.msgs.append(msg) 227 228 def __iter__(self): 229 yield from self.msgs 230 231 232genl_family_name_to_id = None 233 234 235def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 236 # we prepend length in _genl_msg_finalize() 237 if seq is None: 238 seq = random.randint(1, 1024) 239 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 240 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 241 return nlmsg + genlmsg 242 243 244def _genl_msg_finalize(msg): 245 return struct.pack("I", len(msg) + 4) + msg 246 247 248def _genl_load_families(): 249 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 250 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 251 252 msg = _genl_msg(Netlink.GENL_ID_CTRL, 253 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 254 Netlink.CTRL_CMD_GETFAMILY, 1) 255 msg = _genl_msg_finalize(msg) 256 257 sock.send(msg, 0) 258 259 global genl_family_name_to_id 260 genl_family_name_to_id = dict() 261 262 while True: 263 reply = sock.recv(128 * 1024) 264 nms = NlMsgs(reply) 265 for nl_msg in nms: 266 if nl_msg.error: 267 print("Netlink error:", nl_msg.error) 268 return 269 if nl_msg.done: 270 return 271 272 gm = GenlMsg(nl_msg) 273 fam = dict() 274 for attr in gm.raw_attrs: 275 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 276 fam['id'] = attr.as_scalar('u16') 277 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 278 fam['name'] = attr.as_strz() 279 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 280 fam['maxattr'] = attr.as_scalar('u32') 281 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 282 fam['mcast'] = dict() 283 for entry in NlAttrs(attr.raw): 284 mcast_name = None 285 mcast_id = None 286 for entry_attr in NlAttrs(entry.raw): 287 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 288 mcast_name = entry_attr.as_strz() 289 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 290 mcast_id = entry_attr.as_scalar('u32') 291 if mcast_name and mcast_id is not None: 292 fam['mcast'][mcast_name] = mcast_id 293 if 'name' in fam and 'id' in fam: 294 genl_family_name_to_id[fam['name']] = fam 295 296 297class GenlMsg: 298 def __init__(self, nl_msg, fixed_header_members=[]): 299 self.nl = nl_msg 300 301 self.hdr = nl_msg.raw[0:4] 302 offset = 4 303 304 self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr) 305 306 self.fixed_header_attrs = dict() 307 for m in fixed_header_members: 308 format = NlAttr.get_format(m.type, m.byte_order) 309 decoded = format.unpack_from(nl_msg.raw, offset) 310 offset += format.size 311 self.fixed_header_attrs[m.name] = decoded[0] 312 313 self.raw = nl_msg.raw[offset:] 314 self.raw_attrs = NlAttrs(self.raw) 315 316 def __repr__(self): 317 msg = repr(self.nl) 318 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 319 for a in self.raw_attrs: 320 msg += '\t\t' + repr(a) + '\n' 321 return msg 322 323 324class GenlFamily: 325 def __init__(self, family_name): 326 self.family_name = family_name 327 328 global genl_family_name_to_id 329 if genl_family_name_to_id is None: 330 _genl_load_families() 331 332 self.genl_family = genl_family_name_to_id[family_name] 333 self.family_id = genl_family_name_to_id[family_name]['id'] 334 335 336# 337# YNL implementation details. 338# 339 340 341class YnlFamily(SpecFamily): 342 def __init__(self, def_path, schema=None): 343 super().__init__(def_path, schema) 344 345 self.include_raw = False 346 347 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) 348 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 349 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 350 351 self.async_msg_ids = set() 352 self.async_msg_queue = [] 353 354 for msg in self.msgs.values(): 355 if msg.is_async: 356 self.async_msg_ids.add(msg.rsp_value) 357 358 for op_name, op in self.ops.items(): 359 bound_f = functools.partial(self._op, op_name) 360 setattr(self, op.ident_name, bound_f) 361 362 try: 363 self.family = GenlFamily(self.yaml['name']) 364 except KeyError: 365 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 366 367 def ntf_subscribe(self, mcast_name): 368 if mcast_name not in self.family.genl_family['mcast']: 369 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 370 371 self.sock.bind((0, 0)) 372 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 373 self.family.genl_family['mcast'][mcast_name]) 374 375 def _add_attr(self, space, name, value): 376 attr = self.attr_sets[space][name] 377 nl_type = attr.value 378 if attr["type"] == 'nest': 379 nl_type |= Netlink.NLA_F_NESTED 380 attr_payload = b'' 381 for subname, subvalue in value.items(): 382 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) 383 elif attr["type"] == 'flag': 384 attr_payload = b'' 385 elif attr["type"] == 'string': 386 attr_payload = str(value).encode('ascii') + b'\x00' 387 elif attr["type"] == 'binary': 388 attr_payload = value 389 elif attr['type'] in NlAttr.type_formats: 390 format = NlAttr.get_format(attr['type'], attr.byte_order) 391 attr_payload = format.pack(int(value)) 392 else: 393 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 394 395 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 396 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 397 398 def _decode_enum(self, rsp, attr_spec): 399 raw = rsp[attr_spec['name']] 400 enum = self.consts[attr_spec['enum']] 401 i = attr_spec.get('value-start', 0) 402 if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']: 403 value = set() 404 while raw: 405 if raw & 1: 406 value.add(enum.entries_by_val[i].name) 407 raw >>= 1 408 i += 1 409 else: 410 value = enum.entries_by_val[raw - i].name 411 rsp[attr_spec['name']] = value 412 413 def _decode_binary(self, attr, attr_spec): 414 if attr_spec.struct_name: 415 members = self.consts[attr_spec.struct_name] 416 decoded = attr.as_struct(members) 417 for m in members: 418 if m.enum: 419 self._decode_enum(decoded, m) 420 elif attr_spec.sub_type: 421 decoded = attr.as_c_array(attr_spec.sub_type) 422 else: 423 decoded = attr.as_bin() 424 return decoded 425 426 def _decode(self, attrs, space): 427 attr_space = self.attr_sets[space] 428 rsp = dict() 429 for attr in attrs: 430 attr_spec = attr_space.attrs_by_val[attr.type] 431 if attr_spec["type"] == 'nest': 432 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) 433 decoded = subdict 434 elif attr_spec["type"] == 'string': 435 decoded = attr.as_strz() 436 elif attr_spec["type"] == 'binary': 437 decoded = self._decode_binary(attr, attr_spec) 438 elif attr_spec["type"] == 'flag': 439 decoded = True 440 elif attr_spec["type"] in NlAttr.type_formats: 441 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 442 else: 443 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 444 445 if not attr_spec.is_multi: 446 rsp[attr_spec['name']] = decoded 447 elif attr_spec.name in rsp: 448 rsp[attr_spec.name].append(decoded) 449 else: 450 rsp[attr_spec.name] = [decoded] 451 452 if 'enum' in attr_spec: 453 self._decode_enum(rsp, attr_spec) 454 return rsp 455 456 def _decode_extack_path(self, attrs, attr_set, offset, target): 457 for attr in attrs: 458 attr_spec = attr_set.attrs_by_val[attr.type] 459 if offset > target: 460 break 461 if offset == target: 462 return '.' + attr_spec.name 463 464 if offset + attr.full_len <= target: 465 offset += attr.full_len 466 continue 467 if attr_spec['type'] != 'nest': 468 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 469 offset += 4 470 subpath = self._decode_extack_path(NlAttrs(attr.raw), 471 self.attr_sets[attr_spec['nested-attributes']], 472 offset, target) 473 if subpath is None: 474 return None 475 return '.' + attr_spec.name + subpath 476 477 return None 478 479 def _decode_extack(self, request, attr_space, extack): 480 if 'bad-attr-offs' not in extack: 481 return 482 483 genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space)) 484 path = self._decode_extack_path(genl_req.raw_attrs, attr_space, 485 20, extack['bad-attr-offs']) 486 if path: 487 del extack['bad-attr-offs'] 488 extack['bad-attr'] = path 489 490 def handle_ntf(self, nl_msg, genl_msg): 491 msg = dict() 492 if self.include_raw: 493 msg['nlmsg'] = nl_msg 494 msg['genlmsg'] = genl_msg 495 op = self.rsp_by_value[genl_msg.genl_cmd] 496 msg['name'] = op['name'] 497 msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name) 498 self.async_msg_queue.append(msg) 499 500 def check_ntf(self): 501 while True: 502 try: 503 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) 504 except BlockingIOError: 505 return 506 507 nms = NlMsgs(reply) 508 for nl_msg in nms: 509 if nl_msg.error: 510 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 511 print(nl_msg) 512 continue 513 if nl_msg.done: 514 print("Netlink done while checking for ntf!?") 515 continue 516 517 gm = GenlMsg(nl_msg) 518 if gm.genl_cmd not in self.async_msg_ids: 519 print("Unexpected msg id done while checking for ntf", gm) 520 continue 521 522 self.handle_ntf(nl_msg, gm) 523 524 def operation_do_attributes(self, name): 525 """ 526 For a given operation name, find and return a supported 527 set of attributes (as a dict). 528 """ 529 op = self.find_operation(name) 530 if not op: 531 return None 532 533 return op['do']['request']['attributes'].copy() 534 535 def _op(self, method, vals, dump=False): 536 op = self.ops[method] 537 538 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 539 if dump: 540 nl_flags |= Netlink.NLM_F_DUMP 541 542 req_seq = random.randint(1024, 65535) 543 msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq) 544 fixed_header_members = [] 545 if op.fixed_header: 546 fixed_header_members = self.consts[op.fixed_header].members 547 for m in fixed_header_members: 548 value = vals.pop(m.name) if m.name in vals else 0 549 format = NlAttr.get_format(m.type, m.byte_order) 550 msg += format.pack(value) 551 for name, value in vals.items(): 552 msg += self._add_attr(op.attr_set.name, name, value) 553 msg = _genl_msg_finalize(msg) 554 555 self.sock.send(msg, 0) 556 557 done = False 558 rsp = [] 559 while not done: 560 reply = self.sock.recv(128 * 1024) 561 nms = NlMsgs(reply, attr_space=op.attr_set) 562 for nl_msg in nms: 563 if nl_msg.extack: 564 self._decode_extack(msg, op.attr_set, nl_msg.extack) 565 566 if nl_msg.error: 567 raise NlError(nl_msg) 568 if nl_msg.done: 569 if nl_msg.extack: 570 print("Netlink warning:") 571 print(nl_msg) 572 done = True 573 break 574 575 gm = GenlMsg(nl_msg, fixed_header_members) 576 # Check if this is a reply to our request 577 if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value: 578 if gm.genl_cmd in self.async_msg_ids: 579 self.handle_ntf(nl_msg, gm) 580 continue 581 else: 582 print('Unexpected message: ' + repr(gm)) 583 continue 584 585 rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name) 586 rsp_msg.update(gm.fixed_header_attrs) 587 rsp.append(rsp_msg) 588 589 if not rsp: 590 return None 591 if not dump and len(rsp) == 1: 592 return rsp[0] 593 return rsp 594 595 def do(self, method, vals): 596 return self._op(method, vals) 597 598 def dump(self, method, vals): 599 return self._op(method, vals, dump=True) 600