Use message_source_map consistently in proto generation (#9542)

This commit is contained in:
J. Nick Koston 2025-07-18 02:28:08 -10:00 committed by GitHub
parent 0d422bd74f
commit 71cc298363
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1495,6 +1495,7 @@ def build_base_class(
base_class_name: str,
common_fields: list[descriptor.FieldDescriptorProto],
messages: list[descriptor.DescriptorProto],
message_source_map: dict[str, int],
) -> tuple[str, str, str]:
"""Build the base class definition and implementation."""
public_content = []
@ -1511,7 +1512,7 @@ def build_base_class(
# Determine if any message using this base class needs decoding
needs_decode = any(
get_opt(msg, pb.source, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT)
message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT)
for msg in messages
)
@ -1543,6 +1544,7 @@ def build_base_class(
def generate_base_classes(
base_class_groups: dict[str, list[descriptor.DescriptorProto]],
message_source_map: dict[str, int],
) -> tuple[str, str, str]:
"""Generate all base classes."""
all_headers = []
@ -1556,7 +1558,7 @@ def generate_base_classes(
if common_fields:
# Generate base class
header, cpp, dump_cpp = build_base_class(
base_class_name, common_fields, messages
base_class_name, common_fields, messages, message_source_map
)
all_headers.append(header)
all_cpp.append(cpp)
@ -1567,6 +1569,7 @@ def generate_base_classes(
def build_service_message_type(
mt: descriptor.DescriptorProto,
message_source_map: dict[str, int],
) -> tuple[str, str] | None:
"""Builds the service message type."""
snake = camel_to_snake(mt.name)
@ -1574,7 +1577,7 @@ def build_service_message_type(
if id_ is None:
return None
source: int = get_opt(mt, pb.source, 0)
source: int = message_source_map.get(mt.name, SOURCE_BOTH)
ifdef: str | None = get_opt(mt, pb.ifdef)
log: bool = get_opt(mt, pb.log, True)
@ -1714,7 +1717,9 @@ namespace api {
# Generate base classes
if base_class_fields:
base_headers, base_cpp, base_dump_cpp = generate_base_classes(base_class_groups)
base_headers, base_cpp, base_dump_cpp = generate_base_classes(
base_class_groups, message_source_map
)
content += base_headers
cpp += base_cpp
dump_cpp += base_dump_cpp
@ -1832,7 +1837,7 @@ static const char *const TAG = "api.service";
cpp += "#endif\n\n"
for mt in file.message_type:
obj = build_service_message_type(mt)
obj = build_service_message_type(mt, message_source_map)
if obj is None:
continue
hout, cout = obj