From 421aab5febc8d1f02027a7c4d8e411a791c58259 Mon Sep 17 00:00:00 2001 From: Alina Lenk Date: Sun, 14 Apr 2024 18:10:20 +0200 Subject: [PATCH] Support complex field types in network packets Introduce code to initialize, copy and destroy packet structs in a way that allows specially handling fields that require it in the future See RM #446 Signed-off-by: Alina Lenk --- client/clinet.c | 4 +- common/generate_packets.py | 358 ++++++++++++++++++++++++++----- common/networking/packets.h | 4 + common/networking/packets_json.h | 3 + server/sernet.c | 2 +- 5 files changed, 310 insertions(+), 61 deletions(-) diff --git a/client/clinet.c b/client/clinet.c index c728c79a5e..62744e616b 100644 --- a/client/clinet.c +++ b/client/clinet.c @@ -423,7 +423,7 @@ void input_from_server(int fd) if (NULL != packet) { client_packet_input(packet, type); - free(packet); + packet_destroy(packet, type); } else { break; } @@ -467,7 +467,7 @@ void input_from_server_till_request_got_processed(int fd, } client_packet_input(packet, type); - free(packet); + packet_destroy(packet, type); if (type == PACKET_PROCESSING_FINISHED) { log_debug("ifstrgp: expect=%d, seen=%d", diff --git a/common/generate_packets.py b/common/generate_packets.py index 368fba8f90..b5ab9027c1 100755 --- a/common/generate_packets.py +++ b/common/generate_packets.py @@ -677,6 +677,10 @@ class FieldType(RawFieldType): foldable: bool = False """Whether a field of this type can be folded into the packet header""" + complex: bool = False + """Whether a field of this type needs special handling when initializing, + copying or destroying the packet struct""" + @cache def array(self, size: SizeInfo) -> "FieldType": """Construct a FieldType for an array with element type self and the @@ -704,11 +708,45 @@ class FieldType(RawFieldType): See also self.get_code_handle_param()""" return str(location) - @abstractmethod + def get_code_init(self, location: Location) -> str: + """Generate a code snippet initializing a field of this type in the + packet struct, after the struct has already been zeroed. + + Subclasses must override this if self.complex is True""" + if self.complex: + raise ValueError(f"default get_code_init implementation called for field {location.name} with complex type {self!r}") + return f"""\ +/* no work needed for {location} */ +""" + + def get_code_copy(self, location: Location, dest: str, src: str) -> str: + """Generate a code snippet deep-copying a field of this type from + one packet struct to another that has already been initialized. + + Subclasses must override this if self.complex is True""" + if self.complex: + raise ValueError(f"default get_code_copy implementation called for field {location.name} with complex type {self!r}") + return f"""\ +{dest}->{location} = {src}->{location}; +""" + def get_code_fill(self, location: Location) -> str: - """Generate a code snippet moving a value of this type from dsend - arguments into a packet struct.""" - raise NotImplementedError + """Generate a code snippet shallow-copying a value of this type from + dsend arguments into a packet struct.""" + return f"""\ +real_packet->{location} = {location}; +""" + + def get_code_free(self, location: Location) -> str: + """Generate a code snippet deinitializing a field of this type in + the packet struct before it gets destroyed. + + Subclasses must override this if self.complex is True""" + if self.complex: + raise ValueError(f"default get_code_free implementation called for field {location.name} with complex type {self!r}") + return f"""\ +/* no work needed for {location} */ +""" @abstractmethod def get_code_hash(self, location: Location) -> str: @@ -756,11 +794,6 @@ class BasicType(FieldType): def get_code_handle_param(self, location: Location) -> str: return f"{self.public_type} {location}" - def get_code_fill(self, location: Location) -> str: - return f"""\ -real_packet->{location} = {location}; -""" - def get_code_hash(self, location: Location) -> str: raise ValueError(f"hash not supported for type {self} in field {location.name}") @@ -1028,6 +1061,11 @@ class WorklistType(StructType): super().__init__(dataio_type, public_type) + def get_code_copy(self, location: Location, dest: str, src: str) -> str: + return f"""\ +worklist_copy(&{dest}->{location}, &{src}->{location}); +""" + def get_code_fill(self, location: Location) -> str: return f"""\ worklist_copy(&real_packet->{location}, {location}); @@ -1060,6 +1098,10 @@ class SizedType(BasicType): def get_code_fill(self, location: Location) -> str: return super().get_code_fill(location) + @abstractmethod + def get_code_copy(self, location: Location, dest: str, src: str) -> str: + return super().get_code_copy(location, dest, src) + def __str__(self) -> str: return f"{super().__str__()}[{self.size}]" @@ -1079,6 +1121,11 @@ class StringType(SizedType): def get_code_fill(self, location: Location) -> str: return f"""\ sz_strlcpy(real_packet->{location}, {location}); +""" + + def get_code_copy(self, location: Location, dest: str, src: str) -> str: + return f"""\ +sz_strlcpy({dest}->{location}, {src}->{location}); """ def get_code_cmp(self, location: Location) -> str: @@ -1108,6 +1155,11 @@ class MemoryType(SizedType): def get_code_fill(self, location: Location) -> str: raise NotImplementedError("fill not supported for memory-type fields") + def get_code_copy(self, location: Location, dest: str, src: str) -> str: + return f"""\ +memcpy({dest}->{location}, {src}->{location}, {self.size.actual_for(src)}); +""" + def get_code_cmp(self, location: Location) -> str: if self.size.constant: return f"""\ @@ -1151,6 +1203,10 @@ class ArrayType(FieldType): self.elem = elem self.size = size + @property + def complex(self) -> bool: + return self.elem.complex + def get_code_declaration(self, location: Location) -> str: return self.elem.get_code_declaration( location.deeper(f"{location}[{self.size.declared}]") @@ -1161,6 +1217,38 @@ class ArrayType(FieldType): pre = "" if location.depth else "const " return pre + self.elem.get_code_handle_param(location.deeper(f"*{location}")) + def get_code_init(self, location: Location) -> str: + if not self.complex: + return super().get_code_init(location) + inner_init = prefix(" ", self.elem.get_code_init(location.sub)) + # Note: we're initializing and destroying *all* elements of the array, + # not just those up to the actual size; otherwise we'd have to + # dynamically initialize and destroy elements as the actual size changes + return f"""\ +{{ + int {location.index}; + + for ({location.index} = 0; {location.index} < {self.size.declared}; {location.index}++) {{ +{inner_init}\ + }} +}} +""" + + def get_code_copy(self, location: Location, dest: str, src: str) -> str: + # can't use direct assignment to bit-copy a raw array, + # even if our type is not complex + inner_copy = prefix(" ", self.elem.get_code_copy(location.sub, dest, src)) + # FIXME: can't use self.size.real; have to use actual_for(src context) + return f"""\ +{{ + int {location.index}; + + for ({location.index} = 0; {location.index} < {self.size.actual_for(src)}; {location.index}++) {{ +{inner_copy}\ + }} +}} +""" + def get_code_fill(self, location: Location) -> str: inner_fill = prefix(" ", self.elem.get_code_fill(location.sub)) return f"""\ @@ -1171,6 +1259,23 @@ class ArrayType(FieldType): {inner_fill}\ }} }} +""" + + def get_code_free(self, location: Location) -> str: + if not self.complex: + return super().get_code_free(location) + inner_free = prefix(" ", self.elem.get_code_free(location.sub)) + # Note: we're initializing and destroying *all* elements of the array, + # not just those up to the actual size; otherwise we'd have to + # dynamically initialize and destroy elements as the actual size changes + return f"""\ +{{ + int {location.index}; + + for ({location.index} = 0; {location.index} < {self.size.declared}; {location.index}++) {{ +{inner_free}\ + }} +}} """ def get_code_hash(self, location: Location) -> str: @@ -1498,6 +1603,12 @@ class Field: """Set of all capabilities affecting this field""" return self.flags.add_caps | self.flags.remove_caps + @property + def complex(self) -> bool: + """Whether this field's type requires special handling; + see FieldType.complex""" + return self.type_info.complex + def present_with_caps(self, caps: typing.Container[str]) -> bool: """Determine whether this field should be part of a variant with the given capabilities""" @@ -1528,11 +1639,25 @@ class Field: packet_arrow + self.name, )) + def get_init(self) -> str: + """Generate code initializing this field in the packet struct, after + the struct has already been zeroed.""" + return self.type_info.get_code_init(Location(self.name)) + + def get_copy(self, dest: str, src: str) -> str: + """Generate code deep-copying this field from *src to *dest.""" + return self.type_info.get_code_copy(Location(self.name), dest, src) + def get_fill(self) -> str: - """Generate code moving this field from the dsend arguments into - the packet struct.""" + """Generate code shallow-copying this field from the dsend arguments + into the packet struct.""" return self.type_info.get_code_fill(Location(self.name)) + def get_free(self) -> str: + """Generate code deinitializing this field in the packet struct + before destroying the packet.""" + return self.type_info.get_code_free(Location(self.name)) + def get_hash(self) -> str: """Generate code factoring this field into a hash computation.""" assert self.is_key @@ -1853,6 +1978,15 @@ class Variant: See Packet.cancel""" return self.packet.cancel + @property + def complex(self) -> bool: + """Whether this packet's struct requires special handling for + initialization, copying, and destruction. + + Note that this is still True even if the complex-typed fields + of the packet are excluded from this Variant.""" + return self.packet.complex + @property def differ_used(self) -> bool: """Whether the send function needs a `differ` boolean. @@ -1928,6 +2062,18 @@ phandlers->send[{self.type}].packet = (int(*)(struct connection *, const void *) phandlers->receive[{self.type}] = (void *(*)(struct connection *)) receive_{self.name}; """ + def get_copy(self, dest: str, src: str) -> str: + """Generate code deep-copying the fields relevant to this variant + from *src to *dest""" + if not self.complex: + return f"""\ +*{dest} = *{src}; +""" + return "".join( + field.get_copy(dest, src) + for field in self.fields + ) + def get_stats(self) -> str: """Generate the declaration of the delta stats counters associated with this packet variant""" @@ -2067,21 +2213,44 @@ static bool cmp_{self.name}(const void *vkey1, const void *vkey2) log="" if self.no_packet: + # empty packet, don't need anything main_header = "" - else: - if self.packet.want_pre_send: - main_header = f"""\ + after_header = "" + before_return = "" + elif not self.packet.want_pre_send: + # no pre-send, don't need to copy the packet + main_header = f"""\ + const struct {self.packet_name} *real_packet = packet; + int e; +""" + after_header = "" + before_return = "" + elif not self.complex: + # bit-copy the packet + main_header = f"""\ /* copy packet for pre-send */ struct {self.packet_name} packet_buf = *packet; const struct {self.packet_name} *real_packet = &packet_buf; + int e; """ - else: - main_header = f"""\ - const struct {self.packet_name} *real_packet = packet; -""" - main_header += """\ + after_header = "" + before_return = "" + else: + # deep-copy the packet for pre-send, have to destroy the copy + copy = prefix(" ", self.get_copy("(&packet_buf)", "packet")) + main_header = f"""\ + /* buffer to hold packet copy for pre-send */ + struct {self.packet_name} packet_buf; + const struct {self.packet_name} *real_packet = &packet_buf; int e; """ + after_header = f"""\ + init_{self.packet_name}(&packet_buf); +{copy}\ +""" + before_return = f"""\ + free_{self.packet_name}(&packet_buf); +""" if not self.packet.want_pre_send: pre = "" @@ -2121,7 +2290,7 @@ static bool cmp_{self.name}(const void *vkey1, const void *vkey2) delta_header += """\ #endif /* FREECIV_DELTA_PROTOCOL */ """ - body = prefix(" ", self.get_delta_send_body()) + """\ + body = prefix(" ", self.get_delta_send_body(before_return)) + """\ #ifndef FREECIV_DELTA_PROTOCOL """ else: @@ -2178,11 +2347,13 @@ static bool cmp_{self.name}(const void *vkey1, const void *vkey2) SEND_PACKET_START({self.type}); """, faddr, + after_header, log, report, pre, body, post, + before_return, f"""\ SEND_PACKET_END({self.type}); }} @@ -2199,15 +2370,16 @@ static bool cmp_{self.name}(const void *vkey1, const void *vkey2) #ifdef FREECIV_DELTA_PROTOCOL if (nullptr == *hash) {{ *hash = genhash_new_full(hash_{self.name}, cmp_{self.name}, - nullptr, nullptr, nullptr, free); + nullptr, nullptr, nullptr, destroy_{self.packet_name}); }} BV_CLR_ALL(fields); if (!genhash_lookup(*hash, real_packet, (void **) &old)) {{ old = fc_malloc(sizeof(*old)); + /* temporary bitcopy just to insert correctly */ *old = *real_packet; genhash_insert(*hash, old, old); - memset(old, 0, sizeof(*old)); + init_{self.packet_name}(old); """ if self.is_info != "no": intro += """\ @@ -2265,10 +2437,8 @@ if (e) { field.get_put_wrapper(self, i, True) for i, field in enumerate(self.other_fields) ) - body += """\ - -*old = *real_packet; -""" + body += "\n" + body += self.get_copy("old", "real_packet") # Cancel some is-info packets. for i in self.cancel: @@ -2361,6 +2531,7 @@ if (nullptr != *hash) {{ f"""\ {self.receive_prototype} {{ +#define FREE_PACKET_STRUCT(_packet) free_{self.packet_name}(_packet) """, delta_header, f"""\ @@ -2374,6 +2545,7 @@ if (nullptr != *hash) {{ post, """\ RECEIVE_PACKET_END(real_packet); +#undef FREE_PACKET_STRUCT } """, @@ -2383,48 +2555,27 @@ if (nullptr != *hash) {{ """Helper for get_receive(). Generate the part of the receive function responsible for recreating the full packet from the received delta and the last cached packet.""" - if self.key_fields: - # bit-copy the values, since we're moving (not cloning) - # the key fields - # FIXME: might not work for arrays - backup_key = "".join( - prefix(" ", field.get_declar()) - for field in self.key_fields - ) + "\n"+ "".join( - f"""\ - {field.name} = real_packet->{field.name}; -""" - for field in self.key_fields - ) + "\n" - restore_key = "\n" + "".join( - f"""\ - real_packet->{field.name} = {field.name}; -""" - for field in self.key_fields - ) - else: - backup_key = restore_key = "" if self.gen_log: fl = f"""\ {self.log_macro}(" no old info"); """ else: fl="" + + copy_from_old = prefix(" ", self.get_copy("real_packet", "old")) body = f"""\ #ifdef FREECIV_DELTA_PROTOCOL if (nullptr == *hash) {{ *hash = genhash_new_full(hash_{self.name}, cmp_{self.name}, - nullptr, nullptr, nullptr, free); + nullptr, nullptr, nullptr, destroy_{self.packet_name}); }} if (genhash_lookup(*hash, real_packet, (void **) &old)) {{ - *real_packet = *old; +{copy_from_old}\ }} else {{ -{backup_key}\ + /* packet is already initialized empty */ {fl}\ - memset(real_packet, 0, sizeof(*real_packet)); -{restore_key}\ }} """ @@ -2433,15 +2584,17 @@ if (genhash_lookup(*hash, real_packet, (void **) &old)) {{ for i, field in enumerate(self.other_fields) ) - extro = """\ + copy_to_old = prefix(" ", self.get_copy("old", "real_packet")) + extro = f"""\ -if (nullptr == old) { +if (nullptr == old) {{ old = fc_malloc(sizeof(*old)); - *old = *real_packet; + init_{self.packet_name}(old); +{copy_to_old}\ genhash_insert(*hash, old, old); -} else { - *old = *real_packet; -} +}} else {{ +{copy_to_old}\ +}} """ # Cancel some is-info packets. @@ -2726,6 +2879,11 @@ class Packet: """Set of all capabilities affecting this packet""" return {cap for field in self.fields for cap in field.all_caps} + @property + def complex(self) -> bool: + """Whether this packet's struct requires special handling for + initialization, copying, and destruction.""" + return any(field.complex for field in self.fields) def get_struct(self) -> str: """Generate the struct definition for this packet""" @@ -2785,6 +2943,58 @@ struct {self.name} {{ PacketsDefinition.code_delta_stats_reset""" return "\n".join(v.get_reset_part() for v in self.variants) + def get_init(self) -> str: + """Generate this packet's init function, which initializes the + packet struct so its complex-typed fields are useable, and sets + all fields to the empty default state used for computing deltas""" + if self.complex: + field_parts = "\n" + "".join( + prefix(" ", field.get_init()) + for field in self.fields + ) + else: + field_parts = "" + return f"""\ +static inline void init_{self.name}(struct {self.name} *packet) +{{ + memset(packet, 0, sizeof(*packet)); +{field_parts}\ +}} + +""" + + def get_free_destroy(self) -> str: + """Generate this packet's free and destroy functions, which free + memory associated with complex-typed fields of this packet, and + optionally the allocation of the packet itself (destroy).""" + if not self.complex: + return f"""\ +#define free_{self.name}(_packet) (void) 0 +#define destroy_{self.name} free + +""" + + # drop fields in reverse order, in case later fields depend on + # earlier fields (e.g. for actual array sizes) + field_parts = "".join( + prefix(" ", field.get_free()) + for field in reversed(self.fields) + ) + # NB: destroy_*() takes void* to avoid casts + return f"""\ +static inline void free_{self.name}(struct {self.name} *packet) +{{ +{field_parts}\ +}} + +static inline void destroy_{self.name}(void *packet) +{{ + free_{self.name}((struct {self.name} *) packet); + free(packet); +}} + +""" + def get_send(self) -> str: """Generate the implementation of the send function, which sends a given packet to a given connection.""" @@ -2851,6 +3061,7 @@ struct {self.name} {{ """Generate the implementation of the dsend function, which directly takes packet fields instead of a packet struct.""" if not self.want_dsend: return "" + # safety: fill just borrows the given values; no init/free necessary fill = "".join( prefix(" ", field.get_fill()) for field in self.fields @@ -2873,6 +3084,7 @@ struct {self.name} {{ See self.get_dsend() and self.get_lsend()""" if not (self.want_lsend and self.want_dsend): return "" + # safety: fill just borrows the given values; no init/free necessary fill = "".join( prefix(" ", field.get_fill()) for field in self.fields @@ -3380,6 +3592,33 @@ void packet_handlers_fill_capability(struct packet_handlers *phandlers, """ return intro + body + extro + @property + def code_packet_destroy(self) -> str: + """Code fragment implementing the packet_destroy() function""" + # NB: missing packet IDs are empty-initialized, i.e. set to nullptr by default + handlers = "".join( + f"""\ + [{packet.type}] = destroy_{packet.name}, +""" + for packet in self + ) + + return f"""\ + +void packet_destroy(void *packet, enum packet_type type) +{{ + static void (*const destroy_handlers[PACKET_LAST])(void *packet) = {{ +{handlers}\ + }}; + void (*handler)(void *packet) = (type < PACKET_LAST ? destroy_handlers[type] : nullptr); + + fc_assert_action_msg(handler != nullptr, handler = free, + "packet_destroy(): invalid packet type %d", type); + + handler(packet); +}} +""" + @property def code_enum_packet(self) -> str: """Code fragment declaring the packet_type enum""" @@ -3505,6 +3744,8 @@ static int stats_total_sent; # write hash, cmp, send, receive for p in packets: + output_c.write(p.get_init()) + output_c.write(p.get_free_destroy()) output_c.write(p.get_variants()) output_c.write(p.get_send()) output_c.write(p.get_lsend()) @@ -3513,6 +3754,7 @@ static int stats_total_sent; output_c.write(packets.code_packet_handlers_fill_initial) output_c.write(packets.code_packet_handlers_fill_capability) + output_c.write(packets.code_packet_destroy) def write_server_header(path: "str | Path | None", packets: PacketsDefinition): """Write contents for server/hand_gen.h to the given path""" diff --git a/common/networking/packets.h b/common/networking/packets.h index d756701cbd..17c5faac41 100644 --- a/common/networking/packets.h +++ b/common/networking/packets.h @@ -103,6 +103,7 @@ void packet_handlers_fill_capability(struct packet_handlers *phandlers, const char *capability); const char *packet_name(enum packet_type type); bool packet_has_game_info_flag(enum packet_type type); +void packet_destroy(void *packet, enum packet_type type); void packet_header_init(struct packet_header *packet_header); void post_send_packet_server_join_reply(struct connection *pconn, @@ -147,6 +148,7 @@ void packets_deinit(void); struct data_in din; \ struct packet_type packet_buf, *result = &packet_buf; \ \ + init_ ##packet_type (&packet_buf); \ dio_input_init(&din, pc->buffer->data, \ data_type_size(pc->packet_header.length)); \ { \ @@ -160,6 +162,7 @@ void packets_deinit(void); #define RECEIVE_PACKET_END(result) \ if (!packet_check(&din, pc)) { \ + FREE_PACKET_STRUCT(&packet_buf); \ return NULL; \ } \ remove_packet_from_buffer(pc->buffer); \ @@ -169,6 +172,7 @@ void packets_deinit(void); #define RECEIVE_PACKET_FIELD_ERROR(field, ...) \ log_packet("Error on field '" #field "'" __VA_ARGS__); \ + FREE_PACKET_STRUCT(&packet_buf); \ return NULL #endif /* FREECIV_JSON_PROTOCOL */ diff --git a/common/networking/packets_json.h b/common/networking/packets_json.h index c3fcd6ae39..1542b20475 100644 --- a/common/networking/packets_json.h +++ b/common/networking/packets_json.h @@ -70,6 +70,7 @@ void *get_packet_from_connection_json(struct connection *pc, #define RECEIVE_PACKET_START(packet_type, result) \ struct packet_type packet_buf, *result = &packet_buf; \ struct data_in din; \ + init_ ##packet_type (&packet_buf); \ if (!pc->json_mode) { \ dio_input_init(&din, pc->buffer->data, \ data_type_size(pc->packet_header.length)); \ @@ -91,6 +92,7 @@ void *get_packet_from_connection_json(struct connection *pc, return result; \ } else { \ if (!packet_check(&din, pc)) { \ + FREE_PACKET_STRUCT(&packet_buf); \ return NULL; \ } \ remove_packet_from_buffer(pc->buffer); \ @@ -101,6 +103,7 @@ void *get_packet_from_connection_json(struct connection *pc, #define RECEIVE_PACKET_FIELD_ERROR(field, ...) \ log_packet("Error on field '" #field "'" __VA_ARGS__); \ + FREE_PACKET_STRUCT(&packet_buf); \ return NULL; /* Utilities to exchange strings and string vectors. */ diff --git a/server/sernet.c b/server/sernet.c index 00dd808fed..a24f1df746 100644 --- a/server/sernet.c +++ b/server/sernet.c @@ -490,7 +490,7 @@ static void incoming_client_packets(struct connection *pconn) start_processing_request(pconn, pconn->server.last_request_id_seen); command_ok = server_packet_input(pconn, packet.data, packet.type); - free(packet.data); + packet_destroy(packet.data, packet.type); finish_processing_request(pconn); connection_do_unbuffer(pconn); -- 2.34.1