mirror of
https://github.com/esphome/esphome.git
synced 2025-08-03 08:57:47 +00:00
preen
This commit is contained in:
parent
ae945c9a96
commit
dd35038771
@ -30,31 +30,32 @@ class WireType(IntEnum):
|
||||
FIXED32 = 5 # fixed32, sfixed32, float
|
||||
|
||||
|
||||
# Message type registry - maps message names to type IDs
|
||||
MESSAGE_TYPE_REGISTRY = {}
|
||||
NEXT_MESSAGE_TYPE_ID = 0 # Start at 0 for array indexing
|
||||
class MessageTypeRegistry:
|
||||
"""Manages message type ID assignments."""
|
||||
|
||||
# Repeated message type registry - separate from regular messages
|
||||
REPEATED_MESSAGE_TYPE_REGISTRY = {}
|
||||
NEXT_REPEATED_MESSAGE_TYPE_ID = 0 # Start at 0 for array indexing
|
||||
def __init__(self):
|
||||
self.message_registry = {}
|
||||
self.next_message_id = 0
|
||||
self.repeated_registry = {}
|
||||
self.next_repeated_id = 0
|
||||
|
||||
def get_message_type_id(self, message_name):
|
||||
"""Get or assign a type ID for a message type."""
|
||||
if message_name not in self.message_registry:
|
||||
self.message_registry[message_name] = self.next_message_id
|
||||
self.next_message_id += 1
|
||||
return self.message_registry[message_name]
|
||||
|
||||
def get_repeated_message_type_id(self, message_name):
|
||||
"""Get or assign a type ID for a repeated message type."""
|
||||
if message_name not in self.repeated_registry:
|
||||
self.repeated_registry[message_name] = self.next_repeated_id
|
||||
self.next_repeated_id += 1
|
||||
return self.repeated_registry[message_name]
|
||||
|
||||
|
||||
def get_message_type_id(message_name):
|
||||
"""Get or assign a type ID for a message type."""
|
||||
global NEXT_MESSAGE_TYPE_ID
|
||||
if message_name not in MESSAGE_TYPE_REGISTRY:
|
||||
MESSAGE_TYPE_REGISTRY[message_name] = NEXT_MESSAGE_TYPE_ID
|
||||
NEXT_MESSAGE_TYPE_ID += 1
|
||||
return MESSAGE_TYPE_REGISTRY[message_name]
|
||||
|
||||
|
||||
def get_repeated_message_type_id(message_name):
|
||||
"""Get or assign a type ID for a repeated message type."""
|
||||
global NEXT_REPEATED_MESSAGE_TYPE_ID
|
||||
if message_name not in REPEATED_MESSAGE_TYPE_REGISTRY:
|
||||
REPEATED_MESSAGE_TYPE_REGISTRY[message_name] = NEXT_REPEATED_MESSAGE_TYPE_ID
|
||||
NEXT_REPEATED_MESSAGE_TYPE_ID += 1
|
||||
return REPEATED_MESSAGE_TYPE_REGISTRY[message_name]
|
||||
# Create a global instance
|
||||
type_registry = MessageTypeRegistry()
|
||||
|
||||
|
||||
# Mapping from protobuf types to our ProtoFieldType enum
|
||||
@ -1345,7 +1346,9 @@ def build_message_type(
|
||||
|
||||
if field.type == descriptor.FieldDescriptorProto.TYPE_MESSAGE:
|
||||
# For messages, use offset_low and message_type_id with offset extension
|
||||
message_type_id = get_repeated_message_type_id(ti._ti.type_name)
|
||||
message_type_id = type_registry.get_repeated_message_type_id(
|
||||
ti._ti.type_name
|
||||
)
|
||||
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
|
||||
repeated_fields_v3.append(
|
||||
f"{{{field.number}, {type_and_size}, {{.offset_low = static_cast<uint8_t>({offset} & 0xFF), .message_type_id = static_cast<uint8_t>({message_type_id} | ((({offset} >> 8) & 0x0F) << 4))}}}}"
|
||||
@ -1367,7 +1370,9 @@ def build_message_type(
|
||||
|
||||
if field.type == descriptor.FieldDescriptorProto.TYPE_MESSAGE:
|
||||
# For messages, use offset_low and message_type_id
|
||||
message_type_id = get_message_type_id(ti.type_name)
|
||||
message_type_id = type_registry.get_message_type_id(
|
||||
ti.type_name
|
||||
)
|
||||
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
|
||||
|
||||
# Since we have so few message types, we can use the upper bits of
|
||||
@ -1393,7 +1398,7 @@ def build_message_type(
|
||||
field_tag_size = ti.calculate_field_id_size()
|
||||
# Messages are TYPE_MESSAGE (10)
|
||||
type_and_size = (10 & 0x1F) | ((field_tag_size - 1) << 5)
|
||||
message_type_id = get_message_type_id(ti.type_name)
|
||||
message_type_id = type_registry.get_message_type_id(ti.type_name)
|
||||
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
|
||||
# Same encoding as above for large offsets
|
||||
regular_fields_v3.append(
|
||||
@ -1877,9 +1882,9 @@ namespace api {
|
||||
|
||||
mt = file.message_type
|
||||
|
||||
# First pass: Register all message types to populate MESSAGE_TYPE_REGISTRY
|
||||
# First pass: Register all message types to populate the type registry
|
||||
for m in mt:
|
||||
get_message_type_id(m.name)
|
||||
type_registry.get_message_type_id(m.name)
|
||||
|
||||
# Collect messages by base class
|
||||
base_class_groups = collect_messages_by_base_class(mt)
|
||||
@ -1953,8 +1958,6 @@ namespace api {
|
||||
|
||||
for meta in response_metadata:
|
||||
class_name = meta["class_name"]
|
||||
regular_fields = meta["regular_fields"]
|
||||
repeated_fields = meta["repeated_fields"]
|
||||
regular_fields_v3 = meta.get("regular_fields_v3", [])
|
||||
repeated_fields_v3 = meta.get("repeated_fields_v3", [])
|
||||
msg_ifdef = meta["ifdef"]
|
||||
@ -2053,15 +2056,16 @@ namespace api {
|
||||
|
||||
# Sort message types by their assigned IDs to match the metadata
|
||||
sorted_message_types_by_id = sorted(
|
||||
MESSAGE_TYPE_REGISTRY.keys(), key=lambda x: MESSAGE_TYPE_REGISTRY[x]
|
||||
type_registry.message_registry.keys(),
|
||||
key=lambda x: type_registry.message_registry[x],
|
||||
)
|
||||
sorted_repeated_message_types_by_id = sorted(
|
||||
REPEATED_MESSAGE_TYPE_REGISTRY.keys(),
|
||||
key=lambda x: REPEATED_MESSAGE_TYPE_REGISTRY[x],
|
||||
type_registry.repeated_registry.keys(),
|
||||
key=lambda x: type_registry.repeated_registry[x],
|
||||
)
|
||||
|
||||
# Generate MESSAGE_HANDLERS array with proper ifdefs
|
||||
cpp += f"const MessageHandler MESSAGE_HANDLERS[{len(MESSAGE_TYPE_REGISTRY) or 1}] = {{\n"
|
||||
cpp += f"const MessageHandler MESSAGE_HANDLERS[{len(type_registry.message_registry) or 1}] = {{\n"
|
||||
|
||||
# Generate entries in ID order
|
||||
for msg_type in sorted_message_types_by_id:
|
||||
@ -2089,10 +2093,12 @@ namespace api {
|
||||
if cpp.endswith(",\n"):
|
||||
cpp = cpp[:-2] + "\n"
|
||||
cpp += "};\n"
|
||||
cpp += f"const size_t MESSAGE_HANDLER_COUNT = {len(MESSAGE_TYPE_REGISTRY)};\n"
|
||||
cpp += (
|
||||
f"const size_t MESSAGE_HANDLER_COUNT = {len(type_registry.message_registry)};\n"
|
||||
)
|
||||
|
||||
# Generate REPEATED_MESSAGE_HANDLERS array with same approach
|
||||
cpp += f"\nconst RepeatedMessageHandler REPEATED_MESSAGE_HANDLERS[{len(REPEATED_MESSAGE_TYPE_REGISTRY) or 1}] = {{\n"
|
||||
cpp += f"\nconst RepeatedMessageHandler REPEATED_MESSAGE_HANDLERS[{len(type_registry.repeated_registry) or 1}] = {{\n"
|
||||
|
||||
# Generate entries in ID order
|
||||
for msg_type in sorted_repeated_message_types_by_id:
|
||||
@ -2120,7 +2126,7 @@ namespace api {
|
||||
if cpp.endswith(",\n"):
|
||||
cpp = cpp[:-2] + "\n"
|
||||
cpp += "};\n"
|
||||
cpp += f"const size_t REPEATED_MESSAGE_HANDLER_COUNT = {len(REPEATED_MESSAGE_TYPE_REGISTRY)};\n"
|
||||
cpp += f"const size_t REPEATED_MESSAGE_HANDLER_COUNT = {len(type_registry.repeated_registry)};\n"
|
||||
|
||||
cpp += """\
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user