From dd35038771bd1947b9e1c85a7c45ee697cc68dc1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 10 Jul 2025 11:58:47 -1000 Subject: [PATCH] preen --- script/api_protobuf/api_protobuf.py | 78 ++++++++++++++++------------- 1 file changed, 42 insertions(+), 36 deletions(-) diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index b44259237c..781c8d4521 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -30,31 +30,32 @@ class WireType(IntEnum): FIXED32 = 5 # fixed32, sfixed32, float -# Message type registry - maps message names to type IDs -MESSAGE_TYPE_REGISTRY = {} -NEXT_MESSAGE_TYPE_ID = 0 # Start at 0 for array indexing +class MessageTypeRegistry: + """Manages message type ID assignments.""" -# Repeated message type registry - separate from regular messages -REPEATED_MESSAGE_TYPE_REGISTRY = {} -NEXT_REPEATED_MESSAGE_TYPE_ID = 0 # Start at 0 for array indexing + def __init__(self): + self.message_registry = {} + self.next_message_id = 0 + self.repeated_registry = {} + self.next_repeated_id = 0 + + def get_message_type_id(self, message_name): + """Get or assign a type ID for a message type.""" + if message_name not in self.message_registry: + self.message_registry[message_name] = self.next_message_id + self.next_message_id += 1 + return self.message_registry[message_name] + + def get_repeated_message_type_id(self, message_name): + """Get or assign a type ID for a repeated message type.""" + if message_name not in self.repeated_registry: + self.repeated_registry[message_name] = self.next_repeated_id + self.next_repeated_id += 1 + return self.repeated_registry[message_name] -def get_message_type_id(message_name): - """Get or assign a type ID for a message type.""" - global NEXT_MESSAGE_TYPE_ID - if message_name not in MESSAGE_TYPE_REGISTRY: - MESSAGE_TYPE_REGISTRY[message_name] = NEXT_MESSAGE_TYPE_ID - NEXT_MESSAGE_TYPE_ID += 1 - return MESSAGE_TYPE_REGISTRY[message_name] - - -def get_repeated_message_type_id(message_name): - """Get or assign a type ID for a repeated message type.""" - global NEXT_REPEATED_MESSAGE_TYPE_ID - if message_name not in REPEATED_MESSAGE_TYPE_REGISTRY: - REPEATED_MESSAGE_TYPE_REGISTRY[message_name] = NEXT_REPEATED_MESSAGE_TYPE_ID - NEXT_REPEATED_MESSAGE_TYPE_ID += 1 - return REPEATED_MESSAGE_TYPE_REGISTRY[message_name] +# Create a global instance +type_registry = MessageTypeRegistry() # Mapping from protobuf types to our ProtoFieldType enum @@ -1345,7 +1346,9 @@ def build_message_type( if field.type == descriptor.FieldDescriptorProto.TYPE_MESSAGE: # For messages, use offset_low and message_type_id with offset extension - message_type_id = get_repeated_message_type_id(ti._ti.type_name) + message_type_id = type_registry.get_repeated_message_type_id( + ti._ti.type_name + ) offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})" repeated_fields_v3.append( f"{{{field.number}, {type_and_size}, {{.offset_low = static_cast({offset} & 0xFF), .message_type_id = static_cast({message_type_id} | ((({offset} >> 8) & 0x0F) << 4))}}}}" @@ -1367,7 +1370,9 @@ def build_message_type( if field.type == descriptor.FieldDescriptorProto.TYPE_MESSAGE: # For messages, use offset_low and message_type_id - message_type_id = get_message_type_id(ti.type_name) + message_type_id = type_registry.get_message_type_id( + ti.type_name + ) offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})" # Since we have so few message types, we can use the upper bits of @@ -1393,7 +1398,7 @@ def build_message_type( field_tag_size = ti.calculate_field_id_size() # Messages are TYPE_MESSAGE (10) type_and_size = (10 & 0x1F) | ((field_tag_size - 1) << 5) - message_type_id = get_message_type_id(ti.type_name) + message_type_id = type_registry.get_message_type_id(ti.type_name) offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})" # Same encoding as above for large offsets regular_fields_v3.append( @@ -1877,9 +1882,9 @@ namespace api { mt = file.message_type - # First pass: Register all message types to populate MESSAGE_TYPE_REGISTRY + # First pass: Register all message types to populate the type registry for m in mt: - get_message_type_id(m.name) + type_registry.get_message_type_id(m.name) # Collect messages by base class base_class_groups = collect_messages_by_base_class(mt) @@ -1953,8 +1958,6 @@ namespace api { for meta in response_metadata: class_name = meta["class_name"] - regular_fields = meta["regular_fields"] - repeated_fields = meta["repeated_fields"] regular_fields_v3 = meta.get("regular_fields_v3", []) repeated_fields_v3 = meta.get("repeated_fields_v3", []) msg_ifdef = meta["ifdef"] @@ -2053,15 +2056,16 @@ namespace api { # Sort message types by their assigned IDs to match the metadata sorted_message_types_by_id = sorted( - MESSAGE_TYPE_REGISTRY.keys(), key=lambda x: MESSAGE_TYPE_REGISTRY[x] + type_registry.message_registry.keys(), + key=lambda x: type_registry.message_registry[x], ) sorted_repeated_message_types_by_id = sorted( - REPEATED_MESSAGE_TYPE_REGISTRY.keys(), - key=lambda x: REPEATED_MESSAGE_TYPE_REGISTRY[x], + type_registry.repeated_registry.keys(), + key=lambda x: type_registry.repeated_registry[x], ) # Generate MESSAGE_HANDLERS array with proper ifdefs - cpp += f"const MessageHandler MESSAGE_HANDLERS[{len(MESSAGE_TYPE_REGISTRY) or 1}] = {{\n" + cpp += f"const MessageHandler MESSAGE_HANDLERS[{len(type_registry.message_registry) or 1}] = {{\n" # Generate entries in ID order for msg_type in sorted_message_types_by_id: @@ -2089,10 +2093,12 @@ namespace api { if cpp.endswith(",\n"): cpp = cpp[:-2] + "\n" cpp += "};\n" - cpp += f"const size_t MESSAGE_HANDLER_COUNT = {len(MESSAGE_TYPE_REGISTRY)};\n" + cpp += ( + f"const size_t MESSAGE_HANDLER_COUNT = {len(type_registry.message_registry)};\n" + ) # Generate REPEATED_MESSAGE_HANDLERS array with same approach - cpp += f"\nconst RepeatedMessageHandler REPEATED_MESSAGE_HANDLERS[{len(REPEATED_MESSAGE_TYPE_REGISTRY) or 1}] = {{\n" + cpp += f"\nconst RepeatedMessageHandler REPEATED_MESSAGE_HANDLERS[{len(type_registry.repeated_registry) or 1}] = {{\n" # Generate entries in ID order for msg_type in sorted_repeated_message_types_by_id: @@ -2120,7 +2126,7 @@ namespace api { if cpp.endswith(",\n"): cpp = cpp[:-2] + "\n" cpp += "};\n" - cpp += f"const size_t REPEATED_MESSAGE_HANDLER_COUNT = {len(REPEATED_MESSAGE_TYPE_REGISTRY)};\n" + cpp += f"const size_t REPEATED_MESSAGE_HANDLER_COUNT = {len(type_registry.repeated_registry)};\n" cpp += """\