1# SPDX-License-Identifier: BSD-3-Clause 2 3import functools 4import os 5import random 6import socket 7import struct 8import yaml 9 10from .nlspec import SpecFamily 11 12# 13# Generic Netlink code which should really be in some library, but I can't quickly find one. 14# 15 16 17class Netlink: 18 # Netlink socket 19 SOL_NETLINK = 270 20 21 NETLINK_ADD_MEMBERSHIP = 1 22 NETLINK_CAP_ACK = 10 23 NETLINK_EXT_ACK = 11 24 25 # Netlink message 26 NLMSG_ERROR = 2 27 NLMSG_DONE = 3 28 29 NLM_F_REQUEST = 1 30 NLM_F_ACK = 4 31 NLM_F_ROOT = 0x100 32 NLM_F_MATCH = 0x200 33 NLM_F_APPEND = 0x800 34 35 NLM_F_CAPPED = 0x100 36 NLM_F_ACK_TLVS = 0x200 37 38 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 39 40 NLA_F_NESTED = 0x8000 41 NLA_F_NET_BYTEORDER = 0x4000 42 43 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 44 45 # Genetlink defines 46 NETLINK_GENERIC = 16 47 48 GENL_ID_CTRL = 0x10 49 50 # nlctrl 51 CTRL_CMD_GETFAMILY = 3 52 53 CTRL_ATTR_FAMILY_ID = 1 54 CTRL_ATTR_FAMILY_NAME = 2 55 CTRL_ATTR_MAXATTR = 5 56 CTRL_ATTR_MCAST_GROUPS = 7 57 58 CTRL_ATTR_MCAST_GRP_NAME = 1 59 CTRL_ATTR_MCAST_GRP_ID = 2 60 61 # Extack types 62 NLMSGERR_ATTR_MSG = 1 63 NLMSGERR_ATTR_OFFS = 2 64 NLMSGERR_ATTR_COOKIE = 3 65 NLMSGERR_ATTR_POLICY = 4 66 NLMSGERR_ATTR_MISS_TYPE = 5 67 NLMSGERR_ATTR_MISS_NEST = 6 68 69 70class NlAttr: 71 def __init__(self, raw, offset): 72 self._len, self._type = struct.unpack("HH", raw[offset:offset + 4]) 73 self.type = self._type & ~Netlink.NLA_TYPE_MASK 74 self.payload_len = self._len 75 self.full_len = (self.payload_len + 3) & ~3 76 self.raw = raw[offset + 4:offset + self.payload_len] 77 78 def as_u8(self): 79 return struct.unpack("B", self.raw)[0] 80 81 def as_u16(self): 82 return struct.unpack("H", self.raw)[0] 83 84 def as_u32(self): 85 return struct.unpack("I", self.raw)[0] 86 87 def as_u64(self): 88 return struct.unpack("Q", self.raw)[0] 89 90 def as_strz(self): 91 return self.raw.decode('ascii')[:-1] 92 93 def as_bin(self): 94 return self.raw 95 96 def __repr__(self): 97 return f"[type:{self.type} len:{self._len}] {self.raw}" 98 99 100class NlAttrs: 101 def __init__(self, msg): 102 self.attrs = [] 103 104 offset = 0 105 while offset < len(msg): 106 attr = NlAttr(msg, offset) 107 offset += attr.full_len 108 self.attrs.append(attr) 109 110 def __iter__(self): 111 yield from self.attrs 112 113 def __repr__(self): 114 msg = '' 115 for a in self.attrs: 116 if msg: 117 msg += '\n' 118 msg += repr(a) 119 return msg 120 121 122class NlMsg: 123 def __init__(self, msg, offset, attr_space=None): 124 self.hdr = msg[offset:offset + 16] 125 126 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 127 struct.unpack("IHHII", self.hdr) 128 129 self.raw = msg[offset + 16:offset + self.nl_len] 130 131 self.error = 0 132 self.done = 0 133 134 extack_off = None 135 if self.nl_type == Netlink.NLMSG_ERROR: 136 self.error = struct.unpack("i", self.raw[0:4])[0] 137 self.done = 1 138 extack_off = 20 139 elif self.nl_type == Netlink.NLMSG_DONE: 140 self.done = 1 141 extack_off = 4 142 143 self.extack = None 144 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 145 self.extack = dict() 146 extack_attrs = NlAttrs(self.raw[extack_off:]) 147 for extack in extack_attrs: 148 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 149 self.extack['msg'] = extack.as_strz() 150 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 151 self.extack['miss-type'] = extack.as_u32() 152 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 153 self.extack['miss-nest'] = extack.as_u32() 154 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 155 self.extack['bad-attr-offs'] = extack.as_u32() 156 else: 157 if 'unknown' not in self.extack: 158 self.extack['unknown'] = [] 159 self.extack['unknown'].append(extack) 160 161 if attr_space: 162 # We don't have the ability to parse nests yet, so only do global 163 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 164 miss_type = self.extack['miss-type'] 165 if miss_type in attr_space.attrs_by_val: 166 spec = attr_space.attrs_by_val[miss_type] 167 desc = spec['name'] 168 if 'doc' in spec: 169 desc += f" ({spec['doc']})" 170 self.extack['miss-type'] = desc 171 172 def __repr__(self): 173 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" 174 if self.error: 175 msg += '\terror: ' + str(self.error) 176 if self.extack: 177 msg += '\textack: ' + repr(self.extack) 178 return msg 179 180 181class NlMsgs: 182 def __init__(self, data, attr_space=None): 183 self.msgs = [] 184 185 offset = 0 186 while offset < len(data): 187 msg = NlMsg(data, offset, attr_space=attr_space) 188 offset += msg.nl_len 189 self.msgs.append(msg) 190 191 def __iter__(self): 192 yield from self.msgs 193 194 195genl_family_name_to_id = None 196 197 198def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 199 # we prepend length in _genl_msg_finalize() 200 if seq is None: 201 seq = random.randint(1, 1024) 202 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 203 genlmsg = struct.pack("bbH", genl_cmd, genl_version, 0) 204 return nlmsg + genlmsg 205 206 207def _genl_msg_finalize(msg): 208 return struct.pack("I", len(msg) + 4) + msg 209 210 211def _genl_load_families(): 212 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 213 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 214 215 msg = _genl_msg(Netlink.GENL_ID_CTRL, 216 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 217 Netlink.CTRL_CMD_GETFAMILY, 1) 218 msg = _genl_msg_finalize(msg) 219 220 sock.send(msg, 0) 221 222 global genl_family_name_to_id 223 genl_family_name_to_id = dict() 224 225 while True: 226 reply = sock.recv(128 * 1024) 227 nms = NlMsgs(reply) 228 for nl_msg in nms: 229 if nl_msg.error: 230 print("Netlink error:", nl_msg.error) 231 return 232 if nl_msg.done: 233 return 234 235 gm = GenlMsg(nl_msg) 236 fam = dict() 237 for attr in gm.raw_attrs: 238 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 239 fam['id'] = attr.as_u16() 240 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 241 fam['name'] = attr.as_strz() 242 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 243 fam['maxattr'] = attr.as_u32() 244 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 245 fam['mcast'] = dict() 246 for entry in NlAttrs(attr.raw): 247 mcast_name = None 248 mcast_id = None 249 for entry_attr in NlAttrs(entry.raw): 250 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 251 mcast_name = entry_attr.as_strz() 252 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 253 mcast_id = entry_attr.as_u32() 254 if mcast_name and mcast_id is not None: 255 fam['mcast'][mcast_name] = mcast_id 256 if 'name' in fam and 'id' in fam: 257 genl_family_name_to_id[fam['name']] = fam 258 259 260class GenlMsg: 261 def __init__(self, nl_msg): 262 self.nl = nl_msg 263 264 self.hdr = nl_msg.raw[0:4] 265 self.raw = nl_msg.raw[4:] 266 267 self.genl_cmd, self.genl_version, _ = struct.unpack("bbH", self.hdr) 268 269 self.raw_attrs = NlAttrs(self.raw) 270 271 def __repr__(self): 272 msg = repr(self.nl) 273 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 274 for a in self.raw_attrs: 275 msg += '\t\t' + repr(a) + '\n' 276 return msg 277 278 279class GenlFamily: 280 def __init__(self, family_name): 281 self.family_name = family_name 282 283 global genl_family_name_to_id 284 if genl_family_name_to_id is None: 285 _genl_load_families() 286 287 self.genl_family = genl_family_name_to_id[family_name] 288 self.family_id = genl_family_name_to_id[family_name]['id'] 289 290 291# 292# YNL implementation details. 293# 294 295 296class YnlFamily(SpecFamily): 297 def __init__(self, def_path, schema=None): 298 super().__init__(def_path, schema) 299 300 self.include_raw = False 301 302 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) 303 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 304 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 305 306 self._types = dict() 307 308 for elem in self.yaml.get('definitions', []): 309 self._types[elem['name']] = elem 310 311 self.async_msg_ids = set() 312 self.async_msg_queue = [] 313 314 for msg in self.msgs.values(): 315 if msg.is_async: 316 self.async_msg_ids.add(msg.rsp_value) 317 318 for op_name, op in self.ops.items(): 319 bound_f = functools.partial(self._op, op_name) 320 setattr(self, op.ident_name, bound_f) 321 322 self.family = GenlFamily(self.yaml['name']) 323 324 def ntf_subscribe(self, mcast_name): 325 if mcast_name not in self.family.genl_family['mcast']: 326 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 327 328 self.sock.bind((0, 0)) 329 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 330 self.family.genl_family['mcast'][mcast_name]) 331 332 def _add_attr(self, space, name, value): 333 attr = self.attr_sets[space][name] 334 nl_type = attr.value 335 if attr["type"] == 'nest': 336 nl_type |= Netlink.NLA_F_NESTED 337 attr_payload = b'' 338 for subname, subvalue in value.items(): 339 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) 340 elif attr["type"] == 'flag': 341 attr_payload = b'' 342 elif attr["type"] == 'u32': 343 attr_payload = struct.pack("I", int(value)) 344 elif attr["type"] == 'string': 345 attr_payload = str(value).encode('ascii') + b'\x00' 346 elif attr["type"] == 'binary': 347 attr_payload = value 348 else: 349 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 350 351 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 352 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 353 354 def _decode_enum(self, rsp, attr_spec): 355 raw = rsp[attr_spec['name']] 356 enum = self._types[attr_spec['enum']] 357 i = attr_spec.get('value-start', 0) 358 if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']: 359 value = set() 360 while raw: 361 if raw & 1: 362 value.add(enum['entries'][i]) 363 raw >>= 1 364 i += 1 365 else: 366 value = enum['entries'][raw - i] 367 rsp[attr_spec['name']] = value 368 369 def _decode(self, attrs, space): 370 attr_space = self.attr_sets[space] 371 rsp = dict() 372 for attr in attrs: 373 attr_spec = attr_space.attrs_by_val[attr.type] 374 if attr_spec["type"] == 'nest': 375 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) 376 decoded = subdict 377 elif attr_spec['type'] == 'u8': 378 decoded = attr.as_u8() 379 elif attr_spec['type'] == 'u32': 380 decoded = attr.as_u32() 381 elif attr_spec['type'] == 'u64': 382 decoded = attr.as_u64() 383 elif attr_spec["type"] == 'string': 384 decoded = attr.as_strz() 385 elif attr_spec["type"] == 'binary': 386 decoded = attr.as_bin() 387 elif attr_spec["type"] == 'flag': 388 decoded = True 389 else: 390 raise Exception(f'Unknown {attr.type} {attr_spec["name"]} {attr_spec["type"]}') 391 392 if not attr_spec.is_multi: 393 rsp[attr_spec['name']] = decoded 394 elif attr_spec.name in rsp: 395 rsp[attr_spec.name].append(decoded) 396 else: 397 rsp[attr_spec.name] = [decoded] 398 399 if 'enum' in attr_spec: 400 self._decode_enum(rsp, attr_spec) 401 return rsp 402 403 def _decode_extack_path(self, attrs, attr_set, offset, target): 404 for attr in attrs: 405 attr_spec = attr_set.attrs_by_val[attr.type] 406 if offset > target: 407 break 408 if offset == target: 409 return '.' + attr_spec.name 410 411 if offset + attr.full_len <= target: 412 offset += attr.full_len 413 continue 414 if attr_spec['type'] != 'nest': 415 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 416 offset += 4 417 subpath = self._decode_extack_path(NlAttrs(attr.raw), 418 self.attr_sets[attr_spec['nested-attributes']], 419 offset, target) 420 if subpath is None: 421 return None 422 return '.' + attr_spec.name + subpath 423 424 return None 425 426 def _decode_extack(self, request, attr_space, extack): 427 if 'bad-attr-offs' not in extack: 428 return 429 430 genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space)) 431 path = self._decode_extack_path(genl_req.raw_attrs, attr_space, 432 20, extack['bad-attr-offs']) 433 if path: 434 del extack['bad-attr-offs'] 435 extack['bad-attr'] = path 436 437 def handle_ntf(self, nl_msg, genl_msg): 438 msg = dict() 439 if self.include_raw: 440 msg['nlmsg'] = nl_msg 441 msg['genlmsg'] = genl_msg 442 op = self.rsp_by_value[genl_msg.genl_cmd] 443 msg['name'] = op['name'] 444 msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name) 445 self.async_msg_queue.append(msg) 446 447 def check_ntf(self): 448 while True: 449 try: 450 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) 451 except BlockingIOError: 452 return 453 454 nms = NlMsgs(reply) 455 for nl_msg in nms: 456 if nl_msg.error: 457 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 458 print(nl_msg) 459 continue 460 if nl_msg.done: 461 print("Netlink done while checking for ntf!?") 462 continue 463 464 gm = GenlMsg(nl_msg) 465 if gm.genl_cmd not in self.async_msg_ids: 466 print("Unexpected msg id done while checking for ntf", gm) 467 continue 468 469 self.handle_ntf(nl_msg, gm) 470 471 def _op(self, method, vals, dump=False): 472 op = self.ops[method] 473 474 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 475 if dump: 476 nl_flags |= Netlink.NLM_F_DUMP 477 478 req_seq = random.randint(1024, 65535) 479 msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq) 480 for name, value in vals.items(): 481 msg += self._add_attr(op.attr_set.name, name, value) 482 msg = _genl_msg_finalize(msg) 483 484 self.sock.send(msg, 0) 485 486 done = False 487 rsp = [] 488 while not done: 489 reply = self.sock.recv(128 * 1024) 490 nms = NlMsgs(reply, attr_space=op.attr_set) 491 for nl_msg in nms: 492 if nl_msg.extack: 493 self._decode_extack(msg, op.attr_set, nl_msg.extack) 494 495 if nl_msg.error: 496 print("Netlink error:", os.strerror(-nl_msg.error)) 497 print(nl_msg) 498 return 499 if nl_msg.done: 500 if nl_msg.extack: 501 print("Netlink warning:") 502 print(nl_msg) 503 done = True 504 break 505 506 gm = GenlMsg(nl_msg) 507 # Check if this is a reply to our request 508 if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value: 509 if gm.genl_cmd in self.async_msg_ids: 510 self.handle_ntf(nl_msg, gm) 511 continue 512 else: 513 print('Unexpected message: ' + repr(gm)) 514 continue 515 516 rsp.append(self._decode(gm.raw_attrs, op.attr_set.name)) 517 518 if not rsp: 519 return None 520 if not dump and len(rsp) == 1: 521 return rsp[0] 522 return rsp 523 524 def do(self, method, vals): 525 return self._op(method, vals) 526 527 def dump(self, method, vals): 528 return self._op(method, vals, dump=True) 529