This commit is contained in:
J. Nick Koston 2025-07-10 16:30:43 -10:00
parent c745140835
commit 0d5b353cdf
No known key found for this signature in database
3 changed files with 30 additions and 177 deletions

View File

@ -38,15 +38,11 @@ static size_t get_vector_size(ProtoFieldType type, const void *field_addr) {
return static_cast<const std::vector<uint32_t> *>(field_addr)->size();
case ProtoFieldType::TYPE_INT64:
case ProtoFieldType::TYPE_SINT64:
case ProtoFieldType::TYPE_SFIXED64:
return static_cast<const std::vector<int64_t> *>(field_addr)->size();
case ProtoFieldType::TYPE_UINT64:
case ProtoFieldType::TYPE_FIXED64:
return static_cast<const std::vector<uint64_t> *>(field_addr)->size();
case ProtoFieldType::TYPE_FLOAT:
return static_cast<const std::vector<float> *>(field_addr)->size();
case ProtoFieldType::TYPE_DOUBLE:
return static_cast<const std::vector<double> *>(field_addr)->size();
case ProtoFieldType::TYPE_STRING:
case ProtoFieldType::TYPE_BYTES:
return static_cast<const std::vector<std::string> *>(field_addr)->size();
@ -74,15 +70,11 @@ static const void *get_vector_element(ProtoFieldType type, const void *field_add
return &(*static_cast<const std::vector<uint32_t> *>(field_addr))[index];
case ProtoFieldType::TYPE_INT64:
case ProtoFieldType::TYPE_SINT64:
case ProtoFieldType::TYPE_SFIXED64:
return &(*static_cast<const std::vector<int64_t> *>(field_addr))[index];
case ProtoFieldType::TYPE_UINT64:
case ProtoFieldType::TYPE_FIXED64:
return &(*static_cast<const std::vector<uint64_t> *>(field_addr))[index];
case ProtoFieldType::TYPE_FLOAT:
return &(*static_cast<const std::vector<float> *>(field_addr))[index];
case ProtoFieldType::TYPE_DOUBLE:
return &(*static_cast<const std::vector<double> *>(field_addr))[index];
case ProtoFieldType::TYPE_STRING:
case ProtoFieldType::TYPE_BYTES:
return &(*static_cast<const std::vector<std::string> *>(field_addr))[index];
@ -134,15 +126,6 @@ static void encode_field(ProtoWriteBuffer &buffer, ProtoFieldType type, uint8_t
case ProtoFieldType::TYPE_SFIXED32:
buffer.encode_sfixed32(field_num, *static_cast<const int32_t *>(field_addr), force);
break;
case ProtoFieldType::TYPE_DOUBLE:
buffer.encode_double(field_num, *static_cast<const double *>(field_addr), force);
break;
case ProtoFieldType::TYPE_FIXED64:
buffer.encode_fixed64(field_num, *static_cast<const uint64_t *>(field_addr), force);
break;
case ProtoFieldType::TYPE_SFIXED64:
buffer.encode_sfixed64(field_num, *static_cast<const int64_t *>(field_addr), force);
break;
default:
break;
}
@ -195,21 +178,6 @@ static void calculate_field_size(uint32_t &total_size, ProtoFieldType type, uint
ProtoSize::add_fixed_field<4>(total_size, precalc_size, val != 0, force);
break;
}
case ProtoFieldType::TYPE_DOUBLE: {
double val = *static_cast<const double *>(field_addr);
ProtoSize::add_fixed_field<8>(total_size, precalc_size, val != 0.0, force);
break;
}
case ProtoFieldType::TYPE_FIXED64: {
uint64_t val = *static_cast<const uint64_t *>(field_addr);
ProtoSize::add_fixed_field<8>(total_size, precalc_size, val != 0, force);
break;
}
case ProtoFieldType::TYPE_SFIXED64: {
int64_t val = *static_cast<const int64_t *>(field_addr);
ProtoSize::add_fixed_field<8>(total_size, precalc_size, val != 0, force);
break;
}
default:
break;
}
@ -309,40 +277,6 @@ static bool decode_repeated_32bit_field(ProtoFieldType type, void *field_addr, c
}
}
// Decode 64-bit for single fields
static bool decode_64bit_field(ProtoFieldType type, void *field_addr, const Proto64Bit &value) {
switch (type) {
case ProtoFieldType::TYPE_DOUBLE:
*static_cast<double *>(field_addr) = value.as_double();
return true;
case ProtoFieldType::TYPE_FIXED64:
*static_cast<uint64_t *>(field_addr) = value.as_fixed64();
return true;
case ProtoFieldType::TYPE_SFIXED64:
*static_cast<int64_t *>(field_addr) = value.as_sfixed64();
return true;
default:
return false;
}
}
// Decode 64-bit for repeated fields
static bool decode_repeated_64bit_field(ProtoFieldType type, void *field_addr, const Proto64Bit &value) {
switch (type) {
case ProtoFieldType::TYPE_DOUBLE:
static_cast<std::vector<double> *>(field_addr)->push_back(value.as_double());
return true;
case ProtoFieldType::TYPE_FIXED64:
static_cast<std::vector<uint64_t> *>(field_addr)->push_back(value.as_fixed64());
return true;
case ProtoFieldType::TYPE_SFIXED64:
static_cast<std::vector<int64_t> *>(field_addr)->push_back(value.as_sfixed64());
return true;
default:
return false;
}
}
// Decode length-delimited for single fields
static bool decode_length_field(ProtoFieldType type, void *field_addr, const ProtoLengthDelimited &value,
uint8_t message_type_id) {
@ -525,47 +459,6 @@ void ProtoMessage::decode(const uint8_t *buffer, size_t length) {
break;
}
case 1: { // 64-bit
if (i + 8 > length) {
ESP_LOGV(TAG, "64-bit field exceeds buffer at position %u", i);
return;
}
uint64_t raw = 0;
raw |= uint64_t(buffer[i]) << 0;
raw |= uint64_t(buffer[i + 1]) << 8;
raw |= uint64_t(buffer[i + 2]) << 16;
raw |= uint64_t(buffer[i + 3]) << 24;
raw |= uint64_t(buffer[i + 4]) << 32;
raw |= uint64_t(buffer[i + 5]) << 40;
raw |= uint64_t(buffer[i + 6]) << 48;
raw |= uint64_t(buffer[i + 7]) << 56;
Proto64Bit value(raw);
// Try regular fields first
for (size_t j = 0; j < field_count; j++) {
if (fields[j].field_num == field_id && get_wire_type(fields[j].get_type()) == 1) {
void *field_addr = base + fields[j].get_offset();
decoded = decode_64bit_field(fields[j].get_type(), field_addr, value);
break;
}
}
// If not found, try repeated fields
if (!decoded) {
for (size_t j = 0; j < repeated_count; j++) {
if (repeated_fields[j].field_num == field_id && get_wire_type(repeated_fields[j].get_type()) == 1) {
void *field_addr = base + repeated_fields[j].get_offset();
decoded = decode_repeated_64bit_field(repeated_fields[j].get_type(), field_addr, value);
break;
}
}
}
i += 8;
break;
}
default:
ESP_LOGV(TAG, "Unknown wire type %u at position %u", wire_type, i);
return;
@ -663,9 +556,6 @@ void ProtoMessage::calculate_size(uint32_t &total_size) const {
if (type == ProtoFieldType::TYPE_FIXED32 || type == ProtoFieldType::TYPE_SFIXED32 ||
type == ProtoFieldType::TYPE_FLOAT) {
total_size += count * (repeated_fields[i].get_precalced_size() + 4);
} else if (type == ProtoFieldType::TYPE_FIXED64 || type == ProtoFieldType::TYPE_SFIXED64 ||
type == ProtoFieldType::TYPE_DOUBLE) {
total_size += count * (repeated_fields[i].get_precalced_size() + 8);
} else {
// For variable-size types, calculate each element
for (size_t j = 0; j < count; j++) {

View File

@ -37,11 +37,6 @@ enum class ProtoFieldType : uint8_t {
TYPE_FLOAT = 11,
TYPE_FIXED32 = 12,
TYPE_SFIXED32 = 13,
// 64-bit types (wire type 1)
TYPE_DOUBLE = 14,
TYPE_FIXED64 = 15,
TYPE_SFIXED64 = 16,
};
// Helper to get wire type from field type
@ -66,11 +61,6 @@ constexpr uint8_t get_wire_type(ProtoFieldType type) {
case ProtoFieldType::TYPE_FIXED32:
case ProtoFieldType::TYPE_SFIXED32:
return 5; // 32-bit
case ProtoFieldType::TYPE_DOUBLE:
case ProtoFieldType::TYPE_FIXED64:
case ProtoFieldType::TYPE_SFIXED64:
return 1; // 64-bit
}
return 0;
}
@ -234,24 +224,6 @@ class Proto32Bit {
const uint32_t value_;
};
class Proto64Bit {
public:
explicit Proto64Bit(uint64_t value) : value_(value) {}
uint64_t as_fixed64() const { return this->value_; }
int64_t as_sfixed64() const { return static_cast<int64_t>(this->value_); }
double as_double() const {
union {
uint64_t raw;
double value;
} s{};
s.raw = this->value_;
return s.value;
}
protected:
const uint64_t value_;
};
// Function pointer types used by V2 structures
using EncodeFunc = void (*)(ProtoWriteBuffer &, const void *field_ptr, uint8_t field_num);
using SizeFunc = void (*)(uint32_t &total_size, const void *field_ptr, uint8_t precalced_field_id_size, bool force);
@ -308,8 +280,6 @@ struct FieldMeta {
uint8_t get_message_type_id() const { return message_type_id >> 2; } // Upper 6 bits for type ID (0-63)
};
// V2 structures removed - we only use V3 now
class ProtoWriteBuffer {
public:
ProtoWriteBuffer(std::vector<uint8_t> *buffer) : buffer_(buffer) {}
@ -375,20 +345,6 @@ class ProtoWriteBuffer {
this->write((value >> 16) & 0xFF);
this->write((value >> 24) & 0xFF);
}
void encode_fixed64(uint32_t field_id, uint64_t value, bool force = false) {
if (value == 0 && !force)
return;
this->encode_field_raw(field_id, 1); // type 1: 64-bit fixed64
this->write((value >> 0) & 0xFF);
this->write((value >> 8) & 0xFF);
this->write((value >> 16) & 0xFF);
this->write((value >> 24) & 0xFF);
this->write((value >> 32) & 0xFF);
this->write((value >> 40) & 0xFF);
this->write((value >> 48) & 0xFF);
this->write((value >> 56) & 0xFF);
}
template<typename T> void encode_enum(uint32_t field_id, T value, bool force = false) {
this->encode_uint32(field_id, static_cast<uint32_t>(value), force);
}
@ -437,21 +393,6 @@ class ProtoWriteBuffer {
return;
this->encode_fixed32(field_id, static_cast<uint32_t>(value), force);
}
void encode_double(uint32_t field_id, double value, bool force = false) {
if (!force && value == 0.0)
return;
union {
double value;
uint64_t raw;
} val{};
val.value = value;
this->encode_fixed64(field_id, val.raw, force);
}
void encode_sfixed64(uint32_t field_id, int64_t value, bool force = false) {
if (!force && value == 0)
return;
this->encode_fixed64(field_id, static_cast<uint64_t>(value), force);
}
template<class C> void encode_message(uint32_t field_id, const C &value, bool force = false) {
this->encode_field_raw(field_id, 2); // type 2: Length-delimited message
size_t begin = this->buffer_->size();
@ -496,8 +437,6 @@ struct RepeatedFieldMeta {
uint8_t get_message_type_id() const { return message_type_id >> 2; } // Upper 6 bits for type ID (0-63)
};
// V2 structures removed - we only use V3 now
class ProtoMessage {
public:
virtual ~ProtoMessage() = default;

View File

@ -58,6 +58,13 @@ class MessageTypeRegistry:
type_registry = MessageTypeRegistry()
# Unsupported types that ESPHome doesn't use
UNSUPPORTED_TYPES = {
descriptor.FieldDescriptorProto.TYPE_DOUBLE,
descriptor.FieldDescriptorProto.TYPE_FIXED64,
descriptor.FieldDescriptorProto.TYPE_SFIXED64,
}
# Mapping from protobuf types to our ProtoFieldType enum
PROTO_TYPE_MAP = {
descriptor.FieldDescriptorProto.TYPE_BOOL: "ProtoFieldType::TYPE_BOOL",
@ -74,9 +81,6 @@ PROTO_TYPE_MAP = {
descriptor.FieldDescriptorProto.TYPE_FLOAT: "ProtoFieldType::TYPE_FLOAT",
descriptor.FieldDescriptorProto.TYPE_FIXED32: "ProtoFieldType::TYPE_FIXED32",
descriptor.FieldDescriptorProto.TYPE_SFIXED32: "ProtoFieldType::TYPE_SFIXED32",
descriptor.FieldDescriptorProto.TYPE_DOUBLE: "ProtoFieldType::TYPE_DOUBLE",
descriptor.FieldDescriptorProto.TYPE_FIXED64: "ProtoFieldType::TYPE_FIXED64",
descriptor.FieldDescriptorProto.TYPE_SFIXED64: "ProtoFieldType::TYPE_SFIXED64",
}
# Mapping from protobuf types to numeric values (must match proto.h enum)
@ -95,9 +99,6 @@ PROTO_TYPE_NUM_MAP = {
descriptor.FieldDescriptorProto.TYPE_FLOAT: 11,
descriptor.FieldDescriptorProto.TYPE_FIXED32: 12,
descriptor.FieldDescriptorProto.TYPE_SFIXED32: 13,
descriptor.FieldDescriptorProto.TYPE_DOUBLE: 14,
descriptor.FieldDescriptorProto.TYPE_FIXED64: 15,
descriptor.FieldDescriptorProto.TYPE_SFIXED64: 16,
}
@ -1290,6 +1291,20 @@ def build_message_type(
public_content.append("#endif")
for field in desc.field:
# Check for unsupported types
if field.type in UNSUPPORTED_TYPES:
raise ValueError(
f"Field '{field.name}' in message '{desc.name}' uses unsupported type {field.type}. "
f"ESPHome does not support double, fixed64, or sfixed64 types."
)
# Validate field number fits in uint8_t
if field.number > 255:
raise ValueError(
f"Field '{field.name}' in message '{desc.name}' has field number {field.number} "
f"which exceeds the maximum of 255 supported by FieldMeta."
)
if field.label == 3:
ti = RepeatedTypeInfo(field)
else:
@ -1370,6 +1385,15 @@ def build_message_type(
message_type_id = type_registry.get_message_type_id(
ti.type_name
)
# Validate message type ID fits in 6 bits (0-63)
if message_type_id > 63:
raise ValueError(
f"Message field '{field.name}' in '{desc.name}' references message type "
f"'{ti.type_name}' with type ID {message_type_id}, which exceeds the "
f"maximum of 63 supported by FieldMeta."
)
offset = f"PROTO_FIELD_OFFSET({desc.name}, {ti.field_name})"
# Since we have so few message types, we can use the upper bits of