From cf4b25289a0419d26161eec111972ee439009e33 Mon Sep 17 00:00:00 2001 From: Alina Lenk Date: Fri, 3 May 2024 15:07:27 +0200 Subject: [PATCH 2/2] generate_packets.py: make Location class insert packet names See RM #528 Signed-off-by: Alina Lenk --- common/generate_packets.py | 170 ++++++++++++++++++++++--------------- 1 file changed, 101 insertions(+), 69 deletions(-) diff --git a/common/generate_packets.py b/common/generate_packets.py index 507c68b39c..c53fb19aad 100755 --- a/common/generate_packets.py +++ b/common/generate_packets.py @@ -347,12 +347,16 @@ class Location: outside of recursive field types like arrays, this will usually just be a field of a packet, but it serves to concisely handle the recursion.""" + # placeholder that will clearly be an error if it accidentally + # shows up in generated code + _PACKET = "#error gen_packet$" _INDICES = "ijk" name: str """The name associated with this location; used in log messages.""" - location: str - """The actual location as used in code""" + _location: str + """The actual location as used in code, including placeholders for + where the packet name goes""" depth: int """The array nesting depth of this location; used to determine index variable names.""" @@ -360,18 +364,32 @@ class Location: """The total sub-location nesting depth of the JSON field address for this location""" - def __init__(self, name: str, location: "str | None" = None, + def __init__(self, name: str, *, location: "str | None" = None, depth: int = 0, json_depth: "int | None" = None): self.name = name - self.location = location if location is not None else name + self._location = location if location is not None else self._PACKET + name self.depth = depth self.json_depth = json_depth if json_depth is not None else depth + def replace(self, new_location: str) -> "Location": + """Return the given string as a new Location with the same metadata + as self""" + return type(self)( + name = self.name, + location = new_location, + depth = self.depth, + json_depth = self.json_depth, + ) + def deeper(self, new_location: str, json_step: int = 1) -> "Location": """Return the given string as a new Location with the same name as self and incremented depth""" - return type(self)(self.name, new_location, - self.depth + 1, self.json_depth + json_step) + return type(self)( + name = self.name, + location = new_location, + depth = self.depth + 1, + json_depth = self.json_depth + json_step, + ) def sub_full(self, json_step: int = 1) -> "Location": """Like self.sub, but with the option to step the JSON nesting @@ -399,11 +417,18 @@ class Location: of this location's corresponding field address""" return "field_addr.sub_location" + self.json_depth * "->sub_location" + def __matmul__(self, packet: str | None) -> str: + """self @ packet + Code fragment of this location in the given packet, or in local + variables if packet is None""" + packet = f"{packet}->" if packet is not None else "" + return self._location.replace(self._PACKET, packet) + def __str__(self) -> str: - return self.location + return self._location def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.name!r}, {self.location!r}, {self.depth!r}, {self.json_depth!r})" + return f"<{type(self).__name__} {self.name}(depth={self.depth}, json_depth={self.json_depth}) {self @ 'PACKET'}>" #################### Components of a packets definition #################### @@ -761,7 +786,7 @@ class FieldType(RawFieldType): handle function. See also self.get_code_param()""" - return f"{packet}->{location}" + return f"{location @ packet}" def get_code_init(self, location: Location, packet: str) -> str: """Generate a code snippet initializing a field of this type in the @@ -781,14 +806,14 @@ class FieldType(RawFieldType): if self.complex: raise ValueError(f"default get_code_copy implementation called for field {location.name} with complex type {self!r}") return f"""\ -{dest}->{location} = {src}->{location}; +{location @ dest} = {location @ src}; """ def get_code_fill(self, location: Location, packet: str) -> str: """Generate a code snippet shallow-copying a value of this type from dsend arguments into a packet struct.""" return f"""\ -{packet}->{location} = {location}; +{location @ packet} = {location @ None}; """ def get_code_free(self, location: Location, packet: str) -> str: @@ -869,28 +894,28 @@ class BasicType(FieldType): def get_code_declaration(self, location: Location) -> str: return f"""\ -{self.public_type} {location}; +{self.public_type} {location @ None}; """ def get_code_param(self, location: Location) -> str: - return f"{self.public_type} {location}" + return f"{self.public_type} {location @ None}" def get_code_hash(self, location: Location, packet: str) -> str: raise ValueError(f"hash not supported for type {self} in field {location.name}") def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = ({old}->{location} != {new}->{location}); +differ = ({location @ old} != {location @ new}); """ def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: return f"""\ -e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, {packet}->{location}); +e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, {location @ packet}); """ def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: return f"""\ -if (!DIO_GET({self.dataio_type}, &din, &field_addr, &{packet}->{location})) {{ +if (!DIO_GET({self.dataio_type}, &din, &field_addr, &{location @ packet})) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} """ @@ -925,7 +950,7 @@ class IntType(BasicType): def get_code_hash(self, location: Location, packet: str) -> str: return f"""\ -result += {packet}->{location}; +result += {location @ packet}; """ def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: @@ -941,7 +966,7 @@ result += {packet}->{location}; if (!DIO_GET({self.dataio_type}, &din, &field_addr, &readin)) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} - {packet}->{location} = readin; + {location @ packet} = readin; }} """ @@ -1015,17 +1040,17 @@ class FloatType(BasicType): def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = ((int) ({old}->{location} * {self.float_factor}) != (int) ({new}->{location} * {self.float_factor})); +differ = ((int) ({location @ old} * {self.float_factor}) != (int) ({location @ new} * {self.float_factor})); """ def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: return f"""\ -e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, {packet}->{location}, {self.float_factor:d}); +e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, {location @ packet}, {self.float_factor:d}); """ def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: return f"""\ -if (!DIO_GET({self.dataio_type}, &din, &field_addr, &{packet}->{location}, {self.float_factor:d})) {{ +if (!DIO_GET({self.dataio_type}, &din, &field_addr, &{location @ packet}, {self.float_factor:d})) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} """ @@ -1047,17 +1072,17 @@ class BitvectorType(BasicType): def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = !BV_ARE_EQUAL({old}->{location}, {new}->{location}); +differ = !BV_ARE_EQUAL({location @ old}, {location @ new}); """ def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: return f"""\ -e |= DIO_BV_PUT(&dout, &field_addr, {packet}->{location}); +e |= DIO_BV_PUT(&dout, &field_addr, {location @ packet}); """ def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: return f"""\ -if (!DIO_BV_GET(&din, &field_addr, {packet}->{location})) {{ +if (!DIO_BV_GET(&din, &field_addr, {location @ packet})) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} """ @@ -1088,7 +1113,7 @@ class StructType(BasicType): def get_code_param(self, location: Location) -> str: if not location.depth: # top level: pass by-reference - return "const " + super().get_code_param(location.deeper(f"*{location}")) + return "const " + super().get_code_param(location.replace(f"*{location}")) return super().get_code_param(location) def get_code_handle_arg(self, location: Location, packet: str) -> str: @@ -1098,12 +1123,12 @@ class StructType(BasicType): def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = !are_{self.dataio_type}s_equal(&{old}->{location}, &{new}->{location}); +differ = !are_{self.dataio_type}s_equal(&{location @ old}, &{location @ new}); """ def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: return f"""\ -e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, &{packet}->{location}); +e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, &{location @ packet}); """ DEFAULT_REGISTRY.public_patterns[StructType.TYPE_PATTERN] = StructType @@ -1123,7 +1148,7 @@ class CmParameterType(StructType): def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = !cm_are_parameter_equal(&{old}->{location}, &{new}->{location}); +differ = !cm_are_parameter_equal(&{location @ old}, &{location @ new}); """ DEFAULT_REGISTRY.dataio_types["cm_parameter"] = CmParameterType @@ -1143,12 +1168,12 @@ class WorklistType(StructType): def get_code_copy(self, location: Location, dest: str, src: str) -> str: return f"""\ -worklist_copy(&{dest}->{location}, &{src}->{location}); +worklist_copy(&{location @ dest}, &{location @ src}); """ def get_code_fill(self, location: Location, packet: str) -> str: return f"""\ -worklist_copy(&{packet}->{location}, {location}); +worklist_copy(&{location @ packet}, {location @ None}); """ DEFAULT_REGISTRY.dataio_types["worklist"] = WorklistType @@ -1166,13 +1191,15 @@ class SizedType(BasicType): def get_code_declaration(self, location: Location) -> str: return super().get_code_declaration( - location.deeper(f"{location}[{self.size.declared}]") + location.replace(f"{location}[{self.size.declared}]") ) def get_code_param(self, location: Location) -> str: - # add "const" if top level - pre = "" if location.depth else "const " - return pre + super().get_code_param(location.deeper(f"*{location}")) + # see ArrayType.get_code_param() for explanation + if not location.depth: + return "const " + super().get_code_param(location.replace(f"*{location}")) + else: + return super().get_code_param(location.replace(f"*const {location}")) @abstractmethod def get_code_fill(self, location: Location, packet: str) -> str: @@ -1200,22 +1227,22 @@ class StringType(SizedType): def get_code_fill(self, location: Location, packet: str) -> str: return f"""\ -sz_strlcpy({packet}->{location}, {location}); +sz_strlcpy({location @ packet}, {location @ None}); """ def get_code_copy(self, location: Location, dest: str, src: str) -> str: return f"""\ -sz_strlcpy({dest}->{location}, {src}->{location}); +sz_strlcpy({location @ dest}, {location @ src}); """ def get_code_cmp(self, location: Location, new: str, old: str) -> str: return f"""\ -differ = (strcmp({old}->{location}, {new}->{location}) != 0); +differ = (strcmp({location @ old}, {location @ new}) != 0); """ def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: return f"""\ -if (!DIO_GET({self.dataio_type}, &din, &field_addr, {packet}->{location}, sizeof({packet}->{location}))) {{ +if (!DIO_GET({self.dataio_type}, &din, &field_addr, {location @ packet}, sizeof({location @ packet}))) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} """ @@ -1235,28 +1262,28 @@ class MemoryType(SizedType): def get_code_copy(self, location: Location, dest: str, src: str) -> str: return f"""\ -memcpy({dest}->{location}, {src}->{location}, {self.size.actual_for(src)}); +memcpy({location @ dest}, {location @ src}, {self.size.actual_for(src)}); """ def get_code_cmp(self, location: Location, new: str, old: str) -> str: if self.size.constant: return f"""\ -differ = (memcmp({old}->{location}, {new}->{location}, {self.size.declared}) != 0); +differ = (memcmp({location @ old}, {location @ new}, {self.size.declared}) != 0); """ return f"""\ differ = (({self.size.actual_for(old)} != {self.size.actual_for(new)}) - || (memcmp({old}->{location}, {new}->{location}, {self.size.actual_for(new)}) != 0)); + || (memcmp({location @ old}, {location @ new}, {self.size.actual_for(new)}) != 0)); """ def get_code_put(self, location: Location, packet: str, diff_packet: "str | None" = None) -> str: return f"""\ -e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, &{packet}->{location}, {self.size.actual_for(packet)}); +e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, &{location @ packet}, {self.size.actual_for(packet)}); """ def get_code_get(self, location: Location, packet: str, deep_diff: bool = False) -> str: return f"""\ {self.size.size_check_get(location.name, packet)}\ -if (!DIO_GET({self.dataio_type}, &din, &field_addr, {packet}->{location}, {self.size.actual_for(packet)})) {{ +if (!DIO_GET({self.dataio_type}, &din, &field_addr, {location @ packet}, {self.size.actual_for(packet)})) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} """ @@ -1289,6 +1316,8 @@ class ArrayType(FieldType): ) def get_code_param(self, location: Location) -> str: + # When changing this, update SizedType.get_code_param() accordingly + # Note: If we're fine with writing `foo_t const *fieldname`, # we'd only need one case, .deeper(f"const *{location}") if not location.depth: @@ -1296,7 +1325,7 @@ class ArrayType(FieldType): return "const " + self.elem.get_code_param(location.deeper(f"*{location}")) else: # const foo_t *fieldname ~> const foo_t *const *fieldname - # the final * is already part of {location} + # the final * is already part of the location return self.elem.get_code_param(location.deeper(f"*const {location}")) def get_code_init(self, location: Location, packet: str) -> str: @@ -1614,11 +1643,14 @@ class StrvecType(FieldType): class _VecSize(SizeInfo): """Helper class to make SizeInfo methods work with strvec sizes""" + _actual_loc: Location + def __init__(self, location: Location): super().__init__("GENERATE_PACKETS_ERROR", str(location)) + self._actual_loc = location.replace(f"strvec_size({location})") def actual_for(self, packet: str) -> str: - return f"strvec_size({packet}->{self._actual})" + return self._actual_loc @ packet def __str__(self) -> str: return "*" @@ -1643,22 +1675,22 @@ class StrvecType(FieldType): def get_code_declaration(self, location: Location) -> str: return f"""\ -{self.public_type} *{location}; +{self.public_type} *{location @ None}; """ def get_code_param(self, location: Location) -> str: if not location.depth: - return f"const {self.public_type} *{location}" + return f"const {self.public_type} *{location @ None}" else: # const struct strvec *const *fieldname - # the final * is already part of {location} + # the final * is already part of the location # initial const gets added from outside - return f"{self.public_type} *const {location}" + return f"{self.public_type} *const {location @ None}" def get_code_init(self, location: Location, packet: str) -> str: # we're always allocating our vectors, even if they're empty return f"""\ -{packet}->{location} = strvec_new(); +{location @ packet} = strvec_new(); """ def get_code_fill(self, location: Location, packet: str) -> str: @@ -1667,25 +1699,25 @@ class StrvecType(FieldType): # safety: the packet's contents will not be modified without cloning # it first, so discarding 'const' qualifier here is safe return f"""\ -{packet}->{location} = (struct strvec *) {location}; +{location @ packet} = (struct strvec *) {location @ None}; """ def get_code_copy(self, location: Location, dest: str, src: str) -> str: # dest is initialized by us ~> not null # src might be a packet passed in from outside ~> could be null return f"""\ -if ({src}->{location}) {{ - strvec_copy({dest}->{location}, {src}->{location}); +if ({location @ src}) {{ + strvec_copy({location @ dest}, {location @ src}); }} else {{ - strvec_clear({dest}->{location}); + strvec_clear({location @ dest}); }} """ def get_code_free(self, location: Location, packet: str) -> str: return f"""\ -if ({packet}->{location}) {{ - strvec_destroy({packet}->{location}); - {packet}->{location} = nullptr; +if ({location @ packet}) {{ + strvec_destroy({location @ packet}); + {location @ packet} = nullptr; }} """ @@ -1695,10 +1727,10 @@ if ({packet}->{location}) {{ def get_code_cmp(self, location: Location, new: str, old: str) -> str: # "new" packet passed in from outside might have null vector return f"""\ -if ({new}->{location}) {{ - differ = !are_strvecs_equal({old}->{location}, {new}->{location}); +if ({location @ new}) {{ + differ = !are_strvecs_equal({location @ old}, {location @ new}); }} else {{ - differ = (strvec_size({old}->{location}) > 0); + differ = (strvec_size({location @ old}) > 0); }} """ @@ -1710,7 +1742,7 @@ if ({new}->{location}) {{ # which we're a long way from size = __class__._VecSize(location) return f"""\ -if (!{packet}->{location}) {{ +if (!{location @ packet}) {{ /* Transmit null vector as empty vector */ e |= DIO_PUT(arraylen, &dout, &field_addr, 0); }} else {{ @@ -1725,7 +1757,7 @@ if (!{packet}->{location}) {{ #endif /* FREECIV_JSON_CONNECTION */ for ({location.index} = 0; {location.index} < {size.actual_for(packet)}; {location.index}++) {{ - const char *pstr = strvec_get({packet}->{location}, {location.index}); + const char *pstr = strvec_get({location @ packet}, {location.index}); if (!pstr) {{ /* Transmit null strings as empty strings */ @@ -1753,7 +1785,7 @@ if (!{packet}->{location}) {{ index_put = prefix(" ", size.index_put(packet, location.index)) index_put_sentinel = prefix(" ", size.index_put(packet, size.actual_for(packet))) return f"""\ -if (!{packet}->{location} || 0 == {size.actual_for(packet)}) {{ +if (!{location @ packet} || 0 == {size.actual_for(packet)}) {{ /* Special case for empty vector. */ #ifdef FREECIV_JSON_CONNECTION @@ -1801,7 +1833,7 @@ if (!{packet}->{location} || 0 == {size.actual_for(packet)}) {{ #endif /* FREECIV_JSON_CONNECTION */ for ({location.index} = 0; {location.index} < {size.actual_for(packet)}; {location.index}++) {{ - const char *pstr = strvec_get({packet}->{location}, {location.index}); + const char *pstr = strvec_get({location @ packet}, {location.index}); if (!pstr) {{ /* Transmit null strings as empty strings */ @@ -1809,7 +1841,7 @@ if (!{packet}->{location} || 0 == {size.actual_for(packet)}) {{ }} if ({location.index} < {size.actual_for(diff_packet)}) {{ - const char *pstr_old = strvec_get({diff_packet}->{location}, {location.index}); + const char *pstr_old = strvec_get({location @ diff_packet}, {location.index}); differ = (strcmp(pstr_old ? pstr_old : "", pstr) != 0); }} else {{ @@ -1888,7 +1920,7 @@ if (!{packet}->{location} || 0 == {size.actual_for(packet)}) {{ if (!DIO_GET(arraylen, &din, &field_addr, &{location.index})) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} - strvec_reserve({packet}->{location}, {location.index}); + strvec_reserve({location @ packet}, {location.index}); #ifdef FREECIV_JSON_CONNECTION {location.json_subloc} = plocation_elem_new(0); @@ -1903,7 +1935,7 @@ if (!{packet}->{location} || 0 == {size.actual_for(packet)}) {{ #endif /* FREECIV_JSON_CONNECTION */ if (!DIO_GET({self.dataio_type}, &din, &field_addr, readin, sizeof(readin)) - || !strvec_set({packet}->{location}, {location.index}, readin)) {{ + || !strvec_set({location @ packet}, {location.index}, readin)) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} }} @@ -1930,7 +1962,7 @@ if (!{packet}->{location} || 0 == {size.actual_for(packet)}) {{ if (!DIO_GET(uint16, &din, &field_addr, &readin)) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} - strvec_reserve({packet}->{location}, readin); + strvec_reserve({location @ packet}, readin); }} #ifdef FREECIV_JSON_CONNECTION @@ -1966,7 +1998,7 @@ while ({size.actual_for(packet)} > 0) {{ #endif /* FREECIV_JSON_CONNECTION */ if (!DIO_GET({self.dataio_type}, &din, &field_addr, readin, sizeof(readin)) - || !strvec_set({packet}->{location}, {location.index}, readin)) {{ + || !strvec_set({location @ packet}, {location.index}, readin)) {{ RECEIVE_PACKET_FIELD_ERROR({location.name}); }} -- 2.34.1