From e6e617dbdc4c3151f6aa20ff8e336a13064f4e86 Mon Sep 17 00:00:00 2001 From: Alina Lenk Date: Thu, 2 May 2024 21:11:01 +0200 Subject: [PATCH 1/2] generate_packets.py: always pass packet names to field type code See RM #527 Signed-off-by: Alina Lenk --- common/generate_packets.py | 392 ++++++++++++++++++------------------- 1 file changed, 189 insertions(+), 203 deletions(-) diff --git a/common/generate_packets.py b/common/generate_packets.py index 1d7ade7eba..507c68b39c 100755 --- a/common/generate_packets.py +++ b/common/generate_packets.py @@ -511,30 +511,18 @@ class SizeInfo: return self.declared return f"{packet}->{self._actual}" - @property - def real(self) -> str: - """The number of elements to transmit. Either the same as the - declared size, or a field of `*real_packet`.""" - return self.actual_for("real_packet") - - @property - def old(self) -> str: - """The number of elements transmitted last time. Either the same as - the declared size, or a field of `*old`.""" - return self.actual_for("old") - - def size_check_get(self, field_name: str) -> str: + def size_check_get(self, field_name: str, packet: str) -> str: """Generate a code snippet checking whether the received size is in range when receiving a packet.""" if self.constant: return "" return f"""\ -if ({self.real} > {self.declared}) {{ +if ({self.actual_for(packet)} > {self.declared}) {{ RECEIVE_PACKET_FIELD_ERROR({field_name}, ": array truncated"); }} """ - def size_check_index(self, field_name: str) -> str: + def size_check_index(self, field_name: str, packet: str) -> str: """Generate a code snippet asserting that indices can be correctly transmitted for array-diff.""" if self.constant: @@ -543,10 +531,10 @@ FC_STATIC_ASSERT({self.declared} <= MAX_UINT16, packet_array_too_long_{field_nam """ else: return f"""\ -fc_assert({self.real} <= MAX_UINT16); +fc_assert({self.actual_for(packet)} <= MAX_UINT16); """ - def index_put(self, index: str) -> str: + def index_put(self, packet: str, index: str) -> str: """Generate a code snippet writing the given value to the network output, encoded as the appropriate index type""" if self.constant: @@ -559,14 +547,14 @@ e |= DIO_PUT(uint16, &dout, &field_addr, {index}); """ else: return f"""\ -if ({self.real} <= MAX_UINT8) {{ +if ({self.actual_for(packet)} <= MAX_UINT8) {{ e |= DIO_PUT(uint8, &dout, &field_addr, {index}); }} else {{ e |= DIO_PUT(uint16, &dout, &field_addr, {index}); }} """ - def index_get(self, location: Location) -> str: + def index_get(self, packet: str, location: Location) -> str: """Generate a code snippet reading the next index from the network input decoded as the correct type""" if self.constant: @@ -581,7 +569,7 @@ if (!DIO_GET(uint16, &din, &field_addr, &{location.index})) {{ """ else: return f"""\ -if (({self.real} <= MAX_UINT8) +if (({self.actual_for(packet)} <= MAX_UINT8) ? !DIO_GET(uint8, &din, &field_addr, &{location.index}) : !DIO_GET(uint16, &din, &field_addr, &{location.index})) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); @@ -768,14 +756,14 @@ class FieldType(RawFieldType): See also self.get_code_handle_arg()""" raise NotImplementedError - def get_code_handle_arg(self, location: Location) -> str: + def get_code_handle_arg(self, location: Location, packet: str) -> str: """Generate a code fragment passing an argument with this type to a handle function. See also self.get_code_param()""" - return str(location) + return f"{packet}->{location}" - def get_code_init(self, location: Location) -> str: + def get_code_init(self, location: Location, packet: str) -> str: """Generate a code snippet initializing a field of this type in the packet struct, after the struct has already been zeroed. @@ -796,14 +784,14 @@ class FieldType(RawFieldType): {dest}->{location} = {src}->{location}; """ - def get_code_fill(self, location: Location) -> str: + def get_code_fill(self, location: Location, packet: str) -> str: """Generate a code snippet shallow-copying a value of this type from dsend arguments into a packet struct.""" return f"""\ -real_packet->{location} = {location}; +{packet}->{location} = {location}; """ - def get_code_free(self, location: Location) -> str: + def get_code_free(self, location: Location, packet: str) -> str: """Generate a code snippet deinitializing a field of this type in the packet struct before it gets destroyed. @@ -814,40 +802,41 @@ real_packet->{location} = {location}; return "" @abstractmethod - def get_code_hash(self, location: Location) -> str: + def get_code_hash(self, location: Location, packet: str) -> str: """Generate a code snippet factoring a field of this type into a hash computation's `result`.""" raise NotImplementedError @abstractmethod - def get_code_cmp(self, location: Location) -> str: + def get_code_cmp(self, location: Location, new: str, old: str) -> str: """Generate a code snippet comparing a field of this type between - the `old` and `real_packet` and setting `differ` accordingly.""" + the given packets and setting `differ` accordingly. The `old` + packet is one we know to have been initialized by our own code.""" raise NotImplementedError @abstractmethod - def get_code_put(self, location: Location, deep_diff: bool = False) -> str: + def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: """Generate a code snippet writing a field of this type to the dataio stream.""" raise NotImplementedError @abstractmethod - def get_code_get(self, location: Location, deep_diff: bool = False) -> str: + def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: """Generate a code snippet reading a field of this type from the dataio stream.""" raise NotImplementedError - def _compat_keys(self, location: Location): + def _compat_keys(self, location: Location, packet: str): """Internal helper function. Yield keys to compare for type compatibility. See is_type_compatible()""" yield self.get_code_declaration(location) yield self.get_code_param(location) - yield self.get_code_handle_arg(location) - yield self.get_code_fill(location) + yield self.get_code_handle_arg(location, packet) + yield self.get_code_fill(location, packet) yield self.complex if self.complex: - yield self.get_code_init(location) - yield self.get_code_free(location) + yield self.get_code_init(location, packet) + yield self.get_code_free(location, packet) def is_type_compatible(self, other: "FieldType") -> bool: """Determine whether two field types can be used interchangeably as @@ -855,11 +844,12 @@ real_packet->{location} = {location}; if other is self: return True loc = Location("compat_test_field_name") + pak = "compat_test_packet_name" return all( a == b for a, b in zip_longest( - self._compat_keys(loc), - other._compat_keys(loc), + self._compat_keys(loc, pak), + other._compat_keys(loc, pak), ) ) @@ -885,22 +875,22 @@ class BasicType(FieldType): def get_code_param(self, location: Location) -> str: return f"{self.public_type} {location}" - def get_code_hash(self, location: Location) -> str: + def get_code_hash(self, location: Location, packet: str) -> str: raise ValueError(f"hash not supported for type {self} in field {location.name}") - def get_code_cmp(self, location: Location) -> str: + def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = (old->{location} != real_packet->{location}); +differ = ({old}->{location} != {new}->{location}); """ - def get_code_put(self, location: Location, deep_diff: bool = False) -> str: + def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: return f"""\ -e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, real_packet->{location}); +e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, {packet}->{location}); """ - def get_code_get(self, location: Location, deep_diff: bool = False) -> str: + def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: return f"""\ -if (!DIO_GET({self.dataio_type}, &din, &field_addr, &real_packet->{location})) {{ +if (!DIO_GET({self.dataio_type}, &din, &field_addr, &{packet}->{location})) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} """ @@ -933,15 +923,15 @@ class IntType(BasicType): super().__init__(dataio_type, public_type) - def get_code_hash(self, location: Location) -> str: + def get_code_hash(self, location: Location, packet: str) -> str: return f"""\ -result += key->{location}; +result += {packet}->{location}; """ - def get_code_get(self, location: Location, deep_diff: bool = False) -> str: + def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: if self.public_type in ("int", "bool"): # read directly - return super().get_code_get(location, deep_diff) + return super().get_code_get(location, packet, deep_diff) # read indirectly to make sure coercions between different integer # types happen correctly return f"""\ @@ -951,7 +941,7 @@ result += key->{location}; if (!DIO_GET({self.dataio_type}, &din, &field_addr, &readin)) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} - real_packet->{location} = readin; + {packet}->{location} = readin; }} """ @@ -1023,19 +1013,19 @@ class FloatType(BasicType): super().__init__(dataio_type, public_type) self.float_factor = int(float_factor) - def get_code_cmp(self, location: Location) -> str: + def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = ((int) (old->{location} * {self.float_factor}) != (int) (real_packet->{location} * {self.float_factor})); +differ = ((int) ({old}->{location} * {self.float_factor}) != (int) ({new}->{location} * {self.float_factor})); """ - def get_code_put(self, location: Location, deep_diff: bool = False) -> str: + def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: return f"""\ -e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, real_packet->{location}, {self.float_factor:d}); +e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, {packet}->{location}, {self.float_factor:d}); """ - def get_code_get(self, location: Location, deep_diff: bool = False) -> str: + def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: return f"""\ -if (!DIO_GET({self.dataio_type}, &din, &field_addr, &real_packet->{location}, {self.float_factor:d})) {{ +if (!DIO_GET({self.dataio_type}, &din, &field_addr, &{packet}->{location}, {self.float_factor:d})) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} """ @@ -1055,19 +1045,19 @@ class BitvectorType(BasicType): super().__init__(dataio_type, public_type) - def get_code_cmp(self, location: Location) -> str: + def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = !BV_ARE_EQUAL(old->{location}, real_packet->{location}); +differ = !BV_ARE_EQUAL({old}->{location}, {new}->{location}); """ - def get_code_put(self, location: Location, deep_diff: bool = False) -> str: + def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: return f"""\ -e |= DIO_BV_PUT(&dout, &field_addr, packet->{location}); +e |= DIO_BV_PUT(&dout, &field_addr, {packet}->{location}); """ - def get_code_get(self, location: Location, deep_diff: bool = False) -> str: + def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: return f"""\ -if (!DIO_BV_GET(&din, &field_addr, real_packet->{location})) {{ +if (!DIO_BV_GET(&din, &field_addr, {packet}->{location})) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} """ @@ -1101,20 +1091,19 @@ class StructType(BasicType): return "const " + super().get_code_param(location.deeper(f"*{location}")) return super().get_code_param(location) - def get_code_handle_arg(self, location: Location) -> str: - if not location.depth: - # top level: pass by-reference - return super().get_code_handle_arg(location.deeper(f"&{location}")) - return super().get_code_handle_arg(location) + def get_code_handle_arg(self, location: Location, packet: str) -> str: + # top level: pass by-reference + prefix = "&" if not location.depth else "" + return prefix + super().get_code_handle_arg(location, packet) - def get_code_cmp(self, location: Location) -> str: + def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = !are_{self.dataio_type}s_equal(&old->{location}, &real_packet->{location}); +differ = !are_{self.dataio_type}s_equal(&{old}->{location}, &{new}->{location}); """ - def get_code_put(self, location: Location, deep_diff: bool = False) -> str: + def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: return f"""\ -e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, &real_packet->{location}); +e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, &{packet}->{location}); """ DEFAULT_REGISTRY.public_patterns[StructType.TYPE_PATTERN] = StructType @@ -1132,9 +1121,9 @@ class CmParameterType(StructType): super().__init__(dataio_type, public_type) - def get_code_cmp(self, location: Location) -> str: + def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = !cm_are_parameter_equal(&old->{location}, &real_packet->{location}); +differ = !cm_are_parameter_equal(&{old}->{location}, &{new}->{location}); """ DEFAULT_REGISTRY.dataio_types["cm_parameter"] = CmParameterType @@ -1157,9 +1146,9 @@ class WorklistType(StructType): worklist_copy(&{dest}->{location}, &{src}->{location}); """ - def get_code_fill(self, location: Location) -> str: + def get_code_fill(self, location: Location, packet: str) -> str: return f"""\ -worklist_copy(&real_packet->{location}, {location}); +worklist_copy(&{packet}->{location}, {location}); """ DEFAULT_REGISTRY.dataio_types["worklist"] = WorklistType @@ -1186,8 +1175,8 @@ class SizedType(BasicType): return pre + super().get_code_param(location.deeper(f"*{location}")) @abstractmethod - def get_code_fill(self, location: Location) -> str: - return super().get_code_fill(location) + def get_code_fill(self, location: Location, packet: str) -> str: + return super().get_code_fill(location, packet) @abstractmethod def get_code_copy(self, location: Location, dest: str, src: str) -> str: @@ -1209,9 +1198,9 @@ class StringType(SizedType): super().__init__(dataio_type, public_type, size) - def get_code_fill(self, location: Location) -> str: + def get_code_fill(self, location: Location, packet: str) -> str: return f"""\ -sz_strlcpy(real_packet->{location}, {location}); +sz_strlcpy({packet}->{location}, {location}); """ def get_code_copy(self, location: Location, dest: str, src: str) -> str: @@ -1219,14 +1208,14 @@ sz_strlcpy(real_packet->{location}, {location}); sz_strlcpy({dest}->{location}, {src}->{location}); """ - def get_code_cmp(self, location: Location) -> str: + def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = (strcmp(old->{location}, real_packet->{location}) != 0); +differ = (strcmp({old}->{location}, {new}->{location}) != 0); """ - def get_code_get(self, location: Location, deep_diff: bool = False) -> str: + def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: return f"""\ -if (!DIO_GET({self.dataio_type}, &din, &field_addr, real_packet->{location}, sizeof(real_packet->{location}))) {{ +if (!DIO_GET({self.dataio_type}, &din, &field_addr, {packet}->{location}, sizeof({packet}->{location}))) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} """ @@ -1241,7 +1230,7 @@ class MemoryType(SizedType): super().__init__(dataio_type, public_type, size) - def get_code_fill(self, location: Location) -> str: + def get_code_fill(self, location: Location, packet: str) -> str: raise NotImplementedError("fill not supported for memory-type fields") def get_code_copy(self, location: Location, dest: str, src: str) -> str: @@ -1249,25 +1238,25 @@ class MemoryType(SizedType): memcpy({dest}->{location}, {src}->{location}, {self.size.actual_for(src)}); """ - def get_code_cmp(self, location: Location) -> str: + def get_code_cmp(self, location: Location, new: str, old: str) -> str: if self.size.constant: return f"""\ -differ = (memcmp(old->{location}, real_packet->{location}, {self.size.real}) != 0); +differ = (memcmp({old}->{location}, {new}->{location}, {self.size.declared}) != 0); """ return f"""\ -differ = (({self.size.old} != {self.size.real}) - || (memcmp(old->{location}, real_packet->{location}, {self.size.real}) != 0)); +differ = (({self.size.actual_for(old)} != {self.size.actual_for(new)}) + || (memcmp({old}->{location}, {new}->{location}, {self.size.actual_for(new)}) != 0)); """ - def get_code_put(self, location: Location, deep_diff: bool = False) -> str: + def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: return f"""\ -e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, &real_packet->{location}, {self.size.real}); +e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, &{packet}->{location}, {self.size.actual_for(packet)}); """ - def get_code_get(self, location: Location, deep_diff: bool = False) -> str: + def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: return f"""\ -{self.size.size_check_get(location.name)}\ -if (!DIO_GET({self.dataio_type}, &din, &field_addr, real_packet->{location}, {self.size.real})) {{ +{self.size.size_check_get(location.name, packet)}\ +if (!DIO_GET({self.dataio_type}, &din, &field_addr, {packet}->{location}, {self.size.actual_for(packet)})) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} """ @@ -1310,10 +1299,10 @@ class ArrayType(FieldType): # the final * is already part of {location} return self.elem.get_code_param(location.deeper(f"*const {location}")) - def get_code_init(self, location: Location) -> str: + def get_code_init(self, location: Location, packet: str) -> str: if not self.complex: - return super().get_code_init(location) - inner_init = prefix(" ", self.elem.get_code_init(location.sub)) + return super().get_code_init(location, packet) + inner_init = prefix(" ", self.elem.get_code_init(location.sub, packet)) # Note: we're initializing and destroying *all* elements of the array, # not just those up to the actual size; otherwise we'd have to # dynamically initialize and destroy elements as the actual size changes @@ -1341,22 +1330,22 @@ class ArrayType(FieldType): }} """ - def get_code_fill(self, location: Location) -> str: - inner_fill = prefix(" ", self.elem.get_code_fill(location.sub)) + def get_code_fill(self, location: Location, packet: str) -> str: + inner_fill = prefix(" ", self.elem.get_code_fill(location.sub, packet)) return f"""\ {{ int {location.index}; - for ({location.index} = 0; {location.index} < {self.size.real}; {location.index}++) {{ + for ({location.index} = 0; {location.index} < {self.size.actual_for(packet)}; {location.index}++) {{ {inner_fill}\ }} }} """ - def get_code_free(self, location: Location) -> str: + def get_code_free(self, location: Location, packet: str) -> str: if not self.complex: - return super().get_code_free(location) - inner_free = prefix(" ", self.elem.get_code_free(location.sub)) + return super().get_code_free(location, packet) + inner_free = prefix(" ", self.elem.get_code_free(location.sub, packet)) # Note: we're initializing and destroying *all* elements of the array, # not just those up to the actual size; otherwise we'd have to # dynamically initialize and destroy elements as the actual size changes @@ -1373,22 +1362,22 @@ class ArrayType(FieldType): def get_code_hash(self, location: Location) -> str: raise ValueError(f"hash not supported for array type {self} in field {location.name}") - def get_code_cmp(self, location: Location) -> str: + def get_code_cmp(self, location: Location, new: str, old: str) -> str: if not self.size.constant: # ends mid-line head = f"""\ -differ = ({self.size.old} != {self.size.real}); +differ = ({self.size.actual_for(old)} != {self.size.actual_for(new)}); if (!differ) """ else: head = f"""\ differ = FALSE; """ - inner_cmp = prefix(" ", self.elem.get_code_cmp(location.sub)) + inner_cmp = prefix(" ", self.elem.get_code_cmp(location.sub, new, old)) return f"""\ {head}{{ int {location.index}; - for ({location.index} = 0; {location.index} < {self.size.real}; {location.index}++) {{ + for ({location.index} = 0; {location.index} < {self.size.actual_for(new)}; {location.index}++) {{ {inner_cmp}\ if (differ) {{ break; @@ -1397,22 +1386,22 @@ differ = FALSE; }} """ - def _get_code_put_full(self, location: Location) -> str: + def _get_code_put_full(self, location: Location, packet: str) -> str: """Helper method. Generate put code without array-diff.""" - inner_put = prefix(" ", self.elem.get_code_put(location.sub, False)) + inner_put = prefix(" ", self.elem.get_code_put(location.sub, packet)) return f"""\ {{ int {location.index}; #ifdef FREECIV_JSON_CONNECTION /* Create the array. */ - e |= DIO_PUT(farray, &dout, &field_addr, {self.size.real}); + e |= DIO_PUT(farray, &dout, &field_addr, {self.size.actual_for(packet)}); /* Enter the array. */ {location.json_subloc} = plocation_elem_new(0); #endif /* FREECIV_JSON_CONNECTION */ - for ({location.index} = 0; {location.index} < {self.size.real}; {location.index}++) {{ + for ({location.index} = 0; {location.index} < {self.size.actual_for(packet)}; {location.index}++) {{ #ifdef FREECIV_JSON_CONNECTION /* Next array element. */ {location.json_subloc}->number = {location.index}; @@ -1428,20 +1417,20 @@ differ = FALSE; }} """ - def _get_code_put_diff(self, location: Location) -> str: + def _get_code_put_diff(self, location: Location, packet: str, diff_packet: str) -> str: """Helper method. Generate array-diff put code.""" # we're nesting two levels deep in the JSON structure sub = location.sub_full(2) # Note: At the moment, we're only deep-diffing our elements # if our array size is constant - value_put = prefix(" ", self.elem.get_code_put(sub, self.size.constant)) - inner_cmp = prefix(" ", self.elem.get_code_cmp(sub)) - index_put = prefix(" ", self.size.index_put(location.index)) - index_put_sentinel = prefix(" ", self.size.index_put(self.size.real)) + value_put = prefix(" ", self.elem.get_code_put(sub, packet, diff_packet if self.size.constant else None)) + inner_cmp = prefix(" ", self.elem.get_code_cmp(sub, packet, diff_packet)) + index_put = prefix(" ", self.size.index_put(packet, location.index)) + index_put_sentinel = prefix(" ", self.size.index_put(packet, self.size.actual_for(packet))) if not self.size.constant: inner_cmp = f"""\ - if ({location.index} < {self.size.old}) {{ + if ({location.index} < {self.size.actual_for(diff_packet)}) {{ {prefix(" ", inner_cmp)}\ }} else {{ /* Always transmit new elements */ @@ -1450,7 +1439,7 @@ differ = FALSE; """ return f"""\ -{self.size.size_check_index(location.name)}\ +{self.size.size_check_index(location.name, packet)}\ {{ int {location.index}; @@ -1464,7 +1453,7 @@ differ = FALSE; {location.json_subloc} = plocation_elem_new(0); #endif /* FREECIV_JSON_CONNECTION */ - for ({location.index} = 0; {location.index} < {self.size.real}; {location.index}++) {{ + for ({location.index} = 0; {location.index} < {self.size.actual_for(packet)}; {location.index}++) {{ {inner_cmp}\ if (!differ) {{ @@ -1522,17 +1511,17 @@ differ = FALSE; }} """ - def get_code_put(self, location: Location, deep_diff: bool = False) -> str: - if deep_diff: - return self._get_code_put_diff(location) + def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: + if diff_packet is not None: + return self._get_code_put_diff(location, packet, diff_packet) else: - return self._get_code_put_full(location) + return self._get_code_put_full(location, packet) - def _get_code_get_full(self, location: Location) -> str: + def _get_code_get_full(self, location: Location, packet: str) -> str: """Helper method. Generate get code without array-diff.""" - inner_get = prefix(" ", self.elem.get_code_get(location.sub, False)) + inner_get = prefix(" ", self.elem.get_code_get(location.sub, packet)) return f"""\ -{self.size.size_check_get(location.name)}\ +{self.size.size_check_get(location.name, packet)}\ {{ int {location.index}; @@ -1541,7 +1530,7 @@ differ = FALSE; {location.json_subloc} = plocation_elem_new(0); #endif /* FREECIV_JSON_CONNECTION */ - for ({location.index} = 0; {location.index} < {self.size.real}; {location.index}++) {{ + for ({location.index} = 0; {location.index} < {self.size.actual_for(packet)}; {location.index}++) {{ #ifdef FREECIV_JSON_CONNECTION {location.json_subloc}->number = {location.index}; #endif /* FREECIV_JSON_CONNECTION */ @@ -1556,14 +1545,14 @@ differ = FALSE; }} """ - def _get_code_get_diff(self, location: Location) -> str: + def _get_code_get_diff(self, location: Location, packet: str) -> str: """Helper method. Generate array-diff get code.""" # we're nested two levels deep in the JSON structure - value_get = prefix(" ", self.elem.get_code_get(location.sub_full(2), True)) - index_get = prefix(" ", self.size.index_get(location)) + value_get = prefix(" ", self.elem.get_code_get(location.sub_full(2), packet, True)) + index_get = prefix(" ", self.size.index_get(packet, location)) return f"""\ -{self.size.size_check_get(location.name)}\ -{self.size.size_check_index(location.name)}\ +{self.size.size_check_get(location.name, packet)}\ +{self.size.size_check_index(location.name, packet)}\ #ifdef FREECIV_JSON_CONNECTION /* Enter array (start at initial element). */ {location.json_subloc} = plocation_elem_new(0); @@ -1577,13 +1566,13 @@ while (TRUE) {{ /* Read next index */ {index_get}\ - if ({location.index} == {self.size.real}) {{ + if ({location.index} == {self.size.actual_for(packet)}) {{ break; }} - if ({location.index} > {self.size.real}) {{ + if ({location.index} > {self.size.actual_for(packet)}) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}, ": unexpected value %d " - "(> {self.size.real}) in array diff", + "(> {self.size.actual_for(packet)}) in array diff", {location.index}); }} @@ -1609,11 +1598,11 @@ FC_FREE({location.json_subloc}); #endif /* FREECIV_JSON_CONNECTION */ """ - def get_code_get(self, location: Location, deep_diff: bool = False) -> str: + def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: if deep_diff: - return self._get_code_get_diff(location) + return self._get_code_get_diff(location, packet) else: - return self._get_code_get_full(location) + return self._get_code_get_full(location, packet) def __str__(self) -> str: return f"{self.elem}[{self.size}]" @@ -1666,19 +1655,19 @@ class StrvecType(FieldType): # initial const gets added from outside return f"{self.public_type} *const {location}" - def get_code_init(self, location: Location) -> str: + def get_code_init(self, location: Location, packet: str) -> str: # we're always allocating our vectors, even if they're empty return f"""\ -{location} = strvec_new(); +{packet}->{location} = strvec_new(); """ - def get_code_fill(self, location: Location) -> str: + def get_code_fill(self, location: Location, packet: str) -> str: """Generate a code snippet shallow-copying a value of this type from dsend arguments into a packet struct.""" # safety: the packet's contents will not be modified without cloning # it first, so discarding 'const' qualifier here is safe return f"""\ -real_packet->{location} = (struct strvec *) {location}; +{packet}->{location} = (struct strvec *) {location}; """ def get_code_copy(self, location: Location, dest: str, src: str) -> str: @@ -1692,28 +1681,28 @@ if ({src}->{location}) {{ }} """ - def get_code_free(self, location: Location) -> str: + def get_code_free(self, location: Location, packet: str) -> str: return f"""\ -if ({location}) {{ - strvec_destroy({location}); - {location} = nullptr; +if ({packet}->{location}) {{ + strvec_destroy({packet}->{location}); + {packet}->{location} = nullptr; }} """ def get_code_hash(self, location: Location) -> str: raise ValueError(f"hash not supported for strvec type {self} in field {location.name}") - def get_code_cmp(self, location: Location) -> str: - # real_packet vector might be null when sending + def get_code_cmp(self, location: Location, new: str, old: str) -> str: + # "new" packet passed in from outside might have null vector return f"""\ -if (real_packet->{location}) {{ - differ = !are_strvecs_equal(old->{location}, real_packet->{location}); +if ({new}->{location}) {{ + differ = !are_strvecs_equal({old}->{location}, {new}->{location}); }} else {{ - differ = (strvec_size(old->{location}) > 0); + differ = (strvec_size({old}->{location}) > 0); }} """ - def _get_code_put_full(self, location: Location) -> str: + def _get_code_put_full(self, location: Location, packet: str) -> str: # Note: strictly speaking, we could allow size == MAX_UINT16, # but we might want to use that in the future to signal overlong # vectors (like with jumbo packets) @@ -1721,22 +1710,22 @@ if (real_packet->{location}) {{ # which we're a long way from size = __class__._VecSize(location) return f"""\ -if (!real_packet->{location}) {{ +if (!{packet}->{location}) {{ /* Transmit null vector as empty vector */ e |= DIO_PUT(arraylen, &dout, &field_addr, 0); }} else {{ int {location.index}; - fc_assert({size.real} < MAX_UINT16); - e |= DIO_PUT(arraylen, &dout, &field_addr, {size.real}); + fc_assert({size.actual_for(packet)} < MAX_UINT16); + e |= DIO_PUT(arraylen, &dout, &field_addr, {size.actual_for(packet)}); #ifdef FREECIV_JSON_CONNECTION /* Enter array. */ {location.json_subloc} = plocation_elem_new(0); #endif /* FREECIV_JSON_CONNECTION */ - for ({location.index} = 0; {location.index} < {size.real}; {location.index}++) {{ - const char *pstr = strvec_get(real_packet->{location}, {location.index}); + for ({location.index} = 0; {location.index} < {size.actual_for(packet)}; {location.index}++) {{ + const char *pstr = strvec_get({packet}->{location}, {location.index}); if (!pstr) {{ /* Transmit null strings as empty strings */ @@ -1758,13 +1747,13 @@ if (!real_packet->{location}) {{ }} """ - def _get_code_put_diff(self, location: Location) -> str: + def _get_code_put_diff(self, location: Location, packet: str, diff_packet: str) -> str: size = __class__._VecSize(location) - size_check = prefix(" ", size.size_check_index(location.name)) - index_put = prefix(" ", size.index_put(location.index)) - index_put_sentinel = prefix(" ", size.index_put(size.real)) + size_check = prefix(" ", size.size_check_index(location.name, packet)) + index_put = prefix(" ", size.index_put(packet, location.index)) + index_put_sentinel = prefix(" ", size.index_put(packet, size.actual_for(packet))) return f"""\ -if (!real_packet->{location} || 0 == {size.real}) {{ +if (!{packet}->{location} || 0 == {size.actual_for(packet)}) {{ /* Special case for empty vector. */ #ifdef FREECIV_JSON_CONNECTION @@ -1798,7 +1787,7 @@ if (!real_packet->{location} || 0 == {size.real}) {{ #endif /* FREECIV_JSON_CONNECTION */ /* Write the new size */ - e |= DIO_PUT(uint16, &dout, &field_addr, {size.real}); + e |= DIO_PUT(uint16, &dout, &field_addr, {size.actual_for(packet)}); #ifdef FREECIV_JSON_CONNECTION /* Delta address. */ @@ -1811,16 +1800,16 @@ if (!real_packet->{location} || 0 == {size.real}) {{ {location.json_subloc}->sub_location = plocation_elem_new(0); #endif /* FREECIV_JSON_CONNECTION */ - for ({location.index} = 0; {location.index} < {size.real}; {location.index}++) {{ - const char *pstr = strvec_get(real_packet->{location}, {location.index}); + for ({location.index} = 0; {location.index} < {size.actual_for(packet)}; {location.index}++) {{ + const char *pstr = strvec_get({packet}->{location}, {location.index}); if (!pstr) {{ /* Transmit null strings as empty strings */ pstr = ""; }} - if ({location.index} < {size.old}) {{ - const char *pstr_old = strvec_get(old->{location}, {location.index}); + if ({location.index} < {size.actual_for(diff_packet)}) {{ + const char *pstr_old = strvec_get({diff_packet}->{location}, {location.index}); differ = (strcmp(pstr_old ? pstr_old : "", pstr) != 0); }} else {{ @@ -1884,13 +1873,13 @@ if (!real_packet->{location} || 0 == {size.real}) {{ }} """ - def get_code_put(self, location: Location, deep_diff: bool = False) -> str: - if deep_diff: - return self._get_code_put_diff(location) + def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: + if diff_packet is not None: + return self._get_code_put_diff(location, packet, diff_packet) else: - return self._get_code_put_full(location) + return self._get_code_put_full(location, packet) - def _get_code_get_full(self, location: Location) -> str: + def _get_code_get_full(self, location: Location, packet: str) -> str: size = __class__._VecSize(location) return f"""\ {{ @@ -1899,13 +1888,13 @@ if (!real_packet->{location} || 0 == {size.real}) {{ if (!DIO_GET(arraylen, &din, &field_addr, &{location.index})) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} - strvec_reserve(real_packet->{location}, {location.index}); + strvec_reserve({packet}->{location}, {location.index}); #ifdef FREECIV_JSON_CONNECTION {location.json_subloc} = plocation_elem_new(0); #endif /* FREECIV_JSON_CONNECTION */ - for ({location.index} = 0; {location.index} < {size.real}; {location.index}++) {{ + for ({location.index} = 0; {location.index} < {size.actual_for(packet)}; {location.index}++) {{ char readin[MAX_LEN_PACKET]; #ifdef FREECIV_JSON_CONNECTION @@ -1914,7 +1903,7 @@ if (!real_packet->{location} || 0 == {size.real}) {{ #endif /* FREECIV_JSON_CONNECTION */ if (!DIO_GET({self.dataio_type}, &din, &field_addr, readin, sizeof(readin)) - || !strvec_set(real_packet->{location}, {location.index}, readin)) {{ + || !strvec_set({packet}->{location}, {location.index}, readin)) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} }} @@ -1925,11 +1914,11 @@ if (!real_packet->{location} || 0 == {size.real}) {{ }} """ - def _get_code_get_diff(self, location: Location) -> str: + def _get_code_get_diff(self, location: Location, packet: str) -> str: size = __class__._VecSize(location) - index_get = prefix(" ", size.index_get(location)) + index_get = prefix(" ", size.index_get(packet, location)) return f"""\ -{size.size_check_index(location.name)}\ +{size.size_check_index(location.name, packet)}\ #ifdef FREECIV_JSON_CONNECTION /* Enter object (start at size address). */ {location.json_subloc} = plocation_field_new("size"); @@ -1941,7 +1930,7 @@ if (!real_packet->{location} || 0 == {size.real}) {{ if (!DIO_GET(uint16, &din, &field_addr, &readin)) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} - strvec_reserve(real_packet->{location}, readin); + strvec_reserve({packet}->{location}, readin); }} #ifdef FREECIV_JSON_CONNECTION @@ -1953,18 +1942,18 @@ if (!real_packet->{location} || 0 == {size.real}) {{ {location.json_subloc}->sub_location->sub_location = plocation_field_new("index"); #endif /* FREECIV_JSON_CONNECTION */ -/* if ({size.real} > 0) while (TRUE) */ -while ({size.real} > 0) {{ +/* if ({size.actual_for(packet)} > 0) while (TRUE) */ +while ({size.actual_for(packet)} > 0) {{ int {location.index}; char readin[MAX_LEN_PACKET]; /* Read next index */ {index_get}\ - if ({location.index} == {size.real}) {{ + if ({location.index} == {size.actual_for(packet)}) {{ break; }} - if ({location.index} > {size.real}) {{ + if ({location.index} > {size.actual_for(packet)}) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}, ": unexpected value %d " "(> vector length) in array diff", @@ -1977,7 +1966,7 @@ while ({size.real} > 0) {{ #endif /* FREECIV_JSON_CONNECTION */ if (!DIO_GET({self.dataio_type}, &din, &field_addr, readin, sizeof(readin)) - || !strvec_set(real_packet->{location}, {location.index}, readin)) {{ + || !strvec_set({packet}->{location}, {location.index}, readin)) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} @@ -1997,11 +1986,11 @@ FC_FREE({location.json_subloc}); #endif /* FREECIV_JSON_CONNECTION */ """ - def get_code_get(self, location: Location, deep_diff: bool = False) -> str: + def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: if deep_diff: - return self._get_code_get_diff(location) + return self._get_code_get_diff(location, packet) else: - return self._get_code_get_full(location) + return self._get_code_get_full(location, packet) def __str__(self) -> str: return f"{self.dataio_type}({self.public_type})" @@ -2168,20 +2157,17 @@ class Field: See also self.get_handle_arg()""" return self.type_info.get_code_param(Location(self.name)) - def get_handle_arg(self, packet_arrow: str) -> str: + def get_handle_arg(self, packet: str) -> str: """Generate the way this field is passed as an argument to a handle function. See also self.get_handle_param()""" - return self.type_info.get_code_handle_arg(Location( - self.name, - packet_arrow + self.name, - )) + return self.type_info.get_code_handle_arg(Location(self.name), packet) def get_init(self) -> str: """Generate code initializing this field in the packet struct, after the struct has already been zeroed.""" - return self.type_info.get_code_init(Location(self.name, f"packet->{self.name}")) + return self.type_info.get_code_init(Location(self.name), "packet") def get_copy(self, dest: str, src: str) -> str: """Generate code deep-copying this field from *src to *dest.""" @@ -2190,17 +2176,17 @@ class Field: def get_fill(self) -> str: """Generate code shallow-copying this field from the dsend arguments into the packet struct.""" - return self.type_info.get_code_fill(Location(self.name)) + return self.type_info.get_code_fill(Location(self.name), "real_packet") def get_free(self) -> str: """Generate code deinitializing this field in the packet struct before destroying the packet.""" - return self.type_info.get_code_free(Location(self.name, f"packet->{self.name}")) + return self.type_info.get_code_free(Location(self.name), "packet") def get_hash(self) -> str: """Generate code factoring this field into a hash computation.""" assert self.is_key - return self.type_info.get_code_hash(Location(self.name)) + return self.type_info.get_code_hash(Location(self.name), "key") @property def folded_into_head(self) -> bool: @@ -2235,7 +2221,7 @@ if (differ) {{ return f"""\ {info_part}\ /* folded into head */ -if (packet->{self.name}) {{ +if (real_packet->{self.name}) {{ BV_SET(fields, {i:d}); }} """ @@ -2256,7 +2242,7 @@ if (differ) {{ """Generate code checking whether this field changed. This code is primarily used by self.get_cmp_wrapper()""" - return self.type_info.get_code_cmp(Location(self.name)) + return self.type_info.get_code_cmp(Location(self.name), "real_packet", "old") def get_put_wrapper(self, packet: "Variant", index: int, deltafragment: bool) -> str: """Generate code conditionally putting this field iff its bit in the @@ -2314,7 +2300,7 @@ if (e) {{ yet wrapped for full delta and JSON protocol support. See self.get_put() for more info""" - return self.type_info.get_code_put(Location(self.name), deltafragment and self.diff) + return self.type_info.get_code_put(Location(self.name), "real_packet", "old" if deltafragment and self.diff else None) def get_get_wrapper(self, packet: "Variant", i: int, deltafragment: bool) -> str: """Generate code conditionally getting this field iff its bit in the @@ -2361,7 +2347,7 @@ field_addr.name = "{self.name}"; yet wrapped for full delta and JSON protocol support. See self.get_get() for more info""" - return self.type_info.get_code_get(Location(self.name), deltafragment and self.diff) + return self.type_info.get_code_get(Location(self.name), "real_packet", deltafragment and self.diff) class Variant: @@ -4383,9 +4369,9 @@ bool server_handle_packet(enum packet_type type, const void *packet, if p.handle_via_packet: args += ", packet" else: - packet_arrow = f"((const struct {p.name} *)packet)->" + packet = f"((const struct {p.name} *)packet)" args += "".join( - ",\n " + field.get_handle_arg(packet_arrow) + ",\n " + field.get_handle_arg(packet) for field in p.fields ) @@ -4429,9 +4415,9 @@ bool client_handle_packet(enum packet_type type, const void *packet) if p.handle_via_packet: args = "packet" else: - packet_arrow = f"((const struct {p.name} *)packet)->" + packet = f"((const struct {p.name} *)packet)" args = ",".join( - "\n " + field.get_handle_arg(packet_arrow) + "\n " + field.get_handle_arg(packet) for field in p.fields ) -- 2.34.1