Use message_source_map consistently in proto generation (#9542)

This commit is contained in:
J. Nick Koston 2025-07-18 02:28:08 -10:00 committed by GitHub
parent 0d422bd74f
commit 71cc298363
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1495,6 +1495,7 @@ def build_base_class(
base_class_name: str, base_class_name: str,
common_fields: list[descriptor.FieldDescriptorProto], common_fields: list[descriptor.FieldDescriptorProto],
messages: list[descriptor.DescriptorProto], messages: list[descriptor.DescriptorProto],
message_source_map: dict[str, int],
) -> tuple[str, str, str]: ) -> tuple[str, str, str]:
"""Build the base class definition and implementation.""" """Build the base class definition and implementation."""
public_content = [] public_content = []
@ -1511,7 +1512,7 @@ def build_base_class(
# Determine if any message using this base class needs decoding # Determine if any message using this base class needs decoding
needs_decode = any( needs_decode = any(
get_opt(msg, pb.source, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT) message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT)
for msg in messages for msg in messages
) )
@ -1543,6 +1544,7 @@ def build_base_class(
def generate_base_classes( def generate_base_classes(
base_class_groups: dict[str, list[descriptor.DescriptorProto]], base_class_groups: dict[str, list[descriptor.DescriptorProto]],
message_source_map: dict[str, int],
) -> tuple[str, str, str]: ) -> tuple[str, str, str]:
"""Generate all base classes.""" """Generate all base classes."""
all_headers = [] all_headers = []
@ -1556,7 +1558,7 @@ def generate_base_classes(
if common_fields: if common_fields:
# Generate base class # Generate base class
header, cpp, dump_cpp = build_base_class( header, cpp, dump_cpp = build_base_class(
base_class_name, common_fields, messages base_class_name, common_fields, messages, message_source_map
) )
all_headers.append(header) all_headers.append(header)
all_cpp.append(cpp) all_cpp.append(cpp)
@ -1567,6 +1569,7 @@ def generate_base_classes(
def build_service_message_type( def build_service_message_type(
mt: descriptor.DescriptorProto, mt: descriptor.DescriptorProto,
message_source_map: dict[str, int],
) -> tuple[str, str] | None: ) -> tuple[str, str] | None:
"""Builds the service message type.""" """Builds the service message type."""
snake = camel_to_snake(mt.name) snake = camel_to_snake(mt.name)
@ -1574,7 +1577,7 @@ def build_service_message_type(
if id_ is None: if id_ is None:
return None return None
source: int = get_opt(mt, pb.source, 0) source: int = message_source_map.get(mt.name, SOURCE_BOTH)
ifdef: str | None = get_opt(mt, pb.ifdef) ifdef: str | None = get_opt(mt, pb.ifdef)
log: bool = get_opt(mt, pb.log, True) log: bool = get_opt(mt, pb.log, True)
@ -1714,7 +1717,9 @@ namespace api {
# Generate base classes # Generate base classes
if base_class_fields: if base_class_fields:
base_headers, base_cpp, base_dump_cpp = generate_base_classes(base_class_groups) base_headers, base_cpp, base_dump_cpp = generate_base_classes(
base_class_groups, message_source_map
)
content += base_headers content += base_headers
cpp += base_cpp cpp += base_cpp
dump_cpp += base_dump_cpp dump_cpp += base_dump_cpp
@ -1832,7 +1837,7 @@ static const char *const TAG = "api.service";
cpp += "#endif\n\n" cpp += "#endif\n\n"
for mt in file.message_type: for mt in file.message_type:
obj = build_service_message_type(mt) obj = build_service_message_type(mt, message_source_map)
if obj is None: if obj is None:
continue continue
hout, cout = obj hout, cout = obj