1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3import functools 4import os 5import random 6import socket 7import struct 8import yaml 9 10from .nlspec import SpecFamily 11 12# 13# Generic Netlink code which should really be in some library, but I can't quickly find one. 14# 15 16 17class Netlink: 18 # Netlink socket 19 SOL_NETLINK = 270 20 21 NETLINK_ADD_MEMBERSHIP = 1 22 NETLINK_CAP_ACK = 10 23 NETLINK_EXT_ACK = 11 24 25 # Netlink message 26 NLMSG_ERROR = 2 27 NLMSG_DONE = 3 28 29 NLM_F_REQUEST = 1 30 NLM_F_ACK = 4 31 NLM_F_ROOT = 0x100 32 NLM_F_MATCH = 0x200 33 NLM_F_APPEND = 0x800 34 35 NLM_F_CAPPED = 0x100 36 NLM_F_ACK_TLVS = 0x200 37 38 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 39 40 NLA_F_NESTED = 0x8000 41 NLA_F_NET_BYTEORDER = 0x4000 42 43 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 44 45 # Genetlink defines 46 NETLINK_GENERIC = 16 47 48 GENL_ID_CTRL = 0x10 49 50 # nlctrl 51 CTRL_CMD_GETFAMILY = 3 52 53 CTRL_ATTR_FAMILY_ID = 1 54 CTRL_ATTR_FAMILY_NAME = 2 55 CTRL_ATTR_MAXATTR = 5 56 CTRL_ATTR_MCAST_GROUPS = 7 57 58 CTRL_ATTR_MCAST_GRP_NAME = 1 59 CTRL_ATTR_MCAST_GRP_ID = 2 60 61 # Extack types 62 NLMSGERR_ATTR_MSG = 1 63 NLMSGERR_ATTR_OFFS = 2 64 NLMSGERR_ATTR_COOKIE = 3 65 NLMSGERR_ATTR_POLICY = 4 66 NLMSGERR_ATTR_MISS_TYPE = 5 67 NLMSGERR_ATTR_MISS_NEST = 6 68 69 70class NlAttr: 71 def __init__(self, raw, offset): 72 self._len, self._type = struct.unpack("HH", raw[offset:offset + 4]) 73 self.type = self._type & ~Netlink.NLA_TYPE_MASK 74 self.payload_len = self._len 75 self.full_len = (self.payload_len + 3) & ~3 76 self.raw = raw[offset + 4:offset + self.payload_len] 77 78 def as_u8(self): 79 return struct.unpack("B", self.raw)[0] 80 81 def as_u16(self): 82 return struct.unpack("H", self.raw)[0] 83 84 def as_u32(self): 85 return struct.unpack("I", self.raw)[0] 86 87 def as_u64(self): 88 return struct.unpack("Q", self.raw)[0] 89 90 def as_strz(self): 91 return self.raw.decode('ascii')[:-1] 92 93 def as_bin(self): 94 return self.raw 95 96 def __repr__(self): 97 return f"[type:{self.type} len:{self._len}] {self.raw}" 98 99 100class NlAttrs: 101 def __init__(self, msg): 102 self.attrs = [] 103 104 offset = 0 105 while offset < len(msg): 106 attr = NlAttr(msg, offset) 107 offset += attr.full_len 108 self.attrs.append(attr) 109 110 def __iter__(self): 111 yield from self.attrs 112 113 def __repr__(self): 114 msg = '' 115 for a in self.attrs: 116 if msg: 117 msg += '\n' 118 msg += repr(a) 119 return msg 120 121 122class NlMsg: 123 def __init__(self, msg, offset, attr_space=None): 124 self.hdr = msg[offset:offset + 16] 125 126 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 127 struct.unpack("IHHII", self.hdr) 128 129 self.raw = msg[offset + 16:offset + self.nl_len] 130 131 self.error = 0 132 self.done = 0 133 134 extack_off = None 135 if self.nl_type == Netlink.NLMSG_ERROR: 136 self.error = struct.unpack("i", self.raw[0:4])[0] 137 self.done = 1 138 extack_off = 20 139 elif self.nl_type == Netlink.NLMSG_DONE: 140 self.done = 1 141 extack_off = 4 142 143 self.extack = None 144 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 145 self.extack = dict() 146 extack_attrs = NlAttrs(self.raw[extack_off:]) 147 for extack in extack_attrs: 148 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 149 self.extack['msg'] = extack.as_strz() 150 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 151 self.extack['miss-type'] = extack.as_u32() 152 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 153 self.extack['miss-nest'] = extack.as_u32() 154 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 155 self.extack['bad-attr-offs'] = extack.as_u32() 156 else: 157 if 'unknown' not in self.extack: 158 self.extack['unknown'] = [] 159 self.extack['unknown'].append(extack) 160 161 if attr_space: 162 # We don't have the ability to parse nests yet, so only do global 163 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 164 miss_type = self.extack['miss-type'] 165 if miss_type in attr_space.attrs_by_val: 166 spec = attr_space.attrs_by_val[miss_type] 167 desc = spec['name'] 168 if 'doc' in spec: 169 desc += f" ({spec['doc']})" 170 self.extack['miss-type'] = desc 171 172 def __repr__(self): 173 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" 174 if self.error: 175 msg += '\terror: ' + str(self.error) 176 if self.extack: 177 msg += '\textack: ' + repr(self.extack) 178 return msg 179 180 181class NlMsgs: 182 def __init__(self, data, attr_space=None): 183 self.msgs = [] 184 185 offset = 0 186 while offset < len(data): 187 msg = NlMsg(data, offset, attr_space=attr_space) 188 offset += msg.nl_len 189 self.msgs.append(msg) 190 191 def __iter__(self): 192 yield from self.msgs 193 194 195genl_family_name_to_id = None 196 197 198def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 199 # we prepend length in _genl_msg_finalize() 200 if seq is None: 201 seq = random.randint(1, 1024) 202 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 203 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 204 return nlmsg + genlmsg 205 206 207def _genl_msg_finalize(msg): 208 return struct.pack("I", len(msg) + 4) + msg 209 210 211def _genl_load_families(): 212 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 213 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 214 215 msg = _genl_msg(Netlink.GENL_ID_CTRL, 216 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 217 Netlink.CTRL_CMD_GETFAMILY, 1) 218 msg = _genl_msg_finalize(msg) 219 220 sock.send(msg, 0) 221 222 global genl_family_name_to_id 223 genl_family_name_to_id = dict() 224 225 while True: 226 reply = sock.recv(128 * 1024) 227 nms = NlMsgs(reply) 228 for nl_msg in nms: 229 if nl_msg.error: 230 print("Netlink error:", nl_msg.error) 231 return 232 if nl_msg.done: 233 return 234 235 gm = GenlMsg(nl_msg) 236 fam = dict() 237 for attr in gm.raw_attrs: 238 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 239 fam['id'] = attr.as_u16() 240 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 241 fam['name'] = attr.as_strz() 242 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 243 fam['maxattr'] = attr.as_u32() 244 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 245 fam['mcast'] = dict() 246 for entry in NlAttrs(attr.raw): 247 mcast_name = None 248 mcast_id = None 249 for entry_attr in NlAttrs(entry.raw): 250 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 251 mcast_name = entry_attr.as_strz() 252 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 253 mcast_id = entry_attr.as_u32() 254 if mcast_name and mcast_id is not None: 255 fam['mcast'][mcast_name] = mcast_id 256 if 'name' in fam and 'id' in fam: 257 genl_family_name_to_id[fam['name']] = fam 258 259 260class GenlMsg: 261 def __init__(self, nl_msg): 262 self.nl = nl_msg 263 264 self.hdr = nl_msg.raw[0:4] 265 self.raw = nl_msg.raw[4:] 266 267 self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr) 268 269 self.raw_attrs = NlAttrs(self.raw) 270 271 def __repr__(self): 272 msg = repr(self.nl) 273 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 274 for a in self.raw_attrs: 275 msg += '\t\t' + repr(a) + '\n' 276 return msg 277 278 279class GenlFamily: 280 def __init__(self, family_name): 281 self.family_name = family_name 282 283 global genl_family_name_to_id 284 if genl_family_name_to_id is None: 285 _genl_load_families() 286 287 self.genl_family = genl_family_name_to_id[family_name] 288 self.family_id = genl_family_name_to_id[family_name]['id'] 289 290 291# 292# YNL implementation details. 293# 294 295 296class YnlFamily(SpecFamily): 297 def __init__(self, def_path, schema=None): 298 super().__init__(def_path, schema) 299 300 self.include_raw = False 301 302 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) 303 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 304 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 305 306 self.async_msg_ids = set() 307 self.async_msg_queue = [] 308 309 for msg in self.msgs.values(): 310 if msg.is_async: 311 self.async_msg_ids.add(msg.rsp_value) 312 313 for op_name, op in self.ops.items(): 314 bound_f = functools.partial(self._op, op_name) 315 setattr(self, op.ident_name, bound_f) 316 317 self.family = GenlFamily(self.yaml['name']) 318 319 def ntf_subscribe(self, mcast_name): 320 if mcast_name not in self.family.genl_family['mcast']: 321 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 322 323 self.sock.bind((0, 0)) 324 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 325 self.family.genl_family['mcast'][mcast_name]) 326 327 def _add_attr(self, space, name, value): 328 attr = self.attr_sets[space][name] 329 nl_type = attr.value 330 if attr["type"] == 'nest': 331 nl_type |= Netlink.NLA_F_NESTED 332 attr_payload = b'' 333 for subname, subvalue in value.items(): 334 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) 335 elif attr["type"] == 'flag': 336 attr_payload = b'' 337 elif attr["type"] == 'u32': 338 attr_payload = struct.pack("I", int(value)) 339 elif attr["type"] == 'string': 340 attr_payload = str(value).encode('ascii') + b'\x00' 341 elif attr["type"] == 'binary': 342 attr_payload = value 343 else: 344 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 345 346 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 347 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 348 349 def _decode_enum(self, rsp, attr_spec): 350 raw = rsp[attr_spec['name']] 351 enum = self.consts[attr_spec['enum']] 352 i = attr_spec.get('value-start', 0) 353 if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']: 354 value = set() 355 while raw: 356 if raw & 1: 357 value.add(enum.entries_by_val[i].name) 358 raw >>= 1 359 i += 1 360 else: 361 value = enum.entries_by_val[raw - i].name 362 rsp[attr_spec['name']] = value 363 364 def _decode(self, attrs, space): 365 attr_space = self.attr_sets[space] 366 rsp = dict() 367 for attr in attrs: 368 attr_spec = attr_space.attrs_by_val[attr.type] 369 if attr_spec["type"] == 'nest': 370 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) 371 decoded = subdict 372 elif attr_spec['type'] == 'u8': 373 decoded = attr.as_u8() 374 elif attr_spec['type'] == 'u32': 375 decoded = attr.as_u32() 376 elif attr_spec['type'] == 'u64': 377 decoded = attr.as_u64() 378 elif attr_spec["type"] == 'string': 379 decoded = attr.as_strz() 380 elif attr_spec["type"] == 'binary': 381 decoded = attr.as_bin() 382 elif attr_spec["type"] == 'flag': 383 decoded = True 384 else: 385 raise Exception(f'Unknown {attr.type} {attr_spec["name"]} {attr_spec["type"]}') 386 387 if not attr_spec.is_multi: 388 rsp[attr_spec['name']] = decoded 389 elif attr_spec.name in rsp: 390 rsp[attr_spec.name].append(decoded) 391 else: 392 rsp[attr_spec.name] = [decoded] 393 394 if 'enum' in attr_spec: 395 self._decode_enum(rsp, attr_spec) 396 return rsp 397 398 def _decode_extack_path(self, attrs, attr_set, offset, target): 399 for attr in attrs: 400 attr_spec = attr_set.attrs_by_val[attr.type] 401 if offset > target: 402 break 403 if offset == target: 404 return '.' + attr_spec.name 405 406 if offset + attr.full_len <= target: 407 offset += attr.full_len 408 continue 409 if attr_spec['type'] != 'nest': 410 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 411 offset += 4 412 subpath = self._decode_extack_path(NlAttrs(attr.raw), 413 self.attr_sets[attr_spec['nested-attributes']], 414 offset, target) 415 if subpath is None: 416 return None 417 return '.' + attr_spec.name + subpath 418 419 return None 420 421 def _decode_extack(self, request, attr_space, extack): 422 if 'bad-attr-offs' not in extack: 423 return 424 425 genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space)) 426 path = self._decode_extack_path(genl_req.raw_attrs, attr_space, 427 20, extack['bad-attr-offs']) 428 if path: 429 del extack['bad-attr-offs'] 430 extack['bad-attr'] = path 431 432 def handle_ntf(self, nl_msg, genl_msg): 433 msg = dict() 434 if self.include_raw: 435 msg['nlmsg'] = nl_msg 436 msg['genlmsg'] = genl_msg 437 op = self.rsp_by_value[genl_msg.genl_cmd] 438 msg['name'] = op['name'] 439 msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name) 440 self.async_msg_queue.append(msg) 441 442 def check_ntf(self): 443 while True: 444 try: 445 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) 446 except BlockingIOError: 447 return 448 449 nms = NlMsgs(reply) 450 for nl_msg in nms: 451 if nl_msg.error: 452 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 453 print(nl_msg) 454 continue 455 if nl_msg.done: 456 print("Netlink done while checking for ntf!?") 457 continue 458 459 gm = GenlMsg(nl_msg) 460 if gm.genl_cmd not in self.async_msg_ids: 461 print("Unexpected msg id done while checking for ntf", gm) 462 continue 463 464 self.handle_ntf(nl_msg, gm) 465 466 def _op(self, method, vals, dump=False): 467 op = self.ops[method] 468 469 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 470 if dump: 471 nl_flags |= Netlink.NLM_F_DUMP 472 473 req_seq = random.randint(1024, 65535) 474 msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq) 475 for name, value in vals.items(): 476 msg += self._add_attr(op.attr_set.name, name, value) 477 msg = _genl_msg_finalize(msg) 478 479 self.sock.send(msg, 0) 480 481 done = False 482 rsp = [] 483 while not done: 484 reply = self.sock.recv(128 * 1024) 485 nms = NlMsgs(reply, attr_space=op.attr_set) 486 for nl_msg in nms: 487 if nl_msg.extack: 488 self._decode_extack(msg, op.attr_set, nl_msg.extack) 489 490 if nl_msg.error: 491 print("Netlink error:", os.strerror(-nl_msg.error)) 492 print(nl_msg) 493 return 494 if nl_msg.done: 495 if nl_msg.extack: 496 print("Netlink warning:") 497 print(nl_msg) 498 done = True 499 break 500 501 gm = GenlMsg(nl_msg) 502 # Check if this is a reply to our request 503 if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value: 504 if gm.genl_cmd in self.async_msg_ids: 505 self.handle_ntf(nl_msg, gm) 506 continue 507 else: 508 print('Unexpected message: ' + repr(gm)) 509 continue 510 511 rsp.append(self._decode(gm.raw_attrs, op.attr_set.name)) 512 513 if not rsp: 514 return None 515 if not dump and len(rsp) == 1: 516 return rsp[0] 517 return rsp 518 519 def do(self, method, vals): 520 return self._op(method, vals) 521 522 def dump(self, method, vals): 523 return self._op(method, vals, dump=True) 524