This commit is contained in:
J. Nick Koston 2025-07-10 20:25:02 -10:00
parent 0d5b353cdf
commit 5630720715
No known key found for this signature in database

View File

@ -269,21 +269,6 @@ class TypeInfo(ABC):
decode_32bit = None
@property
def decode_64bit_content(self) -> str:
content = self.decode_64bit
if content is None:
return None
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}"""
)
decode_64bit = None
@property
def encode_content(self) -> str:
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
@ -353,28 +338,6 @@ def register_type(name: int):
return func
@register_type(1)
class DoubleType(TypeInfo):
cpp_type = "double"
default_value = "0.0"
decode_64bit = "value.as_double()"
encode_func = "encode_double"
wire_type = WireType.FIXED64 # Uses wire type 1 according to protobuf spec
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%g", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0.0, {force_str(force)});"
return o
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes for double
@register_type(2)
class FloatType(TypeInfo):
cpp_type = "float"
@ -463,28 +426,6 @@ class Int32Type(TypeInfo):
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@register_type(6)
class Fixed64Type(TypeInfo):
cpp_type = "uint64_t"
default_value = "0"
decode_64bit = "value.as_fixed64()"
encode_func = "encode_fixed64"
wire_type = WireType.FIXED64 # Uses wire type 1
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%llu", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
return o
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes fixed
@register_type(7)
class Fixed32Type(TypeInfo):
cpp_type = "uint32_t"
@ -696,28 +637,6 @@ class SFixed32Type(TypeInfo):
return self.calculate_field_id_size() + 4 # field ID + 4 bytes fixed
@register_type(16)
class SFixed64Type(TypeInfo):
cpp_type = "int64_t"
default_value = "0"
decode_64bit = "value.as_sfixed64()"
encode_func = "encode_sfixed64"
wire_type = WireType.FIXED64 # Uses wire type 1
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%lld", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
return o
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes fixed
@register_type(17)
class SInt32Type(TypeInfo):
cpp_type = "int32_t"
@ -826,19 +745,6 @@ class RepeatedTypeInfo(TypeInfo):
}}"""
)
@property
def decode_64bit_content(self) -> str:
content = self._ti.decode_64bit
if content is None:
return None
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}"""
)
@property
def _ti_is_bool(self) -> bool:
# std::vector is specialized for bool, reference does not work
@ -1257,7 +1163,6 @@ def build_message_type(
decode_varint: list[str] = []
decode_length: list[str] = []
decode_32bit: list[str] = []
decode_64bit: list[str] = []
encode: list[str] = []
dump: list[str] = []
size_calc: list[str] = []
@ -1326,240 +1231,147 @@ def build_message_type(
decode_length.append(ti.decode_length_content)
if ti.decode_32bit_content:
decode_32bit.append(ti.decode_32bit_content)
if ti.decode_64bit_content:
decode_64bit.append(ti.decode_64bit_content)
if ti.dump_content:
dump.append(ti.dump_content)
# Use metadata approach for all message classes
use_metadata = True # Apply to all messages
metadata_info = None
cpp = ""
# Generate metadata arrays for all classes using metadata approach
# Generate metadata arrays for all messages
regular_fields = []
repeated_fields = []
metadata_info = None
if use_metadata:
# Generate metadata
for field in desc.field:
if field.label == 3: # Repeated field
ti = RepeatedTypeInfo(field)
field_type = PROTO_TYPE_MAP.get(field.type, None)
if field_type:
field_tag_size = ti.calculate_field_id_size()
# Pack type and size into type_and_size byte
type_num = PROTO_TYPE_NUM_MAP.get(field.type, 0)
type_and_size = (type_num & 0x1F) | ((field_tag_size - 1) << 5)
# Generate metadata
for field in desc.field:
if field.label == 3: # Repeated field
ti = RepeatedTypeInfo(field)
field_type = PROTO_TYPE_MAP.get(field.type, None)
if field_type:
field_tag_size = ti.calculate_field_id_size()
# Pack type and size into type_and_size byte
type_num = PROTO_TYPE_NUM_MAP.get(field.type, 0)
type_and_size = (type_num & 0x1F) | ((field_tag_size - 1) << 5)
if field.type == descriptor.FieldDescriptorProto.TYPE_MESSAGE:
# For messages, use offset_low and message_type_id with offset extension
message_type_id = type_registry.get_repeated_message_type_id(
ti._ti.type_name
)
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
# Bits 0-1: bits 8-9 of offset (extends offset to 10 bits = 1023)
# Bits 2-7: actual message type ID (supports 64 types)
repeated_fields.append(
f"{{{field.number}, {type_and_size}, {{.offset_low = static_cast<uint8_t>({offset} & 0xFF), .message_type_id = static_cast<uint8_t>((({offset} >> 8) & 0x03) | ({message_type_id} << 2))}}}}"
)
else:
# Non-message types use full offset
repeated_fields.append(
f"{{{field.number}, {type_and_size}, {{.offset = PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})}}}}"
)
else:
ti = TYPE_INFO[field.type](field)
field_type = PROTO_TYPE_MAP.get(field.type, None)
if field.type == descriptor.FieldDescriptorProto.TYPE_MESSAGE:
# For messages, use offset_low and message_type_id with offset extension
message_type_id = type_registry.get_repeated_message_type_id(
ti._ti.type_name
)
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
# Bits 0-1: bits 8-9 of offset (extends offset to 10 bits = 1023)
# Bits 2-7: actual message type ID (supports 64 types)
repeated_fields.append(
f"{{{field.number}, {type_and_size}, {{.offset_low = static_cast<uint8_t>({offset} & 0xFF), .message_type_id = static_cast<uint8_t>((({offset} >> 8) & 0x03) | ({message_type_id} << 2))}}}}"
)
else:
# Non-message types use full offset
repeated_fields.append(
f"{{{field.number}, {type_and_size}, {{.offset = PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})}}}}"
)
else:
ti = TYPE_INFO[field.type](field)
field_type = PROTO_TYPE_MAP.get(field.type, None)
if field_type:
field_tag_size = ti.calculate_field_id_size()
# Pack type and size into type_and_size byte
type_num = PROTO_TYPE_NUM_MAP.get(field.type, 0)
type_and_size = (type_num & 0x1F) | ((field_tag_size - 1) << 5)
if field_type:
field_tag_size = ti.calculate_field_id_size()
# Pack type and size into type_and_size byte
type_num = PROTO_TYPE_NUM_MAP.get(field.type, 0)
type_and_size = (type_num & 0x1F) | ((field_tag_size - 1) << 5)
if field.type == descriptor.FieldDescriptorProto.TYPE_MESSAGE:
# For messages, use offset_low and message_type_id
message_type_id = type_registry.get_message_type_id(
ti.type_name
if field.type == descriptor.FieldDescriptorProto.TYPE_MESSAGE:
# For messages, use offset_low and message_type_id
message_type_id = type_registry.get_message_type_id(ti.type_name)
# Validate message type ID fits in 6 bits (0-63)
if message_type_id > 63:
raise ValueError(
f"Message field '{field.name}' in '{desc.name}' references message type "
f"'{ti.type_name}' with type ID {message_type_id}, which exceeds the "
f"maximum of 63 supported by FieldMeta."
)
# Validate message type ID fits in 6 bits (0-63)
if message_type_id > 63:
raise ValueError(
f"Message field '{field.name}' in '{desc.name}' references message type "
f"'{ti.type_name}' with type ID {message_type_id}, which exceeds the "
f"maximum of 63 supported by FieldMeta."
)
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
# Since we have so few message types, we can use the upper bits of
# message_type_id to store the actual type ID
# Bits 0-1: bits 8-9 of offset (extends offset to 10 bits = 1023)
# Bits 2-7: actual message type ID (supports 64 types)
regular_fields.append(
f"{{{field.number}, {type_and_size}, {{.offset_low = static_cast<uint8_t>({offset} & 0xFF), .message_type_id = static_cast<uint8_t>((({offset} >> 8) & 0x03) | ({message_type_id} << 2))}}}}"
)
else:
# Non-message types use full offset
regular_fields.append(
f"{{{field.number}, {type_and_size}, {{.offset = PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})}}}}"
)
elif isinstance(ti, EnumType):
field_tag_size = ti.calculate_field_id_size()
# Enums are TYPE_ENUM (7)
type_and_size = (7 & 0x1F) | ((field_tag_size - 1) << 5)
# Since we have so few message types, we can use the upper bits of
# message_type_id to store the actual type ID
# Bits 0-1: bits 8-9 of offset (extends offset to 10 bits = 1023)
# Bits 2-7: actual message type ID (supports 64 types)
regular_fields.append(
f"{{{field.number}, {type_and_size}, {{.offset_low = static_cast<uint8_t>({offset} & 0xFF), .message_type_id = static_cast<uint8_t>((({offset} >> 8) & 0x03) | ({message_type_id} << 2))}}}}"
)
else:
# Non-message types use full offset
regular_fields.append(
f"{{{field.number}, {type_and_size}, {{.offset = PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})}}}}"
)
elif isinstance(ti, MessageType):
field_tag_size = ti.calculate_field_id_size()
# Messages are TYPE_MESSAGE (10)
type_and_size = (10 & 0x1F) | ((field_tag_size - 1) << 5)
message_type_id = type_registry.get_message_type_id(ti.type_name)
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
# Same encoding as above for large offsets
regular_fields.append(
f"{{{field.number}, {type_and_size}, {{.offset_low = static_cast<uint8_t>({offset} & 0xFF), .message_type_id = static_cast<uint8_t>({message_type_id} | ((({offset} >> 8) & 0x0F) << 4))}}}}"
)
elif isinstance(ti, EnumType):
field_tag_size = ti.calculate_field_id_size()
# Enums are TYPE_ENUM (7)
type_and_size = (7 & 0x1F) | ((field_tag_size - 1) << 5)
regular_fields.append(
f"{{{field.number}, {type_and_size}, {{.offset = PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})}}}}"
)
elif isinstance(ti, MessageType):
field_tag_size = ti.calculate_field_id_size()
# Messages are TYPE_MESSAGE (10)
type_and_size = (10 & 0x1F) | ((field_tag_size - 1) << 5)
message_type_id = type_registry.get_message_type_id(ti.type_name)
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
# Same encoding as above for large offsets
regular_fields.append(
f"{{{field.number}, {type_and_size}, {{.offset_low = static_cast<uint8_t>({offset} & 0xFF), .message_type_id = static_cast<uint8_t>({message_type_id} | ((({offset} >> 8) & 0x0F) << 4))}}}}"
)
# Store metadata info for later generation outside the class
metadata_info = {
"regular_fields": regular_fields,
"repeated_fields": repeated_fields,
"class_name": desc.name,
}
# Store metadata info for later generation outside the class
metadata_info = {
"regular_fields": regular_fields,
"repeated_fields": repeated_fields,
"class_name": desc.name,
}
# Only generate decode methods for classes not using metadata approach
if not use_metadata:
if decode_varint:
decode_varint.append("default:\n return false;")
o = f"bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_varint), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_varint(uint32_t field_id, ProtoVarInt value) override;"
protected_content.insert(0, prot)
if decode_length:
decode_length.append("default:\n return false;")
o = f"bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_length), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;"
protected_content.insert(0, prot)
if decode_32bit:
decode_32bit.append("default:\n return false;")
o = f"bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_32bit), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_32bit(uint32_t field_id, Proto32Bit value) override;"
protected_content.insert(0, prot)
if decode_64bit:
decode_64bit.append("default:\n return false;")
o = f"bool {desc.name}::decode_64bit(uint32_t field_id, Proto64Bit value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_64bit), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
protected_content.insert(0, prot)
# Add metadata declarations
if regular_fields:
public_content.append(f"static const FieldMeta FIELDS[{len(regular_fields)}];")
public_content.append(
f"static constexpr size_t FIELD_COUNT = {len(regular_fields)};"
)
else:
# For classes using metadata approach, no need to generate decode methods
# They're implemented in the base class ProtoMetadataMessage
pass
public_content.append("static constexpr size_t FIELD_COUNT = 0;")
# Metadata arrays for classes using metadata are already generated above
if use_metadata:
# Add metadata declarations
if regular_fields:
public_content.append(
f"static const FieldMeta FIELDS[{len(regular_fields)}];"
)
public_content.append(
f"static constexpr size_t FIELD_COUNT = {len(regular_fields)};"
)
else:
public_content.append("static constexpr size_t FIELD_COUNT = 0;")
if repeated_fields:
public_content.append(
f"static const RepeatedFieldMeta REPEATED_FIELDS[{len(repeated_fields)}];"
)
public_content.append(
f"static constexpr size_t REPEATED_COUNT = {len(repeated_fields)};"
)
else:
public_content.append("static constexpr size_t REPEATED_COUNT = 0;")
# Add virtual getter methods
public_content.append("// Metadata getters")
if regular_fields:
public_content.append(
"const FieldMeta *get_field_metadata() const override { return FIELDS; }"
)
else:
public_content.append(
"const FieldMeta *get_field_metadata() const override { return nullptr; }"
)
if repeated_fields:
public_content.append(
"size_t get_field_count() const override { return FIELD_COUNT; }"
f"static const RepeatedFieldMeta REPEATED_FIELDS[{len(repeated_fields)}];"
)
if repeated_fields:
public_content.append(
"const RepeatedFieldMeta *get_repeated_field_metadata() const override { return REPEATED_FIELDS; }"
)
else:
public_content.append(
"const RepeatedFieldMeta *get_repeated_field_metadata() const override { return nullptr; }"
)
public_content.append(
"size_t get_repeated_field_count() const override { return REPEATED_COUNT; }"
f"static constexpr size_t REPEATED_COUNT = {len(repeated_fields)};"
)
else:
public_content.append("static constexpr size_t REPEATED_COUNT = 0;")
# Only generate encode method if there are fields to encode
if encode and not use_metadata:
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{"
if len(encode) == 1 and len(encode[0]) + len(o) + 3 < 120:
o += f" {encode[0]} "
else:
o += "\n"
o += indent("\n".join(encode)) + "\n"
o += "}\n"
cpp += o
prot = "void encode(ProtoWriteBuffer buffer) const override;"
public_content.append(prot)
# If no fields to encode, the default implementation in ProtoMessage will be used
# For metadata classes, encode is implemented in base class ProtoMetadataMessage
# Add virtual getter methods
public_content.append("// Metadata getters")
if regular_fields:
public_content.append(
"const FieldMeta *get_field_metadata() const override { return FIELDS; }"
)
else:
public_content.append(
"const FieldMeta *get_field_metadata() const override { return nullptr; }"
)
public_content.append(
"size_t get_field_count() const override { return FIELD_COUNT; }"
)
# Add calculate_size method only if there are fields
if size_calc and not use_metadata:
o = f"void {desc.name}::calculate_size(uint32_t &total_size) const {{"
# For a single field, just inline it for simplicity
if len(size_calc) == 1 and len(size_calc[0]) + len(o) + 3 < 120:
o += f" {size_calc[0]} "
else:
# For multiple fields
o += "\n"
o += indent("\n".join(size_calc)) + "\n"
o += "}\n"
cpp += o
prot = "void calculate_size(uint32_t &total_size) const override;"
public_content.append(prot)
# If no fields to calculate size for, the default implementation in ProtoMessage will be used
# For metadata classes, calculate_size is implemented in base class ProtoMetadataMessage
if repeated_fields:
public_content.append(
"const RepeatedFieldMeta *get_repeated_field_metadata() const override { return REPEATED_FIELDS; }"
)
else:
public_content.append(
"const RepeatedFieldMeta *get_repeated_field_metadata() const override { return nullptr; }"
)
public_content.append(
"size_t get_repeated_field_count() const override { return REPEATED_COUNT; }"
)
# dump_to method declaration in header
prot = "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
@ -1603,10 +1415,8 @@ def build_message_type(
# Build dump_cpp content with dump_to implementation
dump_cpp = dump_impl
# Return metadata info for classes using metadata
metadata_return = metadata_info if use_metadata else None
return out, cpp, dump_cpp, metadata_return
# Return metadata info for all classes
return out, cpp, dump_cpp, metadata_info
SOURCE_BOTH = 0