| # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause |
| |
| from collections import namedtuple |
| from enum import Enum |
| import functools |
| import os |
| import random |
| import socket |
| import struct |
| from struct import Struct |
| import sys |
| import yaml |
| import ipaddress |
| import uuid |
| |
| from .nlspec import SpecFamily |
| |
| # |
| # Generic Netlink code which should really be in some library, but I can't quickly find one. |
| # |
| |
| |
| class Netlink: |
| # Netlink socket |
| SOL_NETLINK = 270 |
| |
| NETLINK_ADD_MEMBERSHIP = 1 |
| NETLINK_CAP_ACK = 10 |
| NETLINK_EXT_ACK = 11 |
| NETLINK_GET_STRICT_CHK = 12 |
| |
| # Netlink message |
| NLMSG_ERROR = 2 |
| NLMSG_DONE = 3 |
| |
| NLM_F_REQUEST = 1 |
| NLM_F_ACK = 4 |
| NLM_F_ROOT = 0x100 |
| NLM_F_MATCH = 0x200 |
| |
| NLM_F_REPLACE = 0x100 |
| NLM_F_EXCL = 0x200 |
| NLM_F_CREATE = 0x400 |
| NLM_F_APPEND = 0x800 |
| |
| NLM_F_CAPPED = 0x100 |
| NLM_F_ACK_TLVS = 0x200 |
| |
| NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH |
| |
| NLA_F_NESTED = 0x8000 |
| NLA_F_NET_BYTEORDER = 0x4000 |
| |
| NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER |
| |
| # Genetlink defines |
| NETLINK_GENERIC = 16 |
| |
| GENL_ID_CTRL = 0x10 |
| |
| # nlctrl |
| CTRL_CMD_GETFAMILY = 3 |
| |
| CTRL_ATTR_FAMILY_ID = 1 |
| CTRL_ATTR_FAMILY_NAME = 2 |
| CTRL_ATTR_MAXATTR = 5 |
| CTRL_ATTR_MCAST_GROUPS = 7 |
| |
| CTRL_ATTR_MCAST_GRP_NAME = 1 |
| CTRL_ATTR_MCAST_GRP_ID = 2 |
| |
| # Extack types |
| NLMSGERR_ATTR_MSG = 1 |
| NLMSGERR_ATTR_OFFS = 2 |
| NLMSGERR_ATTR_COOKIE = 3 |
| NLMSGERR_ATTR_POLICY = 4 |
| NLMSGERR_ATTR_MISS_TYPE = 5 |
| NLMSGERR_ATTR_MISS_NEST = 6 |
| |
| # Policy types |
| NL_POLICY_TYPE_ATTR_TYPE = 1 |
| NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 |
| NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 |
| NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 |
| NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 |
| NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 |
| NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 |
| NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 |
| NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 |
| NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 |
| NL_POLICY_TYPE_ATTR_PAD = 11 |
| NL_POLICY_TYPE_ATTR_MASK = 12 |
| |
| AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', |
| 's8', 's16', 's32', 's64', |
| 'binary', 'string', 'nul-string', |
| 'nested', 'nested-array', |
| 'bitfield32', 'sint', 'uint']) |
| |
| class NlError(Exception): |
| def __init__(self, nl_msg): |
| self.nl_msg = nl_msg |
| self.error = -nl_msg.error |
| |
| def __str__(self): |
| return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" |
| |
| |
| class ConfigError(Exception): |
| pass |
| |
| |
| class NlAttr: |
| ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) |
| type_formats = { |
| 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), |
| 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), |
| 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), |
| 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), |
| 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), |
| 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), |
| 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), |
| 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) |
| } |
| |
| def __init__(self, raw, offset): |
| self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) |
| self.type = self._type & ~Netlink.NLA_TYPE_MASK |
| self.is_nest = self._type & Netlink.NLA_F_NESTED |
| self.payload_len = self._len |
| self.full_len = (self.payload_len + 3) & ~3 |
| self.raw = raw[offset + 4 : offset + self.payload_len] |
| |
| @classmethod |
| def get_format(cls, attr_type, byte_order=None): |
| format = cls.type_formats[attr_type] |
| if byte_order: |
| return format.big if byte_order == "big-endian" \ |
| else format.little |
| return format.native |
| |
| def as_scalar(self, attr_type, byte_order=None): |
| format = self.get_format(attr_type, byte_order) |
| return format.unpack(self.raw)[0] |
| |
| def as_auto_scalar(self, attr_type, byte_order=None): |
| if len(self.raw) != 4 and len(self.raw) != 8: |
| raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") |
| real_type = attr_type[0] + str(len(self.raw) * 8) |
| format = self.get_format(real_type, byte_order) |
| return format.unpack(self.raw)[0] |
| |
| def as_strz(self): |
| return self.raw.decode('ascii')[:-1] |
| |
| def as_bin(self): |
| return self.raw |
| |
| def as_c_array(self, type): |
| format = self.get_format(type) |
| return [ x[0] for x in format.iter_unpack(self.raw) ] |
| |
| def __repr__(self): |
| return f"[type:{self.type} len:{self._len}] {self.raw}" |
| |
| |
| class NlAttrs: |
| def __init__(self, msg, offset=0): |
| self.attrs = [] |
| |
| while offset < len(msg): |
| attr = NlAttr(msg, offset) |
| offset += attr.full_len |
| self.attrs.append(attr) |
| |
| def __iter__(self): |
| yield from self.attrs |
| |
| def __repr__(self): |
| msg = '' |
| for a in self.attrs: |
| if msg: |
| msg += '\n' |
| msg += repr(a) |
| return msg |
| |
| |
| class NlMsg: |
| def __init__(self, msg, offset, attr_space=None): |
| self.hdr = msg[offset : offset + 16] |
| |
| self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ |
| struct.unpack("IHHII", self.hdr) |
| |
| self.raw = msg[offset + 16 : offset + self.nl_len] |
| |
| self.error = 0 |
| self.done = 0 |
| |
| extack_off = None |
| if self.nl_type == Netlink.NLMSG_ERROR: |
| self.error = struct.unpack("i", self.raw[0:4])[0] |
| self.done = 1 |
| extack_off = 20 |
| elif self.nl_type == Netlink.NLMSG_DONE: |
| self.error = struct.unpack("i", self.raw[0:4])[0] |
| self.done = 1 |
| extack_off = 4 |
| |
| self.extack = None |
| if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: |
| self.extack = dict() |
| extack_attrs = NlAttrs(self.raw[extack_off:]) |
| for extack in extack_attrs: |
| if extack.type == Netlink.NLMSGERR_ATTR_MSG: |
| self.extack['msg'] = extack.as_strz() |
| elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: |
| self.extack['miss-type'] = extack.as_scalar('u32') |
| elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: |
| self.extack['miss-nest'] = extack.as_scalar('u32') |
| elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: |
| self.extack['bad-attr-offs'] = extack.as_scalar('u32') |
| elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: |
| self.extack['policy'] = self._decode_policy(extack.raw) |
| else: |
| if 'unknown' not in self.extack: |
| self.extack['unknown'] = [] |
| self.extack['unknown'].append(extack) |
| |
| if attr_space: |
| # We don't have the ability to parse nests yet, so only do global |
| if 'miss-type' in self.extack and 'miss-nest' not in self.extack: |
| miss_type = self.extack['miss-type'] |
| if miss_type in attr_space.attrs_by_val: |
| spec = attr_space.attrs_by_val[miss_type] |
| self.extack['miss-type'] = spec['name'] |
| if 'doc' in spec: |
| self.extack['miss-type-doc'] = spec['doc'] |
| |
| def _decode_policy(self, raw): |
| policy = {} |
| for attr in NlAttrs(raw): |
| if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: |
| type = attr.as_scalar('u32') |
| policy['type'] = Netlink.AttrType(type).name |
| elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: |
| policy['min-value'] = attr.as_scalar('s64') |
| elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: |
| policy['max-value'] = attr.as_scalar('s64') |
| elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: |
| policy['min-value'] = attr.as_scalar('u64') |
| elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: |
| policy['max-value'] = attr.as_scalar('u64') |
| elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: |
| policy['min-length'] = attr.as_scalar('u32') |
| elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: |
| policy['max-length'] = attr.as_scalar('u32') |
| elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: |
| policy['bitfield32-mask'] = attr.as_scalar('u32') |
| elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: |
| policy['mask'] = attr.as_scalar('u64') |
| return policy |
| |
| def cmd(self): |
| return self.nl_type |
| |
| def __repr__(self): |
| msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" |
| if self.error: |
| msg += '\n\terror: ' + str(self.error) |
| if self.extack: |
| msg += '\n\textack: ' + repr(self.extack) |
| return msg |
| |
| |
| class NlMsgs: |
| def __init__(self, data, attr_space=None): |
| self.msgs = [] |
| |
| offset = 0 |
| while offset < len(data): |
| msg = NlMsg(data, offset, attr_space=attr_space) |
| offset += msg.nl_len |
| self.msgs.append(msg) |
| |
| def __iter__(self): |
| yield from self.msgs |
| |
| |
| genl_family_name_to_id = None |
| |
| |
| def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): |
| # we prepend length in _genl_msg_finalize() |
| if seq is None: |
| seq = random.randint(1, 1024) |
| nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) |
| genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) |
| return nlmsg + genlmsg |
| |
| |
| def _genl_msg_finalize(msg): |
| return struct.pack("I", len(msg) + 4) + msg |
| |
| |
| def _genl_load_families(): |
| with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: |
| sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) |
| |
| msg = _genl_msg(Netlink.GENL_ID_CTRL, |
| Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, |
| Netlink.CTRL_CMD_GETFAMILY, 1) |
| msg = _genl_msg_finalize(msg) |
| |
| sock.send(msg, 0) |
| |
| global genl_family_name_to_id |
| genl_family_name_to_id = dict() |
| |
| while True: |
| reply = sock.recv(128 * 1024) |
| nms = NlMsgs(reply) |
| for nl_msg in nms: |
| if nl_msg.error: |
| print("Netlink error:", nl_msg.error) |
| return |
| if nl_msg.done: |
| return |
| |
| gm = GenlMsg(nl_msg) |
| fam = dict() |
| for attr in NlAttrs(gm.raw): |
| if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: |
| fam['id'] = attr.as_scalar('u16') |
| elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: |
| fam['name'] = attr.as_strz() |
| elif attr.type == Netlink.CTRL_ATTR_MAXATTR: |
| fam['maxattr'] = attr.as_scalar('u32') |
| elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: |
| fam['mcast'] = dict() |
| for entry in NlAttrs(attr.raw): |
| mcast_name = None |
| mcast_id = None |
| for entry_attr in NlAttrs(entry.raw): |
| if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: |
| mcast_name = entry_attr.as_strz() |
| elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: |
| mcast_id = entry_attr.as_scalar('u32') |
| if mcast_name and mcast_id is not None: |
| fam['mcast'][mcast_name] = mcast_id |
| if 'name' in fam and 'id' in fam: |
| genl_family_name_to_id[fam['name']] = fam |
| |
| |
| class GenlMsg: |
| def __init__(self, nl_msg): |
| self.nl = nl_msg |
| self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) |
| self.raw = nl_msg.raw[4:] |
| |
| def cmd(self): |
| return self.genl_cmd |
| |
| def __repr__(self): |
| msg = repr(self.nl) |
| msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" |
| for a in self.raw_attrs: |
| msg += '\t\t' + repr(a) + '\n' |
| return msg |
| |
| |
| class NetlinkProtocol: |
| def __init__(self, family_name, proto_num): |
| self.family_name = family_name |
| self.proto_num = proto_num |
| |
| def _message(self, nl_type, nl_flags, seq=None): |
| if seq is None: |
| seq = random.randint(1, 1024) |
| nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) |
| return nlmsg |
| |
| def message(self, flags, command, version, seq=None): |
| return self._message(command, flags, seq) |
| |
| def _decode(self, nl_msg): |
| return nl_msg |
| |
| def decode(self, ynl, nl_msg, op): |
| msg = self._decode(nl_msg) |
| fixed_header_size = ynl._struct_size(op.fixed_header) |
| msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) |
| return msg |
| |
| def get_mcast_id(self, mcast_name, mcast_groups): |
| if mcast_name not in mcast_groups: |
| raise Exception(f'Multicast group "{mcast_name}" not present in the spec') |
| return mcast_groups[mcast_name].value |
| |
| def msghdr_size(self): |
| return 16 |
| |
| |
| class GenlProtocol(NetlinkProtocol): |
| def __init__(self, family_name): |
| super().__init__(family_name, Netlink.NETLINK_GENERIC) |
| |
| global genl_family_name_to_id |
| if genl_family_name_to_id is None: |
| _genl_load_families() |
| |
| self.genl_family = genl_family_name_to_id[family_name] |
| self.family_id = genl_family_name_to_id[family_name]['id'] |
| |
| def message(self, flags, command, version, seq=None): |
| nlmsg = self._message(self.family_id, flags, seq) |
| genlmsg = struct.pack("BBH", command, version, 0) |
| return nlmsg + genlmsg |
| |
| def _decode(self, nl_msg): |
| return GenlMsg(nl_msg) |
| |
| def get_mcast_id(self, mcast_name, mcast_groups): |
| if mcast_name not in self.genl_family['mcast']: |
| raise Exception(f'Multicast group "{mcast_name}" not present in the family') |
| return self.genl_family['mcast'][mcast_name] |
| |
| def msghdr_size(self): |
| return super().msghdr_size() + 4 |
| |
| |
| class SpaceAttrs: |
| SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) |
| |
| def __init__(self, attr_space, attrs, outer = None): |
| outer_scopes = outer.scopes if outer else [] |
| inner_scope = self.SpecValuesPair(attr_space, attrs) |
| self.scopes = [inner_scope] + outer_scopes |
| |
| def lookup(self, name): |
| for scope in self.scopes: |
| if name in scope.spec: |
| if name in scope.values: |
| return scope.values[name] |
| spec_name = scope.spec.yaml['name'] |
| raise Exception( |
| f"No value for '{name}' in attribute space '{spec_name}'") |
| raise Exception(f"Attribute '{name}' not defined in any attribute-set") |
| |
| |
| # |
| # YNL implementation details. |
| # |
| |
| |
| class YnlFamily(SpecFamily): |
| def __init__(self, def_path, schema=None, process_unknown=False, |
| recv_size=0): |
| super().__init__(def_path, schema) |
| |
| self.include_raw = False |
| self.process_unknown = process_unknown |
| |
| try: |
| if self.proto == "netlink-raw": |
| self.nlproto = NetlinkProtocol(self.yaml['name'], |
| self.yaml['protonum']) |
| else: |
| self.nlproto = GenlProtocol(self.yaml['name']) |
| except KeyError: |
| raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") |
| |
| self._recv_dbg = False |
| # Note that netlink will use conservative (min) message size for |
| # the first dump recv() on the socket, our setting will only matter |
| # from the second recv() on. |
| self._recv_size = recv_size if recv_size else 131072 |
| # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) |
| # for a message, so smaller receive sizes will lead to truncation. |
| # Note that the min size for other families may be larger than 4k! |
| if self._recv_size < 4000: |
| raise ConfigError() |
| |
| self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) |
| self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) |
| self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) |
| self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) |
| |
| self.async_msg_ids = set() |
| self.async_msg_queue = [] |
| |
| for msg in self.msgs.values(): |
| if msg.is_async: |
| self.async_msg_ids.add(msg.rsp_value) |
| |
| for op_name, op in self.ops.items(): |
| bound_f = functools.partial(self._op, op_name) |
| setattr(self, op.ident_name, bound_f) |
| |
| |
| def ntf_subscribe(self, mcast_name): |
| mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) |
| self.sock.bind((0, 0)) |
| self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, |
| mcast_id) |
| |
| def set_recv_dbg(self, enabled): |
| self._recv_dbg = enabled |
| |
| def _recv_dbg_print(self, reply, nl_msgs): |
| if not self._recv_dbg: |
| return |
| print("Recv: read", len(reply), "bytes,", |
| len(nl_msgs.msgs), "messages", file=sys.stderr) |
| for nl_msg in nl_msgs: |
| print(" ", nl_msg, file=sys.stderr) |
| |
| def _encode_enum(self, attr_spec, value): |
| enum = self.consts[attr_spec['enum']] |
| if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): |
| scalar = 0 |
| if isinstance(value, str): |
| value = [value] |
| for single_value in value: |
| scalar += enum.entries[single_value].user_value(as_flags = True) |
| return scalar |
| else: |
| return enum.entries[value].user_value() |
| |
| def _get_scalar(self, attr_spec, value): |
| try: |
| return int(value) |
| except (ValueError, TypeError) as e: |
| if 'enum' not in attr_spec: |
| raise e |
| return self._encode_enum(attr_spec, value) |
| |
| def _add_attr(self, space, name, value, search_attrs): |
| try: |
| attr = self.attr_sets[space][name] |
| except KeyError: |
| raise Exception(f"Space '{space}' has no attribute '{name}'") |
| nl_type = attr.value |
| |
| if attr.is_multi and isinstance(value, list): |
| attr_payload = b'' |
| for subvalue in value: |
| attr_payload += self._add_attr(space, name, subvalue, search_attrs) |
| return attr_payload |
| |
| if attr["type"] == 'nest': |
| nl_type |= Netlink.NLA_F_NESTED |
| attr_payload = b'' |
| sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs) |
| for subname, subvalue in value.items(): |
| attr_payload += self._add_attr(attr['nested-attributes'], |
| subname, subvalue, sub_attrs) |
| elif attr["type"] == 'flag': |
| if not value: |
| # If value is absent or false then skip attribute creation. |
| return b'' |
| attr_payload = b'' |
| elif attr["type"] == 'string': |
| attr_payload = str(value).encode('ascii') + b'\x00' |
| elif attr["type"] == 'binary': |
| if isinstance(value, bytes): |
| attr_payload = value |
| elif isinstance(value, str): |
| attr_payload = bytes.fromhex(value) |
| elif isinstance(value, dict) and attr.struct_name: |
| attr_payload = self._encode_struct(attr.struct_name, value) |
| else: |
| raise Exception(f'Unknown type for binary attribute, value: {value}') |
| elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: |
| scalar = self._get_scalar(attr, value) |
| if attr.is_auto_scalar: |
| attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') |
| else: |
| attr_type = attr["type"] |
| format = NlAttr.get_format(attr_type, attr.byte_order) |
| attr_payload = format.pack(scalar) |
| elif attr['type'] in "bitfield32": |
| scalar_value = self._get_scalar(attr, value["value"]) |
| scalar_selector = self._get_scalar(attr, value["selector"]) |
| attr_payload = struct.pack("II", scalar_value, scalar_selector) |
| elif attr['type'] == 'sub-message': |
| msg_format = self._resolve_selector(attr, search_attrs) |
| attr_payload = b'' |
| if msg_format.fixed_header: |
| attr_payload += self._encode_struct(msg_format.fixed_header, value) |
| if msg_format.attr_set: |
| if msg_format.attr_set in self.attr_sets: |
| nl_type |= Netlink.NLA_F_NESTED |
| sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) |
| for subname, subvalue in value.items(): |
| attr_payload += self._add_attr(msg_format.attr_set, |
| subname, subvalue, sub_attrs) |
| else: |
| raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") |
| else: |
| raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') |
| |
| pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) |
| return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad |
| |
| def _decode_enum(self, raw, attr_spec): |
| enum = self.consts[attr_spec['enum']] |
| if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): |
| i = 0 |
| value = set() |
| while raw: |
| if raw & 1: |
| value.add(enum.entries_by_val[i].name) |
| raw >>= 1 |
| i += 1 |
| else: |
| value = enum.entries_by_val[raw].name |
| return value |
| |
| def _decode_binary(self, attr, attr_spec): |
| if attr_spec.struct_name: |
| decoded = self._decode_struct(attr.raw, attr_spec.struct_name) |
| elif attr_spec.sub_type: |
| decoded = attr.as_c_array(attr_spec.sub_type) |
| else: |
| decoded = attr.as_bin() |
| if attr_spec.display_hint: |
| decoded = self._formatted_string(decoded, attr_spec.display_hint) |
| return decoded |
| |
| def _decode_array_attr(self, attr, attr_spec): |
| decoded = [] |
| offset = 0 |
| while offset < len(attr.raw): |
| item = NlAttr(attr.raw, offset) |
| offset += item.full_len |
| |
| if attr_spec["sub-type"] == 'nest': |
| subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) |
| decoded.append({ item.type: subattrs }) |
| elif attr_spec["sub-type"] == 'binary': |
| subattrs = item.as_bin() |
| if attr_spec.display_hint: |
| subattrs = self._formatted_string(subattrs, attr_spec.display_hint) |
| decoded.append(subattrs) |
| elif attr_spec["sub-type"] in NlAttr.type_formats: |
| subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) |
| if attr_spec.display_hint: |
| subattrs = self._formatted_string(subattrs, attr_spec.display_hint) |
| decoded.append(subattrs) |
| else: |
| raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') |
| return decoded |
| |
| def _decode_nest_type_value(self, attr, attr_spec): |
| decoded = {} |
| value = attr |
| for name in attr_spec['type-value']: |
| value = NlAttr(value.raw, 0) |
| decoded[name] = value.type |
| subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) |
| decoded.update(subattrs) |
| return decoded |
| |
| def _decode_unknown(self, attr): |
| if attr.is_nest: |
| return self._decode(NlAttrs(attr.raw), None) |
| else: |
| return attr.as_bin() |
| |
| def _rsp_add(self, rsp, name, is_multi, decoded): |
| if is_multi == None: |
| if name in rsp and type(rsp[name]) is not list: |
| rsp[name] = [rsp[name]] |
| is_multi = True |
| else: |
| is_multi = False |
| |
| if not is_multi: |
| rsp[name] = decoded |
| elif name in rsp: |
| rsp[name].append(decoded) |
| else: |
| rsp[name] = [decoded] |
| |
| def _resolve_selector(self, attr_spec, search_attrs): |
| sub_msg = attr_spec.sub_message |
| if sub_msg not in self.sub_msgs: |
| raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") |
| sub_msg_spec = self.sub_msgs[sub_msg] |
| |
| selector = attr_spec.selector |
| value = search_attrs.lookup(selector) |
| if value not in sub_msg_spec.formats: |
| raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") |
| |
| spec = sub_msg_spec.formats[value] |
| return spec |
| |
| def _decode_sub_msg(self, attr, attr_spec, search_attrs): |
| msg_format = self._resolve_selector(attr_spec, search_attrs) |
| decoded = {} |
| offset = 0 |
| if msg_format.fixed_header: |
| decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); |
| offset = self._struct_size(msg_format.fixed_header) |
| if msg_format.attr_set: |
| if msg_format.attr_set in self.attr_sets: |
| subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) |
| decoded.update(subdict) |
| else: |
| raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") |
| return decoded |
| |
| def _decode(self, attrs, space, outer_attrs = None): |
| rsp = dict() |
| if space: |
| attr_space = self.attr_sets[space] |
| search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) |
| |
| for attr in attrs: |
| try: |
| attr_spec = attr_space.attrs_by_val[attr.type] |
| except (KeyError, UnboundLocalError): |
| if not self.process_unknown: |
| raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") |
| attr_name = f"UnknownAttr({attr.type})" |
| self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) |
| continue |
| |
| if attr_spec["type"] == 'nest': |
| subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) |
| decoded = subdict |
| elif attr_spec["type"] == 'string': |
| decoded = attr.as_strz() |
| elif attr_spec["type"] == 'binary': |
| decoded = self._decode_binary(attr, attr_spec) |
| elif attr_spec["type"] == 'flag': |
| decoded = True |
| elif attr_spec.is_auto_scalar: |
| decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) |
| elif attr_spec["type"] in NlAttr.type_formats: |
| decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) |
| if 'enum' in attr_spec: |
| decoded = self._decode_enum(decoded, attr_spec) |
| elif attr_spec.display_hint: |
| decoded = self._formatted_string(decoded, attr_spec.display_hint) |
| elif attr_spec["type"] == 'indexed-array': |
| decoded = self._decode_array_attr(attr, attr_spec) |
| elif attr_spec["type"] == 'bitfield32': |
| value, selector = struct.unpack("II", attr.raw) |
| if 'enum' in attr_spec: |
| value = self._decode_enum(value, attr_spec) |
| selector = self._decode_enum(selector, attr_spec) |
| decoded = {"value": value, "selector": selector} |
| elif attr_spec["type"] == 'sub-message': |
| decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) |
| elif attr_spec["type"] == 'nest-type-value': |
| decoded = self._decode_nest_type_value(attr, attr_spec) |
| else: |
| if not self.process_unknown: |
| raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') |
| decoded = self._decode_unknown(attr) |
| |
| self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) |
| |
| return rsp |
| |
| def _decode_extack_path(self, attrs, attr_set, offset, target): |
| for attr in attrs: |
| try: |
| attr_spec = attr_set.attrs_by_val[attr.type] |
| except KeyError: |
| raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") |
| if offset > target: |
| break |
| if offset == target: |
| return '.' + attr_spec.name |
| |
| if offset + attr.full_len <= target: |
| offset += attr.full_len |
| continue |
| if attr_spec['type'] != 'nest': |
| raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") |
| offset += 4 |
| subpath = self._decode_extack_path(NlAttrs(attr.raw), |
| self.attr_sets[attr_spec['nested-attributes']], |
| offset, target) |
| if subpath is None: |
| return None |
| return '.' + attr_spec.name + subpath |
| |
| return None |
| |
| def _decode_extack(self, request, op, extack): |
| if 'bad-attr-offs' not in extack: |
| return |
| |
| msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) |
| offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) |
| path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, |
| extack['bad-attr-offs']) |
| if path: |
| del extack['bad-attr-offs'] |
| extack['bad-attr'] = path |
| |
| def _struct_size(self, name): |
| if name: |
| members = self.consts[name].members |
| size = 0 |
| for m in members: |
| if m.type in ['pad', 'binary']: |
| if m.struct: |
| size += self._struct_size(m.struct) |
| else: |
| size += m.len |
| else: |
| format = NlAttr.get_format(m.type, m.byte_order) |
| size += format.size |
| return size |
| else: |
| return 0 |
| |
| def _decode_struct(self, data, name): |
| members = self.consts[name].members |
| attrs = dict() |
| offset = 0 |
| for m in members: |
| value = None |
| if m.type == 'pad': |
| offset += m.len |
| elif m.type == 'binary': |
| if m.struct: |
| len = self._struct_size(m.struct) |
| value = self._decode_struct(data[offset : offset + len], |
| m.struct) |
| offset += len |
| else: |
| value = data[offset : offset + m.len] |
| offset += m.len |
| else: |
| format = NlAttr.get_format(m.type, m.byte_order) |
| [ value ] = format.unpack_from(data, offset) |
| offset += format.size |
| if value is not None: |
| if m.enum: |
| value = self._decode_enum(value, m) |
| elif m.display_hint: |
| value = self._formatted_string(value, m.display_hint) |
| attrs[m.name] = value |
| return attrs |
| |
| def _encode_struct(self, name, vals): |
| members = self.consts[name].members |
| attr_payload = b'' |
| for m in members: |
| value = vals.pop(m.name) if m.name in vals else None |
| if m.type == 'pad': |
| attr_payload += bytearray(m.len) |
| elif m.type == 'binary': |
| if m.struct: |
| if value is None: |
| value = dict() |
| attr_payload += self._encode_struct(m.struct, value) |
| else: |
| if value is None: |
| attr_payload += bytearray(m.len) |
| else: |
| attr_payload += bytes.fromhex(value) |
| else: |
| if value is None: |
| value = 0 |
| format = NlAttr.get_format(m.type, m.byte_order) |
| attr_payload += format.pack(value) |
| return attr_payload |
| |
| def _formatted_string(self, raw, display_hint): |
| if display_hint == 'mac': |
| formatted = ':'.join('%02x' % b for b in raw) |
| elif display_hint == 'hex': |
| if isinstance(raw, int): |
| formatted = hex(raw) |
| else: |
| formatted = bytes.hex(raw, ' ') |
| elif display_hint in [ 'ipv4', 'ipv6' ]: |
| formatted = format(ipaddress.ip_address(raw)) |
| elif display_hint == 'uuid': |
| formatted = str(uuid.UUID(bytes=raw)) |
| else: |
| formatted = raw |
| return formatted |
| |
| def handle_ntf(self, decoded): |
| msg = dict() |
| if self.include_raw: |
| msg['raw'] = decoded |
| op = self.rsp_by_value[decoded.cmd()] |
| attrs = self._decode(decoded.raw_attrs, op.attr_set.name) |
| if op.fixed_header: |
| attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) |
| |
| msg['name'] = op['name'] |
| msg['msg'] = attrs |
| self.async_msg_queue.append(msg) |
| |
| def check_ntf(self): |
| while True: |
| try: |
| reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) |
| except BlockingIOError: |
| return |
| |
| nms = NlMsgs(reply) |
| self._recv_dbg_print(reply, nms) |
| for nl_msg in nms: |
| if nl_msg.error: |
| print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) |
| print(nl_msg) |
| continue |
| if nl_msg.done: |
| print("Netlink done while checking for ntf!?") |
| continue |
| |
| op = self.rsp_by_value[nl_msg.cmd()] |
| decoded = self.nlproto.decode(self, nl_msg, op) |
| if decoded.cmd() not in self.async_msg_ids: |
| print("Unexpected msg id done while checking for ntf", decoded) |
| continue |
| |
| self.handle_ntf(decoded) |
| |
| def operation_do_attributes(self, name): |
| """ |
| For a given operation name, find and return a supported |
| set of attributes (as a dict). |
| """ |
| op = self.find_operation(name) |
| if not op: |
| return None |
| |
| return op['do']['request']['attributes'].copy() |
| |
| def _encode_message(self, op, vals, flags, req_seq): |
| nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK |
| for flag in flags or []: |
| nl_flags |= flag |
| |
| msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) |
| if op.fixed_header: |
| msg += self._encode_struct(op.fixed_header, vals) |
| search_attrs = SpaceAttrs(op.attr_set, vals) |
| for name, value in vals.items(): |
| msg += self._add_attr(op.attr_set.name, name, value, search_attrs) |
| msg = _genl_msg_finalize(msg) |
| return msg |
| |
| def _ops(self, ops): |
| reqs_by_seq = {} |
| req_seq = random.randint(1024, 65535) |
| payload = b'' |
| for (method, vals, flags) in ops: |
| op = self.ops[method] |
| msg = self._encode_message(op, vals, flags, req_seq) |
| reqs_by_seq[req_seq] = (op, msg, flags) |
| payload += msg |
| req_seq += 1 |
| |
| self.sock.send(payload, 0) |
| |
| done = False |
| rsp = [] |
| op_rsp = [] |
| while not done: |
| reply = self.sock.recv(self._recv_size) |
| nms = NlMsgs(reply, attr_space=op.attr_set) |
| self._recv_dbg_print(reply, nms) |
| for nl_msg in nms: |
| if nl_msg.nl_seq in reqs_by_seq: |
| (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] |
| if nl_msg.extack: |
| self._decode_extack(req_msg, op, nl_msg.extack) |
| else: |
| op = self.rsp_by_value[nl_msg.cmd()] |
| req_flags = [] |
| |
| if nl_msg.error: |
| raise NlError(nl_msg) |
| if nl_msg.done: |
| if nl_msg.extack: |
| print("Netlink warning:") |
| print(nl_msg) |
| |
| if Netlink.NLM_F_DUMP in req_flags: |
| rsp.append(op_rsp) |
| elif not op_rsp: |
| rsp.append(None) |
| elif len(op_rsp) == 1: |
| rsp.append(op_rsp[0]) |
| else: |
| rsp.append(op_rsp) |
| op_rsp = [] |
| |
| del reqs_by_seq[nl_msg.nl_seq] |
| done = len(reqs_by_seq) == 0 |
| break |
| |
| decoded = self.nlproto.decode(self, nl_msg, op) |
| |
| # Check if this is a reply to our request |
| if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: |
| if decoded.cmd() in self.async_msg_ids: |
| self.handle_ntf(decoded) |
| continue |
| else: |
| print('Unexpected message: ' + repr(decoded)) |
| continue |
| |
| rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) |
| if op.fixed_header: |
| rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) |
| op_rsp.append(rsp_msg) |
| |
| return rsp |
| |
| def _op(self, method, vals, flags=None, dump=False): |
| req_flags = flags or [] |
| if dump: |
| req_flags.append(Netlink.NLM_F_DUMP) |
| |
| ops = [(method, vals, req_flags)] |
| return self._ops(ops)[0] |
| |
| def do(self, method, vals, flags=None): |
| return self._op(method, vals, flags) |
| |
| def dump(self, method, vals): |
| return self._op(method, vals, dump=True) |
| |
| def do_multi(self, ops): |
| return self._ops(ops) |