mirror of
https://github.com/esphome/esphome.git
synced 2025-07-28 14:16:40 +00:00
Use message_source_map consistently in proto generation (#9542)
This commit is contained in:
parent
0d422bd74f
commit
71cc298363
@ -1495,6 +1495,7 @@ def build_base_class(
|
||||
base_class_name: str,
|
||||
common_fields: list[descriptor.FieldDescriptorProto],
|
||||
messages: list[descriptor.DescriptorProto],
|
||||
message_source_map: dict[str, int],
|
||||
) -> tuple[str, str, str]:
|
||||
"""Build the base class definition and implementation."""
|
||||
public_content = []
|
||||
@ -1511,7 +1512,7 @@ def build_base_class(
|
||||
|
||||
# Determine if any message using this base class needs decoding
|
||||
needs_decode = any(
|
||||
get_opt(msg, pb.source, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT)
|
||||
message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT)
|
||||
for msg in messages
|
||||
)
|
||||
|
||||
@ -1543,6 +1544,7 @@ def build_base_class(
|
||||
|
||||
def generate_base_classes(
|
||||
base_class_groups: dict[str, list[descriptor.DescriptorProto]],
|
||||
message_source_map: dict[str, int],
|
||||
) -> tuple[str, str, str]:
|
||||
"""Generate all base classes."""
|
||||
all_headers = []
|
||||
@ -1556,7 +1558,7 @@ def generate_base_classes(
|
||||
if common_fields:
|
||||
# Generate base class
|
||||
header, cpp, dump_cpp = build_base_class(
|
||||
base_class_name, common_fields, messages
|
||||
base_class_name, common_fields, messages, message_source_map
|
||||
)
|
||||
all_headers.append(header)
|
||||
all_cpp.append(cpp)
|
||||
@ -1567,6 +1569,7 @@ def generate_base_classes(
|
||||
|
||||
def build_service_message_type(
|
||||
mt: descriptor.DescriptorProto,
|
||||
message_source_map: dict[str, int],
|
||||
) -> tuple[str, str] | None:
|
||||
"""Builds the service message type."""
|
||||
snake = camel_to_snake(mt.name)
|
||||
@ -1574,7 +1577,7 @@ def build_service_message_type(
|
||||
if id_ is None:
|
||||
return None
|
||||
|
||||
source: int = get_opt(mt, pb.source, 0)
|
||||
source: int = message_source_map.get(mt.name, SOURCE_BOTH)
|
||||
|
||||
ifdef: str | None = get_opt(mt, pb.ifdef)
|
||||
log: bool = get_opt(mt, pb.log, True)
|
||||
@ -1714,7 +1717,9 @@ namespace api {
|
||||
|
||||
# Generate base classes
|
||||
if base_class_fields:
|
||||
base_headers, base_cpp, base_dump_cpp = generate_base_classes(base_class_groups)
|
||||
base_headers, base_cpp, base_dump_cpp = generate_base_classes(
|
||||
base_class_groups, message_source_map
|
||||
)
|
||||
content += base_headers
|
||||
cpp += base_cpp
|
||||
dump_cpp += base_dump_cpp
|
||||
@ -1832,7 +1837,7 @@ static const char *const TAG = "api.service";
|
||||
cpp += "#endif\n\n"
|
||||
|
||||
for mt in file.message_type:
|
||||
obj = build_service_message_type(mt)
|
||||
obj = build_service_message_type(mt, message_source_map)
|
||||
if obj is None:
|
||||
continue
|
||||
hout, cout = obj
|
||||
|
Loading…
x
Reference in New Issue
Block a user