This commit is contained in:
J. Nick Koston 2025-07-10 11:58:47 -10:00
parent ae945c9a96
commit dd35038771
No known key found for this signature in database

View File

@ -30,31 +30,32 @@ class WireType(IntEnum):
FIXED32 = 5 # fixed32, sfixed32, float
# Message type registry - maps message names to type IDs
MESSAGE_TYPE_REGISTRY = {}
NEXT_MESSAGE_TYPE_ID = 0 # Start at 0 for array indexing
class MessageTypeRegistry:
"""Manages message type ID assignments."""
# Repeated message type registry - separate from regular messages
REPEATED_MESSAGE_TYPE_REGISTRY = {}
NEXT_REPEATED_MESSAGE_TYPE_ID = 0 # Start at 0 for array indexing
def __init__(self):
self.message_registry = {}
self.next_message_id = 0
self.repeated_registry = {}
self.next_repeated_id = 0
def get_message_type_id(self, message_name):
"""Get or assign a type ID for a message type."""
if message_name not in self.message_registry:
self.message_registry[message_name] = self.next_message_id
self.next_message_id += 1
return self.message_registry[message_name]
def get_repeated_message_type_id(self, message_name):
"""Get or assign a type ID for a repeated message type."""
if message_name not in self.repeated_registry:
self.repeated_registry[message_name] = self.next_repeated_id
self.next_repeated_id += 1
return self.repeated_registry[message_name]
def get_message_type_id(message_name):
"""Get or assign a type ID for a message type."""
global NEXT_MESSAGE_TYPE_ID
if message_name not in MESSAGE_TYPE_REGISTRY:
MESSAGE_TYPE_REGISTRY[message_name] = NEXT_MESSAGE_TYPE_ID
NEXT_MESSAGE_TYPE_ID += 1
return MESSAGE_TYPE_REGISTRY[message_name]
def get_repeated_message_type_id(message_name):
"""Get or assign a type ID for a repeated message type."""
global NEXT_REPEATED_MESSAGE_TYPE_ID
if message_name not in REPEATED_MESSAGE_TYPE_REGISTRY:
REPEATED_MESSAGE_TYPE_REGISTRY[message_name] = NEXT_REPEATED_MESSAGE_TYPE_ID
NEXT_REPEATED_MESSAGE_TYPE_ID += 1
return REPEATED_MESSAGE_TYPE_REGISTRY[message_name]
# Create a global instance
type_registry = MessageTypeRegistry()
# Mapping from protobuf types to our ProtoFieldType enum
@ -1345,7 +1346,9 @@ def build_message_type(
if field.type == descriptor.FieldDescriptorProto.TYPE_MESSAGE:
# For messages, use offset_low and message_type_id with offset extension
message_type_id = get_repeated_message_type_id(ti._ti.type_name)
message_type_id = type_registry.get_repeated_message_type_id(
ti._ti.type_name
)
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
repeated_fields_v3.append(
f"{{{field.number}, {type_and_size}, {{.offset_low = static_cast<uint8_t>({offset} & 0xFF), .message_type_id = static_cast<uint8_t>({message_type_id} | ((({offset} >> 8) & 0x0F) << 4))}}}}"
@ -1367,7 +1370,9 @@ def build_message_type(
if field.type == descriptor.FieldDescriptorProto.TYPE_MESSAGE:
# For messages, use offset_low and message_type_id
message_type_id = get_message_type_id(ti.type_name)
message_type_id = type_registry.get_message_type_id(
ti.type_name
)
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
# Since we have so few message types, we can use the upper bits of
@ -1393,7 +1398,7 @@ def build_message_type(
field_tag_size = ti.calculate_field_id_size()
# Messages are TYPE_MESSAGE (10)
type_and_size = (10 & 0x1F) | ((field_tag_size - 1) << 5)
message_type_id = get_message_type_id(ti.type_name)
message_type_id = type_registry.get_message_type_id(ti.type_name)
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
# Same encoding as above for large offsets
regular_fields_v3.append(
@ -1877,9 +1882,9 @@ namespace api {
mt = file.message_type
# First pass: Register all message types to populate MESSAGE_TYPE_REGISTRY
# First pass: Register all message types to populate the type registry
for m in mt:
get_message_type_id(m.name)
type_registry.get_message_type_id(m.name)
# Collect messages by base class
base_class_groups = collect_messages_by_base_class(mt)
@ -1953,8 +1958,6 @@ namespace api {
for meta in response_metadata:
class_name = meta["class_name"]
regular_fields = meta["regular_fields"]
repeated_fields = meta["repeated_fields"]
regular_fields_v3 = meta.get("regular_fields_v3", [])
repeated_fields_v3 = meta.get("repeated_fields_v3", [])
msg_ifdef = meta["ifdef"]
@ -2053,15 +2056,16 @@ namespace api {
# Sort message types by their assigned IDs to match the metadata
sorted_message_types_by_id = sorted(
MESSAGE_TYPE_REGISTRY.keys(), key=lambda x: MESSAGE_TYPE_REGISTRY[x]
type_registry.message_registry.keys(),
key=lambda x: type_registry.message_registry[x],
)
sorted_repeated_message_types_by_id = sorted(
REPEATED_MESSAGE_TYPE_REGISTRY.keys(),
key=lambda x: REPEATED_MESSAGE_TYPE_REGISTRY[x],
type_registry.repeated_registry.keys(),
key=lambda x: type_registry.repeated_registry[x],
)
# Generate MESSAGE_HANDLERS array with proper ifdefs
cpp += f"const MessageHandler MESSAGE_HANDLERS[{len(MESSAGE_TYPE_REGISTRY) or 1}] = {{\n"
cpp += f"const MessageHandler MESSAGE_HANDLERS[{len(type_registry.message_registry) or 1}] = {{\n"
# Generate entries in ID order
for msg_type in sorted_message_types_by_id:
@ -2089,10 +2093,12 @@ namespace api {
if cpp.endswith(",\n"):
cpp = cpp[:-2] + "\n"
cpp += "};\n"
cpp += f"const size_t MESSAGE_HANDLER_COUNT = {len(MESSAGE_TYPE_REGISTRY)};\n"
cpp += (
f"const size_t MESSAGE_HANDLER_COUNT = {len(type_registry.message_registry)};\n"
)
# Generate REPEATED_MESSAGE_HANDLERS array with same approach
cpp += f"\nconst RepeatedMessageHandler REPEATED_MESSAGE_HANDLERS[{len(REPEATED_MESSAGE_TYPE_REGISTRY) or 1}] = {{\n"
cpp += f"\nconst RepeatedMessageHandler REPEATED_MESSAGE_HANDLERS[{len(type_registry.repeated_registry) or 1}] = {{\n"
# Generate entries in ID order
for msg_type in sorted_repeated_message_types_by_id:
@ -2120,7 +2126,7 @@ namespace api {
if cpp.endswith(",\n"):
cpp = cpp[:-2] + "\n"
cpp += "};\n"
cpp += f"const size_t REPEATED_MESSAGE_HANDLER_COUNT = {len(REPEATED_MESSAGE_TYPE_REGISTRY)};\n"
cpp += f"const size_t REPEATED_MESSAGE_HANDLER_COUNT = {len(type_registry.repeated_registry)};\n"
cpp += """\