summaryrefslogtreecommitdiff
path: root/tools/net/ynl/lib/ynl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/net/ynl/lib/ynl.py')
-rw-r--r--tools/net/ynl/lib/ynl.py711
1 files changed, 563 insertions, 148 deletions
diff --git a/tools/net/ynl/lib/ynl.py b/tools/net/ynl/lib/ynl.py
index 3ca28d4bcb18..35e666928119 100644
--- a/tools/net/ynl/lib/ynl.py
+++ b/tools/net/ynl/lib/ynl.py
@@ -1,12 +1,14 @@
# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
from collections import namedtuple
+from enum import Enum
import functools
import os
import random
import socket
import struct
from struct import Struct
+import sys
import yaml
import ipaddress
import uuid
@@ -25,6 +27,7 @@ class Netlink:
NETLINK_ADD_MEMBERSHIP = 1
NETLINK_CAP_ACK = 10
NETLINK_EXT_ACK = 11
+ NETLINK_GET_STRICT_CHK = 12
# Netlink message
NLMSG_ERROR = 2
@@ -34,6 +37,10 @@ class Netlink:
NLM_F_ACK = 4
NLM_F_ROOT = 0x100
NLM_F_MATCH = 0x200
+
+ NLM_F_REPLACE = 0x100
+ NLM_F_EXCL = 0x200
+ NLM_F_CREATE = 0x400
NLM_F_APPEND = 0x800
NLM_F_CAPPED = 0x100
@@ -70,13 +77,37 @@ class Netlink:
NLMSGERR_ATTR_MISS_TYPE = 5
NLMSGERR_ATTR_MISS_NEST = 6
+ # Policy types
+ NL_POLICY_TYPE_ATTR_TYPE = 1
+ NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
+ NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
+ NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
+ NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
+ NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
+ NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
+ NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
+ NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
+ NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
+ NL_POLICY_TYPE_ATTR_PAD = 11
+ NL_POLICY_TYPE_ATTR_MASK = 12
+
+ AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
+ 's8', 's16', 's32', 's64',
+ 'binary', 'string', 'nul-string',
+ 'nested', 'nested-array',
+ 'bitfield32', 'sint', 'uint'])
class NlError(Exception):
def __init__(self, nl_msg):
self.nl_msg = nl_msg
+ self.error = -nl_msg.error
def __str__(self):
- return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
+ return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
+
+
+class ConfigError(Exception):
+ pass
class NlAttr:
@@ -93,11 +124,12 @@ class NlAttr:
}
def __init__(self, raw, offset):
- self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
+ self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
self.type = self._type & ~Netlink.NLA_TYPE_MASK
+ self.is_nest = self._type & Netlink.NLA_F_NESTED
self.payload_len = self._len
self.full_len = (self.payload_len + 3) & ~3
- self.raw = raw[offset + 4:offset + self.payload_len]
+ self.raw = raw[offset + 4 : offset + self.payload_len]
@classmethod
def get_format(cls, attr_type, byte_order=None):
@@ -107,24 +139,17 @@ class NlAttr:
else format.little
return format.native
- @classmethod
- def formatted_string(cls, raw, display_hint):
- if display_hint == 'mac':
- formatted = ':'.join('%02x' % b for b in raw)
- elif display_hint == 'hex':
- formatted = bytes.hex(raw, ' ')
- elif display_hint in [ 'ipv4', 'ipv6' ]:
- formatted = format(ipaddress.ip_address(raw))
- elif display_hint == 'uuid':
- formatted = str(uuid.UUID(bytes=raw))
- else:
- formatted = raw
- return formatted
-
def as_scalar(self, attr_type, byte_order=None):
format = self.get_format(attr_type, byte_order)
return format.unpack(self.raw)[0]
+ def as_auto_scalar(self, attr_type, byte_order=None):
+ if len(self.raw) != 4 and len(self.raw) != 8:
+ raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
+ real_type = attr_type[0] + str(len(self.raw) * 8)
+ format = self.get_format(real_type, byte_order)
+ return format.unpack(self.raw)[0]
+
def as_strz(self):
return self.raw.decode('ascii')[:-1]
@@ -135,32 +160,14 @@ class NlAttr:
format = self.get_format(type)
return [ x[0] for x in format.iter_unpack(self.raw) ]
- def as_struct(self, members):
- value = dict()
- offset = 0
- for m in members:
- # TODO: handle non-scalar members
- if m.type == 'binary':
- decoded = self.raw[offset:offset+m['len']]
- offset += m['len']
- elif m.type in NlAttr.type_formats:
- format = self.get_format(m.type, m.byte_order)
- [ decoded ] = format.unpack_from(self.raw, offset)
- offset += format.size
- if m.display_hint:
- decoded = self.formatted_string(decoded, m.display_hint)
- value[m.name] = decoded
- return value
-
def __repr__(self):
return f"[type:{self.type} len:{self._len}] {self.raw}"
class NlAttrs:
- def __init__(self, msg):
+ def __init__(self, msg, offset=0):
self.attrs = []
- offset = 0
while offset < len(msg):
attr = NlAttr(msg, offset)
offset += attr.full_len
@@ -180,12 +187,12 @@ class NlAttrs:
class NlMsg:
def __init__(self, msg, offset, attr_space=None):
- self.hdr = msg[offset:offset + 16]
+ self.hdr = msg[offset : offset + 16]
self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
struct.unpack("IHHII", self.hdr)
- self.raw = msg[offset + 16:offset + self.nl_len]
+ self.raw = msg[offset + 16 : offset + self.nl_len]
self.error = 0
self.done = 0
@@ -196,6 +203,7 @@ class NlMsg:
self.done = 1
extack_off = 20
elif self.nl_type == Netlink.NLMSG_DONE:
+ self.error = struct.unpack("i", self.raw[0:4])[0]
self.done = 1
extack_off = 4
@@ -212,6 +220,8 @@ class NlMsg:
self.extack['miss-nest'] = extack.as_scalar('u32')
elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
self.extack['bad-attr-offs'] = extack.as_scalar('u32')
+ elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
+ self.extack['policy'] = self._decode_policy(extack.raw)
else:
if 'unknown' not in self.extack:
self.extack['unknown'] = []
@@ -223,17 +233,43 @@ class NlMsg:
miss_type = self.extack['miss-type']
if miss_type in attr_space.attrs_by_val:
spec = attr_space.attrs_by_val[miss_type]
- desc = spec['name']
+ self.extack['miss-type'] = spec['name']
if 'doc' in spec:
- desc += f" ({spec['doc']})"
- self.extack['miss-type'] = desc
+ self.extack['miss-type-doc'] = spec['doc']
+
+ def _decode_policy(self, raw):
+ policy = {}
+ for attr in NlAttrs(raw):
+ if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
+ type = attr.as_scalar('u32')
+ policy['type'] = Netlink.AttrType(type).name
+ elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
+ policy['min-value'] = attr.as_scalar('s64')
+ elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
+ policy['max-value'] = attr.as_scalar('s64')
+ elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
+ policy['min-value'] = attr.as_scalar('u64')
+ elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
+ policy['max-value'] = attr.as_scalar('u64')
+ elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
+ policy['min-length'] = attr.as_scalar('u32')
+ elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
+ policy['max-length'] = attr.as_scalar('u32')
+ elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
+ policy['bitfield32-mask'] = attr.as_scalar('u32')
+ elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
+ policy['mask'] = attr.as_scalar('u64')
+ return policy
+
+ def cmd(self):
+ return self.nl_type
def __repr__(self):
- msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
+ msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
if self.error:
- msg += '\terror: ' + str(self.error)
+ msg += '\n\terror: ' + str(self.error)
if self.extack:
- msg += '\textack: ' + repr(self.extack)
+ msg += '\n\textack: ' + repr(self.extack)
return msg
@@ -293,7 +329,7 @@ def _genl_load_families():
gm = GenlMsg(nl_msg)
fam = dict()
- for attr in gm.raw_attrs:
+ for attr in NlAttrs(gm.raw):
if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
fam['id'] = attr.as_scalar('u16')
elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
@@ -317,23 +353,13 @@ def _genl_load_families():
class GenlMsg:
- def __init__(self, nl_msg, fixed_header_members=[]):
+ def __init__(self, nl_msg):
self.nl = nl_msg
+ self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
+ self.raw = nl_msg.raw[4:]
- self.hdr = nl_msg.raw[0:4]
- offset = 4
-
- self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr)
-
- self.fixed_header_attrs = dict()
- for m in fixed_header_members:
- format = NlAttr.get_format(m.type, m.byte_order)
- decoded = format.unpack_from(nl_msg.raw, offset)
- offset += format.size
- self.fixed_header_attrs[m.name] = decoded[0]
-
- self.raw = nl_msg.raw[offset:]
- self.raw_attrs = NlAttrs(self.raw)
+ def cmd(self):
+ return self.genl_cmd
def __repr__(self):
msg = repr(self.nl)
@@ -343,9 +369,41 @@ class GenlMsg:
return msg
-class GenlFamily:
- def __init__(self, family_name):
+class NetlinkProtocol:
+ def __init__(self, family_name, proto_num):
self.family_name = family_name
+ self.proto_num = proto_num
+
+ def _message(self, nl_type, nl_flags, seq=None):
+ if seq is None:
+ seq = random.randint(1, 1024)
+ nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
+ return nlmsg
+
+ def message(self, flags, command, version, seq=None):
+ return self._message(command, flags, seq)
+
+ def _decode(self, nl_msg):
+ return nl_msg
+
+ def decode(self, ynl, nl_msg, op):
+ msg = self._decode(nl_msg)
+ fixed_header_size = ynl._struct_size(op.fixed_header)
+ msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
+ return msg
+
+ def get_mcast_id(self, mcast_name, mcast_groups):
+ if mcast_name not in mcast_groups:
+ raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
+ return mcast_groups[mcast_name].value
+
+ def msghdr_size(self):
+ return 16
+
+
+class GenlProtocol(NetlinkProtocol):
+ def __init__(self, family_name):
+ super().__init__(family_name, Netlink.NETLINK_GENERIC)
global genl_family_name_to_id
if genl_family_name_to_id is None:
@@ -354,6 +412,41 @@ class GenlFamily:
self.genl_family = genl_family_name_to_id[family_name]
self.family_id = genl_family_name_to_id[family_name]['id']
+ def message(self, flags, command, version, seq=None):
+ nlmsg = self._message(self.family_id, flags, seq)
+ genlmsg = struct.pack("BBH", command, version, 0)
+ return nlmsg + genlmsg
+
+ def _decode(self, nl_msg):
+ return GenlMsg(nl_msg)
+
+ def get_mcast_id(self, mcast_name, mcast_groups):
+ if mcast_name not in self.genl_family['mcast']:
+ raise Exception(f'Multicast group "{mcast_name}" not present in the family')
+ return self.genl_family['mcast'][mcast_name]
+
+ def msghdr_size(self):
+ return super().msghdr_size() + 4
+
+
+class SpaceAttrs:
+ SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
+
+ def __init__(self, attr_space, attrs, outer = None):
+ outer_scopes = outer.scopes if outer else []
+ inner_scope = self.SpecValuesPair(attr_space, attrs)
+ self.scopes = [inner_scope] + outer_scopes
+
+ def lookup(self, name):
+ for scope in self.scopes:
+ if name in scope.spec:
+ if name in scope.values:
+ return scope.values[name]
+ spec_name = scope.spec.yaml['name']
+ raise Exception(
+ f"No value for '{name}' in attribute space '{spec_name}'")
+ raise Exception(f"Attribute '{name}' not defined in any attribute-set")
+
#
# YNL implementation details.
@@ -361,14 +454,37 @@ class GenlFamily:
class YnlFamily(SpecFamily):
- def __init__(self, def_path, schema=None):
+ def __init__(self, def_path, schema=None, process_unknown=False,
+ recv_size=0):
super().__init__(def_path, schema)
self.include_raw = False
+ self.process_unknown = process_unknown
- self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
+ try:
+ if self.proto == "netlink-raw":
+ self.nlproto = NetlinkProtocol(self.yaml['name'],
+ self.yaml['protonum'])
+ else:
+ self.nlproto = GenlProtocol(self.yaml['name'])
+ except KeyError:
+ raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
+
+ self._recv_dbg = False
+ # Note that netlink will use conservative (min) message size for
+ # the first dump recv() on the socket, our setting will only matter
+ # from the second recv() on.
+ self._recv_size = recv_size if recv_size else 131072
+ # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
+ # for a message, so smaller receive sizes will lead to truncation.
+ # Note that the min size for other families may be larger than 4k!
+ if self._recv_size < 4000:
+ raise ConfigError()
+
+ self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
+ self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
self.async_msg_ids = set()
self.async_msg_queue = []
@@ -381,36 +497,106 @@ class YnlFamily(SpecFamily):
bound_f = functools.partial(self._op, op_name)
setattr(self, op.ident_name, bound_f)
- try:
- self.family = GenlFamily(self.yaml['name'])
- except KeyError:
- raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
def ntf_subscribe(self, mcast_name):
- if mcast_name not in self.family.genl_family['mcast']:
- raise Exception(f'Multicast group "{mcast_name}" not present in the family')
-
+ mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
self.sock.bind((0, 0))
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
- self.family.genl_family['mcast'][mcast_name])
+ mcast_id)
- def _add_attr(self, space, name, value):
- attr = self.attr_sets[space][name]
+ def set_recv_dbg(self, enabled):
+ self._recv_dbg = enabled
+
+ def _recv_dbg_print(self, reply, nl_msgs):
+ if not self._recv_dbg:
+ return
+ print("Recv: read", len(reply), "bytes,",
+ len(nl_msgs.msgs), "messages", file=sys.stderr)
+ for nl_msg in nl_msgs:
+ print(" ", nl_msg, file=sys.stderr)
+
+ def _encode_enum(self, attr_spec, value):
+ enum = self.consts[attr_spec['enum']]
+ if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
+ scalar = 0
+ if isinstance(value, str):
+ value = [value]
+ for single_value in value:
+ scalar += enum.entries[single_value].user_value(as_flags = True)
+ return scalar
+ else:
+ return enum.entries[value].user_value()
+
+ def _get_scalar(self, attr_spec, value):
+ try:
+ return int(value)
+ except (ValueError, TypeError) as e:
+ if 'enum' not in attr_spec:
+ raise e
+ return self._encode_enum(attr_spec, value)
+
+ def _add_attr(self, space, name, value, search_attrs):
+ try:
+ attr = self.attr_sets[space][name]
+ except KeyError:
+ raise Exception(f"Space '{space}' has no attribute '{name}'")
nl_type = attr.value
+
+ if attr.is_multi and isinstance(value, list):
+ attr_payload = b''
+ for subvalue in value:
+ attr_payload += self._add_attr(space, name, subvalue, search_attrs)
+ return attr_payload
+
if attr["type"] == 'nest':
nl_type |= Netlink.NLA_F_NESTED
attr_payload = b''
+ sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
for subname, subvalue in value.items():
- attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
+ attr_payload += self._add_attr(attr['nested-attributes'],
+ subname, subvalue, sub_attrs)
elif attr["type"] == 'flag':
+ if not value:
+ # If value is absent or false then skip attribute creation.
+ return b''
attr_payload = b''
elif attr["type"] == 'string':
attr_payload = str(value).encode('ascii') + b'\x00'
elif attr["type"] == 'binary':
- attr_payload = bytes.fromhex(value)
- elif attr['type'] in NlAttr.type_formats:
- format = NlAttr.get_format(attr['type'], attr.byte_order)
- attr_payload = format.pack(int(value))
+ if isinstance(value, bytes):
+ attr_payload = value
+ elif isinstance(value, str):
+ attr_payload = bytes.fromhex(value)
+ elif isinstance(value, dict) and attr.struct_name:
+ attr_payload = self._encode_struct(attr.struct_name, value)
+ else:
+ raise Exception(f'Unknown type for binary attribute, value: {value}')
+ elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
+ scalar = self._get_scalar(attr, value)
+ if attr.is_auto_scalar:
+ attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
+ else:
+ attr_type = attr["type"]
+ format = NlAttr.get_format(attr_type, attr.byte_order)
+ attr_payload = format.pack(scalar)
+ elif attr['type'] in "bitfield32":
+ scalar_value = self._get_scalar(attr, value["value"])
+ scalar_selector = self._get_scalar(attr, value["selector"])
+ attr_payload = struct.pack("II", scalar_value, scalar_selector)
+ elif attr['type'] == 'sub-message':
+ msg_format = self._resolve_selector(attr, search_attrs)
+ attr_payload = b''
+ if msg_format.fixed_header:
+ attr_payload += self._encode_struct(msg_format.fixed_header, value)
+ if msg_format.attr_set:
+ if msg_format.attr_set in self.attr_sets:
+ nl_type |= Netlink.NLA_F_NESTED
+ sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
+ for subname, subvalue in value.items():
+ attr_payload += self._add_attr(msg_format.attr_set,
+ subname, subvalue, sub_attrs)
+ else:
+ raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
else:
raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
@@ -419,7 +605,7 @@ class YnlFamily(SpecFamily):
def _decode_enum(self, raw, attr_spec):
enum = self.consts[attr_spec['enum']]
- if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
+ if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
i = 0
value = set()
while raw:
@@ -433,26 +619,117 @@ class YnlFamily(SpecFamily):
def _decode_binary(self, attr, attr_spec):
if attr_spec.struct_name:
- members = self.consts[attr_spec.struct_name]
- decoded = attr.as_struct(members)
- for m in members:
- if m.enum:
- decoded[m.name] = self._decode_enum(decoded[m.name], m)
+ decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
elif attr_spec.sub_type:
decoded = attr.as_c_array(attr_spec.sub_type)
else:
decoded = attr.as_bin()
if attr_spec.display_hint:
- decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint)
+ decoded = self._formatted_string(decoded, attr_spec.display_hint)
return decoded
- def _decode(self, attrs, space):
- attr_space = self.attr_sets[space]
+ def _decode_array_attr(self, attr, attr_spec):
+ decoded = []
+ offset = 0
+ while offset < len(attr.raw):
+ item = NlAttr(attr.raw, offset)
+ offset += item.full_len
+
+ if attr_spec["sub-type"] == 'nest':
+ subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
+ decoded.append({ item.type: subattrs })
+ elif attr_spec["sub-type"] == 'binary':
+ subattrs = item.as_bin()
+ if attr_spec.display_hint:
+ subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
+ decoded.append(subattrs)
+ elif attr_spec["sub-type"] in NlAttr.type_formats:
+ subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
+ if attr_spec.display_hint:
+ subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
+ decoded.append(subattrs)
+ else:
+ raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
+ return decoded
+
+ def _decode_nest_type_value(self, attr, attr_spec):
+ decoded = {}
+ value = attr
+ for name in attr_spec['type-value']:
+ value = NlAttr(value.raw, 0)
+ decoded[name] = value.type
+ subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
+ decoded.update(subattrs)
+ return decoded
+
+ def _decode_unknown(self, attr):
+ if attr.is_nest:
+ return self._decode(NlAttrs(attr.raw), None)
+ else:
+ return attr.as_bin()
+
+ def _rsp_add(self, rsp, name, is_multi, decoded):
+ if is_multi == None:
+ if name in rsp and type(rsp[name]) is not list:
+ rsp[name] = [rsp[name]]
+ is_multi = True
+ else:
+ is_multi = False
+
+ if not is_multi:
+ rsp[name] = decoded
+ elif name in rsp:
+ rsp[name].append(decoded)
+ else:
+ rsp[name] = [decoded]
+
+ def _resolve_selector(self, attr_spec, search_attrs):
+ sub_msg = attr_spec.sub_message
+ if sub_msg not in self.sub_msgs:
+ raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
+ sub_msg_spec = self.sub_msgs[sub_msg]
+
+ selector = attr_spec.selector
+ value = search_attrs.lookup(selector)
+ if value not in sub_msg_spec.formats:
+ raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
+
+ spec = sub_msg_spec.formats[value]
+ return spec
+
+ def _decode_sub_msg(self, attr, attr_spec, search_attrs):
+ msg_format = self._resolve_selector(attr_spec, search_attrs)
+ decoded = {}
+ offset = 0
+ if msg_format.fixed_header:
+ decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
+ offset = self._struct_size(msg_format.fixed_header)
+ if msg_format.attr_set:
+ if msg_format.attr_set in self.attr_sets:
+ subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
+ decoded.update(subdict)
+ else:
+ raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
+ return decoded
+
+ def _decode(self, attrs, space, outer_attrs = None):
rsp = dict()
+ if space:
+ attr_space = self.attr_sets[space]
+ search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
+
for attr in attrs:
- attr_spec = attr_space.attrs_by_val[attr.type]
+ try:
+ attr_spec = attr_space.attrs_by_val[attr.type]
+ except (KeyError, UnboundLocalError):
+ if not self.process_unknown:
+ raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
+ attr_name = f"UnknownAttr({attr.type})"
+ self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
+ continue
+
if attr_spec["type"] == 'nest':
- subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
+ subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
decoded = subdict
elif attr_spec["type"] == 'string':
decoded = attr.as_strz()
@@ -460,26 +737,39 @@ class YnlFamily(SpecFamily):
decoded = self._decode_binary(attr, attr_spec)
elif attr_spec["type"] == 'flag':
decoded = True
+ elif attr_spec.is_auto_scalar:
+ decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
elif attr_spec["type"] in NlAttr.type_formats:
decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
+ if 'enum' in attr_spec:
+ decoded = self._decode_enum(decoded, attr_spec)
+ elif attr_spec["type"] == 'indexed-array':
+ decoded = self._decode_array_attr(attr, attr_spec)
+ elif attr_spec["type"] == 'bitfield32':
+ value, selector = struct.unpack("II", attr.raw)
+ if 'enum' in attr_spec:
+ value = self._decode_enum(value, attr_spec)
+ selector = self._decode_enum(selector, attr_spec)
+ decoded = {"value": value, "selector": selector}
+ elif attr_spec["type"] == 'sub-message':
+ decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
+ elif attr_spec["type"] == 'nest-type-value':
+ decoded = self._decode_nest_type_value(attr, attr_spec)
else:
- raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
-
- if 'enum' in attr_spec:
- decoded = self._decode_enum(decoded, attr_spec)
+ if not self.process_unknown:
+ raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
+ decoded = self._decode_unknown(attr)
- if not attr_spec.is_multi:
- rsp[attr_spec['name']] = decoded
- elif attr_spec.name in rsp:
- rsp[attr_spec.name].append(decoded)
- else:
- rsp[attr_spec.name] = [decoded]
+ self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
return rsp
def _decode_extack_path(self, attrs, attr_set, offset, target):
for attr in attrs:
- attr_spec = attr_set.attrs_by_val[attr.type]
+ try:
+ attr_spec = attr_set.attrs_by_val[attr.type]
+ except KeyError:
+ raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
if offset > target:
break
if offset == target:
@@ -500,35 +790,126 @@ class YnlFamily(SpecFamily):
return None
- def _decode_extack(self, request, attr_space, extack):
+ def _decode_extack(self, request, op, extack):
if 'bad-attr-offs' not in extack:
return
- genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space))
- path = self._decode_extack_path(genl_req.raw_attrs, attr_space,
- 20, extack['bad-attr-offs'])
+ msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
+ offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
+ path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
+ extack['bad-attr-offs'])
if path:
del extack['bad-attr-offs']
extack['bad-attr'] = path
- def handle_ntf(self, nl_msg, genl_msg):
+ def _struct_size(self, name):
+ if name:
+ members = self.consts[name].members
+ size = 0
+ for m in members:
+ if m.type in ['pad', 'binary']:
+ if m.struct:
+ size += self._struct_size(m.struct)
+ else:
+ size += m.len
+ else:
+ format = NlAttr.get_format(m.type, m.byte_order)
+ size += format.size
+ return size
+ else:
+ return 0
+
+ def _decode_struct(self, data, name):
+ members = self.consts[name].members
+ attrs = dict()
+ offset = 0
+ for m in members:
+ value = None
+ if m.type == 'pad':
+ offset += m.len
+ elif m.type == 'binary':
+ if m.struct:
+ len = self._struct_size(m.struct)
+ value = self._decode_struct(data[offset : offset + len],
+ m.struct)
+ offset += len
+ else:
+ value = data[offset : offset + m.len]
+ offset += m.len
+ else:
+ format = NlAttr.get_format(m.type, m.byte_order)
+ [ value ] = format.unpack_from(data, offset)
+ offset += format.size
+ if value is not None:
+ if m.enum:
+ value = self._decode_enum(value, m)
+ elif m.display_hint:
+ value = self._formatted_string(value, m.display_hint)
+ attrs[m.name] = value
+ return attrs
+
+ def _encode_struct(self, name, vals):
+ members = self.consts[name].members
+ attr_payload = b''
+ for m in members:
+ value = vals.pop(m.name) if m.name in vals else None
+ if m.type == 'pad':
+ attr_payload += bytearray(m.len)
+ elif m.type == 'binary':
+ if m.struct:
+ if value is None:
+ value = dict()
+ attr_payload += self._encode_struct(m.struct, value)
+ else:
+ if value is None:
+ attr_payload += bytearray(m.len)
+ else:
+ attr_payload += bytes.fromhex(value)
+ else:
+ if value is None:
+ value = 0
+ format = NlAttr.get_format(m.type, m.byte_order)
+ attr_payload += format.pack(value)
+ return attr_payload
+
+ def _formatted_string(self, raw, display_hint):
+ if display_hint == 'mac':
+ formatted = ':'.join('%02x' % b for b in raw)
+ elif display_hint == 'hex':
+ if isinstance(raw, int):
+ formatted = hex(raw)
+ else:
+ formatted = bytes.hex(raw, ' ')
+ elif display_hint in [ 'ipv4', 'ipv6' ]:
+ formatted = format(ipaddress.ip_address(raw))
+ elif display_hint == 'uuid':
+ formatted = str(uuid.UUID(bytes=raw))
+ else:
+ formatted = raw
+ return formatted
+
+ def handle_ntf(self, decoded):
msg = dict()
if self.include_raw:
- msg['nlmsg'] = nl_msg
- msg['genlmsg'] = genl_msg
- op = self.rsp_by_value[genl_msg.genl_cmd]
+ msg['raw'] = decoded
+ op = self.rsp_by_value[decoded.cmd()]
+ attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
+ if op.fixed_header:
+ attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
+
msg['name'] = op['name']
- msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
+ msg['msg'] = attrs
self.async_msg_queue.append(msg)
def check_ntf(self):
while True:
try:
- reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
+ reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
except BlockingIOError:
return
nms = NlMsgs(reply)
+ self._recv_dbg_print(reply, nms)
for nl_msg in nms:
if nl_msg.error:
print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
@@ -538,12 +919,13 @@ class YnlFamily(SpecFamily):
print("Netlink done while checking for ntf!?")
continue
- gm = GenlMsg(nl_msg)
- if gm.genl_cmd not in self.async_msg_ids:
- print("Unexpected msg id done while checking for ntf", gm)
+ op = self.rsp_by_value[nl_msg.cmd()]
+ decoded = self.nlproto.decode(self, nl_msg, op)
+ if decoded.cmd() not in self.async_msg_ids:
+ print("Unexpected msg id done while checking for ntf", decoded)
continue
- self.handle_ntf(nl_msg, gm)
+ self.handle_ntf(decoded)
def operation_do_attributes(self, name):
"""
@@ -556,36 +938,48 @@ class YnlFamily(SpecFamily):
return op['do']['request']['attributes'].copy()
- def _op(self, method, vals, dump=False):
- op = self.ops[method]
-
+ def _encode_message(self, op, vals, flags, req_seq):
nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
- if dump:
- nl_flags |= Netlink.NLM_F_DUMP
+ for flag in flags or []:
+ nl_flags |= flag
- req_seq = random.randint(1024, 65535)
- msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
- fixed_header_members = []
+ msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
if op.fixed_header:
- fixed_header_members = self.consts[op.fixed_header].members
- for m in fixed_header_members:
- value = vals.pop(m.name) if m.name in vals else 0
- format = NlAttr.get_format(m.type, m.byte_order)
- msg += format.pack(value)
+ msg += self._encode_struct(op.fixed_header, vals)
+ search_attrs = SpaceAttrs(op.attr_set, vals)
for name, value in vals.items():
- msg += self._add_attr(op.attr_set.name, name, value)
+ msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
msg = _genl_msg_finalize(msg)
+ return msg
- self.sock.send(msg, 0)
+ def _ops(self, ops):
+ reqs_by_seq = {}
+ req_seq = random.randint(1024, 65535)
+ payload = b''
+ for (method, vals, flags) in ops:
+ op = self.ops[method]
+ msg = self._encode_message(op, vals, flags, req_seq)
+ reqs_by_seq[req_seq] = (op, msg, flags)
+ payload += msg
+ req_seq += 1
+
+ self.sock.send(payload, 0)
done = False
rsp = []
+ op_rsp = []
while not done:
- reply = self.sock.recv(128 * 1024)
+ reply = self.sock.recv(self._recv_size)
nms = NlMsgs(reply, attr_space=op.attr_set)
+ self._recv_dbg_print(reply, nms)
for nl_msg in nms:
- if nl_msg.extack:
- self._decode_extack(msg, op.attr_set, nl_msg.extack)
+ if nl_msg.nl_seq in reqs_by_seq:
+ (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
+ if nl_msg.extack:
+ self._decode_extack(req_msg, op, nl_msg.extack)
+ else:
+ op = self.rsp_by_value[nl_msg.cmd()]
+ req_flags = []
if nl_msg.error:
raise NlError(nl_msg)
@@ -593,31 +987,52 @@ class YnlFamily(SpecFamily):
if nl_msg.extack:
print("Netlink warning:")
print(nl_msg)
- done = True
+
+ if Netlink.NLM_F_DUMP in req_flags:
+ rsp.append(op_rsp)
+ elif not op_rsp:
+ rsp.append(None)
+ elif len(op_rsp) == 1:
+ rsp.append(op_rsp[0])
+ else:
+ rsp.append(op_rsp)
+ op_rsp = []
+
+ del reqs_by_seq[nl_msg.nl_seq]
+ done = len(reqs_by_seq) == 0
break
- gm = GenlMsg(nl_msg, fixed_header_members)
+ decoded = self.nlproto.decode(self, nl_msg, op)
+
# Check if this is a reply to our request
- if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
- if gm.genl_cmd in self.async_msg_ids:
- self.handle_ntf(nl_msg, gm)
+ if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
+ if decoded.cmd() in self.async_msg_ids:
+ self.handle_ntf(decoded)
continue
else:
- print('Unexpected message: ' + repr(gm))
+ print('Unexpected message: ' + repr(decoded))
continue
- rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name)
- rsp_msg.update(gm.fixed_header_attrs)
- rsp.append(rsp_msg)
+ rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
+ if op.fixed_header:
+ rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
+ op_rsp.append(rsp_msg)
- if not rsp:
- return None
- if not dump and len(rsp) == 1:
- return rsp[0]
return rsp
- def do(self, method, vals):
- return self._op(method, vals)
+ def _op(self, method, vals, flags=None, dump=False):
+ req_flags = flags or []
+ if dump:
+ req_flags.append(Netlink.NLM_F_DUMP)
+
+ ops = [(method, vals, req_flags)]
+ return self._ops(ops)[0]
+
+ def do(self, method, vals, flags=None):
+ return self._op(method, vals, flags)
def dump(self, method, vals):
return self._op(method, vals, dump=True)
+
+ def do_multi(self, ops):
+ return self._ops(ops)