This commit is contained in:
J. Nick Koston 2025-07-09 21:50:57 -10:00
parent 95afae4830
commit 921974ec23
No known key found for this signature in database
6 changed files with 3748 additions and 3300 deletions

File diff suppressed because it is too large Load Diff

View File

@ -285,7 +285,7 @@ enum UpdateCommand : uint32_t {
} // namespace enums
class InfoResponseProtoMessage : public ProtoMessage {
class InfoResponseProtoMessage : public ProtoMetadataMessage {
public:
~InfoResponseProtoMessage() override = default;
std::string object_id{};
@ -300,7 +300,7 @@ class InfoResponseProtoMessage : public ProtoMessage {
protected:
};
class StateResponseProtoMessage : public ProtoMessage {
class StateResponseProtoMessage : public ProtoMetadataMessage {
public:
~StateResponseProtoMessage() override = default;
uint32_t key{0};
@ -328,7 +328,7 @@ class HelloRequest : public ProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class HelloResponse : public ProtoMessage {
class HelloResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 2;
static constexpr uint16_t ESTIMATED_SIZE = 26;
@ -369,7 +369,7 @@ class ConnectRequest : public ProtoMessage {
protected:
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
};
class ConnectResponse : public ProtoMessage {
class ConnectResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 4;
static constexpr uint16_t ESTIMATED_SIZE = 2;
@ -402,7 +402,7 @@ class DisconnectRequest : public ProtoMessage {
protected:
};
class DisconnectResponse : public ProtoMessage {
class DisconnectResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 6;
static constexpr uint16_t ESTIMATED_SIZE = 0;
@ -430,7 +430,7 @@ class PingRequest : public ProtoMessage {
protected:
};
class PingResponse : public ProtoMessage {
class PingResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 8;
static constexpr uint16_t ESTIMATED_SIZE = 0;
@ -487,7 +487,7 @@ class DeviceInfo : public ProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class DeviceInfoResponse : public ProtoMessage {
class DeviceInfoResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 10;
static constexpr uint16_t ESTIMATED_SIZE = 219;
@ -543,7 +543,7 @@ class ListEntitiesRequest : public ProtoMessage {
protected:
};
class ListEntitiesDoneResponse : public ProtoMessage {
class ListEntitiesDoneResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 19;
static constexpr uint16_t ESTIMATED_SIZE = 0;
@ -1073,7 +1073,7 @@ class SubscribeLogsRequest : public ProtoMessage {
protected:
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class SubscribeLogsResponse : public ProtoMessage {
class SubscribeLogsResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 29;
static constexpr uint16_t ESTIMATED_SIZE = 13;
@ -1114,7 +1114,7 @@ class NoiseEncryptionSetKeyRequest : public ProtoMessage {
protected:
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
};
class NoiseEncryptionSetKeyResponse : public ProtoMessage {
class NoiseEncryptionSetKeyResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 125;
static constexpr uint16_t ESTIMATED_SIZE = 2;
@ -1161,7 +1161,7 @@ class HomeassistantServiceMap : public ProtoMessage {
protected:
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
};
class HomeassistantServiceResponse : public ProtoMessage {
class HomeassistantServiceResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 35;
static constexpr uint16_t ESTIMATED_SIZE = 113;
@ -1200,7 +1200,7 @@ class SubscribeHomeAssistantStatesRequest : public ProtoMessage {
protected:
};
class SubscribeHomeAssistantStateResponse : public ProtoMessage {
class SubscribeHomeAssistantStateResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 39;
static constexpr uint16_t ESTIMATED_SIZE = 20;
@ -1223,7 +1223,7 @@ class SubscribeHomeAssistantStateResponse : public ProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class HomeAssistantStateResponse : public ProtoMessage {
class HomeAssistantStateResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 40;
static constexpr uint16_t ESTIMATED_SIZE = 27;
@ -1258,7 +1258,7 @@ class GetTimeRequest : public ProtoMessage {
protected:
};
class GetTimeResponse : public ProtoMessage {
class GetTimeResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 37;
static constexpr uint16_t ESTIMATED_SIZE = 5;
@ -1292,7 +1292,7 @@ class ListEntitiesServicesArgument : public ProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class ListEntitiesServicesResponse : public ProtoMessage {
class ListEntitiesServicesResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 41;
static constexpr uint16_t ESTIMATED_SIZE = 48;
@ -1379,7 +1379,7 @@ class ListEntitiesCameraResponse : public InfoResponseProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class CameraImageResponse : public ProtoMessage {
class CameraImageResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 44;
static constexpr uint16_t ESTIMATED_SIZE = 16;
@ -1988,7 +1988,7 @@ class BluetoothServiceData : public ProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothLEAdvertisementResponse : public ProtoMessage {
class BluetoothLEAdvertisementResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 67;
static constexpr uint16_t ESTIMATED_SIZE = 107;
@ -2032,7 +2032,7 @@ class BluetoothLERawAdvertisement : public ProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothLERawAdvertisementsResponse : public ProtoMessage {
class BluetoothLERawAdvertisementsResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 93;
static constexpr uint16_t ESTIMATED_SIZE = 34;
@ -2072,7 +2072,7 @@ class BluetoothDeviceRequest : public ProtoMessage {
protected:
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothDeviceConnectionResponse : public ProtoMessage {
class BluetoothDeviceConnectionResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 69;
static constexpr uint16_t ESTIMATED_SIZE = 14;
@ -2156,7 +2156,7 @@ class BluetoothGATTService : public ProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothGATTGetServicesResponse : public ProtoMessage {
class BluetoothGATTGetServicesResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 71;
static constexpr uint16_t ESTIMATED_SIZE = 38;
@ -2179,7 +2179,7 @@ class BluetoothGATTGetServicesResponse : public ProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothGATTGetServicesDoneResponse : public ProtoMessage {
class BluetoothGATTGetServicesDoneResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 72;
static constexpr uint16_t ESTIMATED_SIZE = 4;
@ -2217,7 +2217,7 @@ class BluetoothGATTReadRequest : public ProtoMessage {
protected:
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothGATTReadResponse : public ProtoMessage {
class BluetoothGATTReadResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 74;
static constexpr uint16_t ESTIMATED_SIZE = 17;
@ -2318,7 +2318,7 @@ class BluetoothGATTNotifyRequest : public ProtoMessage {
protected:
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothGATTNotifyDataResponse : public ProtoMessage {
class BluetoothGATTNotifyDataResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 79;
static constexpr uint16_t ESTIMATED_SIZE = 17;
@ -2354,7 +2354,7 @@ class SubscribeBluetoothConnectionsFreeRequest : public ProtoMessage {
protected:
};
class BluetoothConnectionsFreeResponse : public ProtoMessage {
class BluetoothConnectionsFreeResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 81;
static constexpr uint16_t ESTIMATED_SIZE = 16;
@ -2377,7 +2377,7 @@ class BluetoothConnectionsFreeResponse : public ProtoMessage {
protected:
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothGATTErrorResponse : public ProtoMessage {
class BluetoothGATTErrorResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 82;
static constexpr uint16_t ESTIMATED_SIZE = 12;
@ -2399,7 +2399,7 @@ class BluetoothGATTErrorResponse : public ProtoMessage {
protected:
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothGATTWriteResponse : public ProtoMessage {
class BluetoothGATTWriteResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 83;
static constexpr uint16_t ESTIMATED_SIZE = 8;
@ -2420,7 +2420,7 @@ class BluetoothGATTWriteResponse : public ProtoMessage {
protected:
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothGATTNotifyResponse : public ProtoMessage {
class BluetoothGATTNotifyResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 84;
static constexpr uint16_t ESTIMATED_SIZE = 8;
@ -2441,7 +2441,7 @@ class BluetoothGATTNotifyResponse : public ProtoMessage {
protected:
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothDevicePairingResponse : public ProtoMessage {
class BluetoothDevicePairingResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 85;
static constexpr uint16_t ESTIMATED_SIZE = 10;
@ -2463,7 +2463,7 @@ class BluetoothDevicePairingResponse : public ProtoMessage {
protected:
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothDeviceUnpairingResponse : public ProtoMessage {
class BluetoothDeviceUnpairingResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 86;
static constexpr uint16_t ESTIMATED_SIZE = 10;
@ -2498,7 +2498,7 @@ class UnsubscribeBluetoothLEAdvertisementsRequest : public ProtoMessage {
protected:
};
class BluetoothDeviceClearCacheResponse : public ProtoMessage {
class BluetoothDeviceClearCacheResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 88;
static constexpr uint16_t ESTIMATED_SIZE = 10;
@ -2520,7 +2520,7 @@ class BluetoothDeviceClearCacheResponse : public ProtoMessage {
protected:
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class BluetoothScannerStateResponse : public ProtoMessage {
class BluetoothScannerStateResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 126;
static constexpr uint16_t ESTIMATED_SIZE = 4;
@ -2615,7 +2615,7 @@ class VoiceAssistantRequest : public ProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class VoiceAssistantResponse : public ProtoMessage {
class VoiceAssistantResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 91;
static constexpr uint16_t ESTIMATED_SIZE = 6;
@ -2649,7 +2649,7 @@ class VoiceAssistantEventData : public ProtoMessage {
protected:
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
};
class VoiceAssistantEventResponse : public ProtoMessage {
class VoiceAssistantEventResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 92;
static constexpr uint16_t ESTIMATED_SIZE = 36;
@ -2691,7 +2691,7 @@ class VoiceAssistantAudio : public ProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class VoiceAssistantTimerEventResponse : public ProtoMessage {
class VoiceAssistantTimerEventResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 115;
static constexpr uint16_t ESTIMATED_SIZE = 30;
@ -2782,7 +2782,7 @@ class VoiceAssistantConfigurationRequest : public ProtoMessage {
protected:
};
class VoiceAssistantConfigurationResponse : public ProtoMessage {
class VoiceAssistantConfigurationResponse : public ProtoMetadataMessage {
public:
static constexpr uint16_t MESSAGE_TYPE = 122;
static constexpr uint16_t ESTIMATED_SIZE = 56;

View File

@ -227,6 +227,85 @@ void size_bytes_field(uint32_t &total_size, const void *field_ptr, uint8_t field
ProtoSize::add_string_field(total_size, 1, *str, force);
}
// Type-specific decode functions
bool decode_string_field(void *field_ptr, ProtoLengthDelimited value) {
auto *str = static_cast<std::string *>(field_ptr);
*str = value.as_string();
return true;
}
bool decode_fixed32_field(void *field_ptr, Proto32Bit value) {
auto *val = static_cast<uint32_t *>(field_ptr);
*val = value.as_fixed32();
return true;
}
bool decode_bool_field(void *field_ptr, ProtoVarInt value) {
auto *val = static_cast<bool *>(field_ptr);
*val = value.as_bool();
return true;
}
bool decode_float_field(void *field_ptr, Proto32Bit value) {
auto *val = static_cast<float *>(field_ptr);
*val = value.as_float();
return true;
}
bool decode_int32_field(void *field_ptr, ProtoVarInt value) {
auto *val = static_cast<int32_t *>(field_ptr);
*val = value.as_int32();
return true;
}
bool decode_uint32_field(void *field_ptr, ProtoVarInt value) {
auto *val = static_cast<uint32_t *>(field_ptr);
*val = value.as_uint32();
return true;
}
bool decode_int64_field(void *field_ptr, ProtoVarInt value) {
auto *val = static_cast<int64_t *>(field_ptr);
*val = value.as_int64();
return true;
}
bool decode_uint64_field(void *field_ptr, ProtoVarInt value) {
auto *val = static_cast<uint64_t *>(field_ptr);
*val = value.as_uint64();
return true;
}
bool decode_sint32_field(void *field_ptr, ProtoVarInt value) {
auto *val = static_cast<int32_t *>(field_ptr);
*val = value.as_sint32();
return true;
}
bool decode_sint64_field(void *field_ptr, ProtoVarInt value) {
auto *val = static_cast<int64_t *>(field_ptr);
*val = value.as_sint64();
return true;
}
bool decode_fixed64_field(void *field_ptr, Proto64Bit value) {
auto *val = static_cast<uint64_t *>(field_ptr);
*val = value.as_fixed64();
return true;
}
bool decode_double_field(void *field_ptr, Proto64Bit value) {
auto *val = static_cast<double *>(field_ptr);
*val = value.as_double();
return true;
}
bool decode_bytes_field(void *field_ptr, ProtoLengthDelimited value) {
auto *str = static_cast<std::string *>(field_ptr);
*str = value.as_string();
return true;
}
// Template functions are now in the header file for proper instantiation
// Repeated field encoding functions
@ -447,5 +526,58 @@ void calculate_size_from_metadata(uint32_t &total_size, const void *obj, const F
}
}
// Metadata-driven decode implementations
bool ProtoMetadataMessage::decode_varint_metadata(uint32_t field_id, ProtoVarInt value, const FieldMeta *fields,
size_t field_count) {
uint8_t *base = reinterpret_cast<uint8_t *>(this);
for (size_t i = 0; i < field_count; i++) {
if (fields[i].field_num == field_id && fields[i].wire_type == 0) { // varint
void *field_addr = base + fields[i].offset;
return fields[i].decoder.decode_varint(field_addr, value);
}
}
return false;
}
bool ProtoMetadataMessage::decode_length_metadata(uint32_t field_id, ProtoLengthDelimited value,
const FieldMeta *fields, size_t field_count) {
uint8_t *base = reinterpret_cast<uint8_t *>(this);
for (size_t i = 0; i < field_count; i++) {
if (fields[i].field_num == field_id && fields[i].wire_type == 2) { // length-delimited
void *field_addr = base + fields[i].offset;
return fields[i].decoder.decode_length(field_addr, value);
}
}
return false;
}
bool ProtoMetadataMessage::decode_32bit_metadata(uint32_t field_id, Proto32Bit value, const FieldMeta *fields,
size_t field_count) {
uint8_t *base = reinterpret_cast<uint8_t *>(this);
for (size_t i = 0; i < field_count; i++) {
if (fields[i].field_num == field_id && fields[i].wire_type == 5) { // 32-bit
void *field_addr = base + fields[i].offset;
return fields[i].decoder.decode_32bit(field_addr, value);
}
}
return false;
}
bool ProtoMetadataMessage::decode_64bit_metadata(uint32_t field_id, Proto64Bit value, const FieldMeta *fields,
size_t field_count) {
uint8_t *base = reinterpret_cast<uint8_t *>(this);
for (size_t i = 0; i < field_count; i++) {
if (fields[i].field_num == field_id && fields[i].wire_type == 1) { // 64-bit
void *field_addr = base + fields[i].offset;
return fields[i].decoder.decode_64bit(field_addr, value);
}
}
return false;
}
} // namespace api
} // namespace esphome

View File

@ -24,15 +24,6 @@ using SizeFunc = void (*)(uint32_t &total_size, const void *field_ptr, uint8_t f
// This uses the same approach as offsetof but with explicit reinterpret_cast
#define PROTO_FIELD_OFFSET(Type, Member) (reinterpret_cast<size_t>(&reinterpret_cast<Type *>(16)->Member) - 16)
// Metadata structure describing each field
struct FieldMeta {
uint8_t field_num; // Protobuf field number (1-255)
uint16_t offset; // offset of field in class
EncodeFunc encoder; // Function to encode this field type
SizeFunc sizer; // Function to calculate size for this field type
bool force_encode; // If true, encode even if value is default/empty
};
// Function pointer types for repeated fields
using RepeatedEncodeFunc = void (*)(ProtoWriteBuffer &, const void *field_ptr, uint8_t field_num);
using RepeatedSizeFunc = void (*)(uint32_t &total_size, const void *field_ptr, uint8_t field_num);
@ -216,6 +207,28 @@ class Proto64Bit {
const uint64_t value_;
};
// Function pointer types for decoding (now that Proto classes are defined)
using DecodeVarintFunc = bool (*)(void *field_ptr, ProtoVarInt value);
using DecodeLengthFunc = bool (*)(void *field_ptr, ProtoLengthDelimited value);
using Decode32BitFunc = bool (*)(void *field_ptr, Proto32Bit value);
using Decode64BitFunc = bool (*)(void *field_ptr, Proto64Bit value);
// Metadata structure describing each field
struct FieldMeta {
uint8_t field_num; // Protobuf field number (1-255)
uint16_t offset; // offset of field in class
EncodeFunc encoder; // Function to encode this field type
SizeFunc sizer; // Function to calculate size for this field type
bool force_encode; // If true, encode even if value is default/empty
uint8_t wire_type; // Wire type (0=varint, 2=length, 5=32bit, 1=64bit)
union {
DecodeVarintFunc decode_varint;
DecodeLengthFunc decode_length;
Decode32BitFunc decode_32bit;
Decode64BitFunc decode_64bit;
} decoder;
};
class ProtoWriteBuffer {
public:
ProtoWriteBuffer(std::vector<uint8_t> *buffer) : buffer_(buffer) {}
@ -379,6 +392,17 @@ class ProtoMessage {
template<typename T> const char *proto_enum_to_string(T value);
// Base class for messages using metadata-driven encode/decode
class ProtoMetadataMessage : public ProtoMessage {
protected:
// Metadata-driven decode methods
bool decode_varint_metadata(uint32_t field_id, ProtoVarInt value, const FieldMeta *fields, size_t field_count);
bool decode_length_metadata(uint32_t field_id, ProtoLengthDelimited value, const FieldMeta *fields,
size_t field_count);
bool decode_32bit_metadata(uint32_t field_id, Proto32Bit value, const FieldMeta *fields, size_t field_count);
bool decode_64bit_metadata(uint32_t field_id, Proto64Bit value, const FieldMeta *fields, size_t field_count);
};
class ProtoService {
public:
protected:
@ -449,6 +473,24 @@ void encode_fixed64_field(ProtoWriteBuffer &buffer, const void *field_ptr, uint8
void encode_double_field(ProtoWriteBuffer &buffer, const void *field_ptr, uint8_t field_num);
void encode_bytes_field(ProtoWriteBuffer &buffer, const void *field_ptr, uint8_t field_num);
// Type-specific decode functions
bool decode_string_field(void *field_ptr, ProtoLengthDelimited value);
bool decode_fixed32_field(void *field_ptr, Proto32Bit value);
bool decode_bool_field(void *field_ptr, ProtoVarInt value);
bool decode_float_field(void *field_ptr, Proto32Bit value);
bool decode_int32_field(void *field_ptr, ProtoVarInt value);
bool decode_uint32_field(void *field_ptr, ProtoVarInt value);
bool decode_int64_field(void *field_ptr, ProtoVarInt value);
bool decode_uint64_field(void *field_ptr, ProtoVarInt value);
bool decode_sint32_field(void *field_ptr, ProtoVarInt value);
bool decode_sint64_field(void *field_ptr, ProtoVarInt value);
bool decode_fixed64_field(void *field_ptr, Proto64Bit value);
bool decode_double_field(void *field_ptr, Proto64Bit value);
bool decode_bytes_field(void *field_ptr, ProtoLengthDelimited value);
// Template enum decode function
template<typename EnumType> bool decode_enum_field(void *field_ptr, ProtoVarInt value);
// Type-specific size calculation functions
void size_string_field(uint32_t &total_size, const void *field_ptr, uint8_t field_num, bool force);
void size_fixed32_field(uint32_t &total_size, const void *field_ptr, uint8_t field_num, bool force);

View File

@ -19,6 +19,12 @@ inline void size_enum_field(uint32_t &total_size, const void *field_ptr, uint8_t
ProtoSize::add_enum_field(total_size, 1, static_cast<uint32_t>(*val), force);
}
template<typename EnumType> inline bool decode_enum_field(void *field_ptr, ProtoVarInt value) {
auto *val = static_cast<EnumType *>(field_ptr);
*val = value.as_enum<EnumType>();
return true;
}
// Template repeated field functions (must be in header for instantiation)
template<typename EnumType>
inline void encode_repeated_enum_field(ProtoWriteBuffer &buffer, const void *field_ptr, uint8_t field_num) {

View File

@ -1056,6 +1056,75 @@ def get_repeated_sizer_function(type_info: RepeatedTypeInfo) -> str | None:
return type_map.get(type_name, None)
def get_wire_type(type_info: TypeInfo) -> int:
"""Get the wire type for a given field type."""
# Map from TypeInfo class name to wire type
wire_type_map = {
"StringType": 2, # LENGTH_DELIMITED
"BytesType": 2, # LENGTH_DELIMITED
"MessageType": 2, # LENGTH_DELIMITED
"BoolType": 0, # VARINT
"Int32Type": 0, # VARINT
"UInt32Type": 0, # VARINT
"Int64Type": 0, # VARINT
"UInt64Type": 0, # VARINT
"SInt32Type": 0, # VARINT
"SInt64Type": 0, # VARINT
"EnumType": 0, # VARINT
"FloatType": 5, # FIXED32
"Fixed32Type": 5, # FIXED32
"SFixed32Type": 5, # FIXED32
"DoubleType": 1, # FIXED64
"Fixed64Type": 1, # FIXED64
"SFixed64Type": 1, # FIXED64
}
type_name = type_info.__class__.__name__
return wire_type_map.get(type_name, 0)
def get_decoder_function(type_info: TypeInfo, wire_type: int) -> str:
"""Get the decoder function for a given type."""
# Map based on both type and wire type
if wire_type == 0: # VARINT
type_map = {
"BoolType": "&decode_bool_field",
"Int32Type": "&decode_int32_field",
"UInt32Type": "&decode_uint32_field",
"Int64Type": "&decode_int64_field",
"UInt64Type": "&decode_uint64_field",
"SInt32Type": "&decode_sint32_field",
"SInt64Type": "&decode_sint64_field",
}
type_name = type_info.__class__.__name__
return type_map.get(type_name, None)
elif wire_type == 2: # LENGTH_DELIMITED
type_map = {
"StringType": "&decode_string_field",
"BytesType": "&decode_bytes_field",
}
type_name = type_info.__class__.__name__
return type_map.get(type_name, None)
elif wire_type == 5: # FIXED32
type_map = {
"FloatType": "&decode_float_field",
"Fixed32Type": "&decode_fixed32_field",
"SFixed32Type": "&decode_int32_field", # sfixed32 uses same as int32
}
type_name = type_info.__class__.__name__
return type_map.get(type_name, None)
elif wire_type == 1: # FIXED64
type_map = {
"DoubleType": "&decode_double_field",
"Fixed64Type": "&decode_fixed64_field",
"SFixed64Type": "&decode_int64_field", # sfixed64 uses same as int64
}
type_name = type_info.__class__.__name__
return type_map.get(type_name, None)
return None
def build_message_type(
desc: descriptor.DescriptorProto,
base_class_fields: dict[str, list[descriptor.FieldDescriptorProto]] = None,
@ -1125,52 +1194,85 @@ def build_message_type(
if ti.dump_content:
dump.append(ti.dump_content)
cpp = ""
if decode_varint:
decode_varint.append("default:\n return false;")
o = f"bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_varint), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_varint(uint32_t field_id, ProtoVarInt value) override;"
protected_content.insert(0, prot)
if decode_length:
decode_length.append("default:\n return false;")
o = f"bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_length), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;"
protected_content.insert(0, prot)
if decode_32bit:
decode_32bit.append("default:\n return false;")
o = f"bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_32bit), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_32bit(uint32_t field_id, Proto32Bit value) override;"
protected_content.insert(0, prot)
if decode_64bit:
decode_64bit.append("default:\n return false;")
o = f"bool {desc.name}::decode_64bit(uint32_t field_id, Proto64Bit value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_64bit), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
protected_content.insert(0, prot)
# Check if this is a Response message and use metadata approach
is_response = desc.name.endswith("Response")
metadata_info = None
cpp = ""
# Only generate decode methods for non-Response messages
if not is_response:
if decode_varint:
decode_varint.append("default:\n return false;")
o = f"bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_varint), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_varint(uint32_t field_id, ProtoVarInt value) override;"
protected_content.insert(0, prot)
if decode_length:
decode_length.append("default:\n return false;")
o = f"bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_length), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;"
protected_content.insert(0, prot)
if decode_32bit:
decode_32bit.append("default:\n return false;")
o = f"bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_32bit), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_32bit(uint32_t field_id, Proto32Bit value) override;"
protected_content.insert(0, prot)
if decode_64bit:
decode_64bit.append("default:\n return false;")
o = f"bool {desc.name}::decode_64bit(uint32_t field_id, Proto64Bit value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_64bit), " ") + "\n"
o += " }\n"
o += "}\n"
cpp += o
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
protected_content.insert(0, prot)
else:
# For Response classes, add metadata-driven decode methods
if decode_varint:
prot = "bool decode_varint(uint32_t field_id, ProtoVarInt value) override;"
protected_content.insert(0, prot)
o = f"bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n"
o += " return decode_varint_metadata(field_id, value, FIELDS, FIELD_COUNT);\n"
o += "}\n"
cpp += o
if decode_length:
prot = "bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;"
protected_content.insert(0, prot)
o = f"bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n"
o += " return decode_length_metadata(field_id, value, FIELDS, FIELD_COUNT);\n"
o += "}\n"
cpp += o
if decode_32bit:
prot = "bool decode_32bit(uint32_t field_id, Proto32Bit value) override;"
protected_content.insert(0, prot)
o = f"bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n"
o += " return decode_32bit_metadata(field_id, value, FIELDS, FIELD_COUNT);\n"
o += "}\n"
cpp += o
if decode_64bit:
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
protected_content.insert(0, prot)
o = f"bool {desc.name}::decode_64bit(uint32_t field_id, Proto64Bit value) {{\n"
o += " return decode_64bit_metadata(field_id, value, FIELDS, FIELD_COUNT);\n"
o += "}\n"
cpp += o
# Generate metadata arrays for Response classes
if is_response:
regular_fields = []
@ -1204,18 +1306,30 @@ def build_message_type(
ti = TYPE_INFO[field.type](field)
encoder = get_encoder_function(ti)
sizer = get_sizer_function(ti)
wire_type = get_wire_type(ti)
decoder = get_decoder_function(ti, wire_type)
force = "true" if field.label == 2 else "false" # Required fields
if encoder and sizer:
if encoder and sizer and decoder:
# Format: {field_num, offset, encoder, sizer, force_encode, wire_type, {decoder}}
decoder_field = (
f".decode_varint = {decoder}"
if wire_type == 0
else f".decode_length = {decoder}"
if wire_type == 2
else f".decode_32bit = {decoder}"
if wire_type == 5
else f".decode_64bit = {decoder}"
)
regular_fields.append(
f"{{{field.number}, PROTO_FIELD_OFFSET({desc.name}, {ti.field_name}), {encoder}, {sizer}, {force}}}"
f"{{{field.number}, PROTO_FIELD_OFFSET({desc.name}, {ti.field_name}), {encoder}, {sizer}, {force}, {wire_type}, {{{decoder_field}}}}}"
)
elif isinstance(ti, EnumType):
# Handle enum fields with template
enum_type = ti.cpp_type
regular_fields.append(
f"{{{field.number}, PROTO_FIELD_OFFSET({desc.name}, {ti.field_name}), "
f"&encode_enum_field<{enum_type}>, &size_enum_field<{enum_type}>, {force}}}"
f"&encode_enum_field<{enum_type}>, &size_enum_field<{enum_type}>, {force}, 0, {{.decode_varint = &decode_enum_field<{enum_type}>}}}}"
)
elif isinstance(ti, MessageType):
# Skip nested messages for now - they need special handling
@ -1335,6 +1449,8 @@ def build_message_type(
if base_class:
out = f"class {desc.name} : public {base_class} {{\n"
elif is_response:
out = f"class {desc.name} : public ProtoMetadataMessage {{\n"
else:
out = f"class {desc.name} : public ProtoMessage {{\n"
out += " public:\n"
@ -1458,7 +1574,13 @@ def build_base_class(
public_content.extend(ti.public_content)
# Build header
out = f"class {base_class_name} : public ProtoMessage {{\n"
# Check if this is a Response base class
if base_class_name.endswith("Response") or base_class_name.endswith(
"ResponseProtoMessage"
):
out = f"class {base_class_name} : public ProtoMetadataMessage {{\n"
else:
out = f"class {base_class_name} : public ProtoMessage {{\n"
out += " public:\n"
# Add destructor with override
@ -1962,7 +2084,6 @@ static const char *const TAG = "api.service";
exec_clang_format(root / "api_pb2_service.cpp")
exec_clang_format(root / "api_pb2.h")
exec_clang_format(root / "api_pb2.cpp")
exec_clang_format(root / "api_pb2_dump.h")
exec_clang_format(root / "api_pb2_dump.cpp")
except ImportError:
pass