diff --git a/trsfile/parametermap.py b/trsfile/parametermap.py index 969f5a2..5572017 100644 --- a/trsfile/parametermap.py +++ b/trsfile/parametermap.py @@ -264,7 +264,11 @@ def deserialize(raw: BytesIO) -> TraceSetParameterMap: for _ in range(number_of_entries): name = read_parameter_name(raw) value = TraceSetParameter.deserialize(raw) - result[name] = value + # Writing `result[name] = value` would cause the overridden `__setitem__` + # method in the `TraceParameterMap` to be called. That overridden method + # does additional type checking. There is no need to do type checking + # when deserializing. So invoke the base class method explicitly. + StringKeyOrderedDict.__setitem__(result, name, value) return result def serialize(self) -> bytes: @@ -474,7 +478,11 @@ def deserialize(raw: bytes, definitions: TraceParameterDefinitionMap) -> TracePa for key, val in definitions.items(): io_bytes.seek(val.offset) param = val.param_type.param_class.deserialize(io_bytes, val.length) - result[key] = param + # Writing `result[name] = value` would cause the overridden `__setitem__` + # method in the `TraceParameterMap` to be called. That overridden method + # does additional type checking. There is no need to do type checking + # when deserializing. So invoke the base class method explicitly. + StringKeyOrderedDict.__setitem__(result, key, param) return result def serialize(self) -> bytearray: diff --git a/trsfile/traceparameter.py b/trsfile/traceparameter.py index 51b5184..a2400da 100644 --- a/trsfile/traceparameter.py +++ b/trsfile/traceparameter.py @@ -37,16 +37,17 @@ def serialize(self) -> bytes: def _has_expected_type(value: Any) -> bool: pass - def __init__(self, value): - if type(value) is ndarray and len(value.shape) > 1: - warnings.warn("Flatting multi-dimensional ndarray before adding it to trace parameter.\n" - "Information about dimensions of this ndarray will be lost.") - value = value.flatten() - if value is None or ((type(value) is list or type(value) is ndarray) and len(value) <= 0): - raise ValueError('The value for a TraceParameter cannot be empty') - if not type(self)._has_expected_type(value): - raise TypeError(f'A {type(self).__name__} must have a value of type "{type(self)._expected_type_string}"' - f', but it has a type of {type(value)}') + def __init__(self, value, skip_validation=False): + if not skip_validation: + if type(value) is ndarray and len(value.shape) > 1: + warnings.warn("Flatting multi-dimensional ndarray before adding it to trace parameter.\n" + "Information about dimensions of this ndarray will be lost.") + value = value.flatten() + if value is None or ((type(value) is list or type(value) is ndarray) and len(value) <= 0): + raise ValueError('The value for a TraceParameter cannot be empty') + if not type(self)._has_expected_type(value): + raise TypeError(f'A {type(self).__name__} must have a value of type "{type(self)._expected_type_string}"' + f', but it has a type of {type(value)}') self.value = value def __len__(self): @@ -84,7 +85,7 @@ def __len__(self): def deserialize(io_bytes: BytesIO, param_length: int) -> BooleanArrayParameter: raw_values = io_bytes.read(ParameterType.BOOL.byte_size * param_length) param_value = [bool(x) for x in list(raw_values)] - return BooleanArrayParameter(param_value) + return BooleanArrayParameter(param_value, skip_validation=True) def serialize(self) -> bytes: out = bytearray() @@ -112,7 +113,7 @@ def __eq__(self, other): @staticmethod def deserialize(io_bytes: BytesIO, param_length: int): param_value = list(io_bytes.read(ParameterType.BYTE.byte_size * param_length)) - return ByteArrayParameter(param_value) + return ByteArrayParameter(param_value, skip_validation=True) def __str__(self): return '0x' + bytes(self.value).hex().upper() if self.value else '' @@ -139,7 +140,7 @@ class DoubleArrayParameter(TraceParameter): @staticmethod def deserialize(io_bytes: BytesIO, param_length: int) -> DoubleArrayParameter: param_value = [struct.unpack(' bytes: out = bytearray() @@ -162,7 +163,7 @@ class FloatArrayParameter(TraceParameter): @staticmethod def deserialize(io_bytes: BytesIO, param_length: int) -> FloatArrayParameter: param_value = [struct.unpack(' bytes: out = bytearray() @@ -185,7 +186,7 @@ class IntegerArrayParameter(TraceParameter): @staticmethod def deserialize(io_bytes: BytesIO, param_length: int) -> IntegerArrayParameter: param_value = [struct.unpack(' bytes: out = bytearray() @@ -208,7 +209,7 @@ class LongArrayParameter(TraceParameter): @staticmethod def deserialize(io_bytes: BytesIO, param_length: int) -> LongArrayParameter: param_value = [struct.unpack(' bytes: out = bytearray() @@ -231,7 +232,7 @@ class ShortArrayParameter(TraceParameter): @staticmethod def deserialize(io_bytes: BytesIO, param_length: int) -> ShortArrayParameter: param_value = [struct.unpack(' bytes: out = bytearray() @@ -261,7 +262,7 @@ def __eq__(self, other): def deserialize(io_bytes: BytesIO, param_length: int) -> StringParameter: bytes_read = io_bytes.read(ParameterType.STRING.byte_size * param_length) param_value = bytes_read.decode(UTF_8) - return StringParameter(param_value) + return StringParameter(param_value, skip_validation=True) def serialize(self) -> bytes: out = bytearray()