diff --git a/esphome/components/api/proto.cpp b/esphome/components/api/proto.cpp index 7a11745c74..f51d872e09 100644 --- a/esphome/components/api/proto.cpp +++ b/esphome/components/api/proto.cpp @@ -3,13 +3,38 @@ #include "esphome/core/helpers.h" #include "esphome/core/log.h" #include "api_pb2_size.h" +#include "proto_templates.h" namespace esphome { namespace api { static const char *const TAG = "api.proto"; +// Message handler registry - populated by generated code +const MessageHandler MESSAGE_HANDLERS[] = { + // Will be populated with entries like: + // {encode_message_field, size_message_field, decode_message_field}, + // etc. +}; +const size_t MESSAGE_HANDLER_COUNT = 0; // Will be set by generated code + +const RepeatedMessageHandler REPEATED_MESSAGE_HANDLERS[] = { + // Will be populated with entries like: + // {encode_repeated_message_field, size_repeated_message_field, + // decode_repeated_message_field}, + // etc. +}; +const size_t REPEATED_MESSAGE_HANDLER_COUNT = 0; // Will be set by generated code + void ProtoMessage::decode(const uint8_t *buffer, size_t length) { + // Check if V3 metadata is available + const FieldMetaV3 *fields_v3 = get_field_metadata_v3(); + if (fields_v3 != nullptr) { + decode_v3(buffer, length); + return; + } + + // Fall back to V2 uint32_t i = 0; bool error = false; uint8_t *base = reinterpret_cast(this); @@ -898,6 +923,13 @@ bool decode_repeated_double_field(void *field_ptr, Proto64Bit value) { // ProtoMessage implementations using metadata void ProtoMessage::encode(ProtoWriteBuffer buffer) const { + // Check if V3 metadata is available + const FieldMetaV3 *fields_v3 = get_field_metadata_v3(); + if (fields_v3 != nullptr) { + encode_v3(buffer); + return; + } + const uint8_t *base = reinterpret_cast(this); // Get V2 metadata once at the start @@ -1140,6 +1172,13 @@ void ProtoMessage::encode(ProtoWriteBuffer buffer) const { } void ProtoMessage::calculate_size(uint32_t &total_size) const { + // Check if V3 metadata is available + const FieldMetaV3 *fields_v3 = get_field_metadata_v3(); + if (fields_v3 != nullptr) { + calculate_size_v3(total_size); + return; + } + const uint8_t *base = reinterpret_cast(this); // Get V2 metadata once at the start @@ -1381,6 +1420,885 @@ void ProtoMessage::calculate_size(uint32_t &total_size) const { } } +// V3 implementations +void ProtoMessage::decode_v3(const uint8_t *buffer, size_t length) { + uint32_t i = 0; + bool error = false; + uint8_t *base = reinterpret_cast(this); + + // Get V3 metadata + const FieldMetaV3 *fields = get_field_metadata_v3(); + size_t field_count = get_field_count_v3(); + const RepeatedFieldMetaV3 *repeated_fields = get_repeated_field_metadata_v3(); + size_t repeated_count = get_repeated_field_count_v3(); + + while (i < length) { + uint32_t consumed; + auto res = ProtoVarInt::parse(&buffer[i], length - i, &consumed); + if (!res.has_value()) { + ESP_LOGV(TAG, "Invalid field start at %" PRIu32, i); + break; + } + + uint32_t field_type = (res->as_uint32()) & 0b111; + uint32_t field_id = (res->as_uint32()) >> 3; + i += consumed; + + switch (field_type) { + case 0: { // VarInt + res = ProtoVarInt::parse(&buffer[i], length - i, &consumed); + if (!res.has_value()) { + ESP_LOGV(TAG, "Invalid VarInt at %" PRIu32, i); + error = true; + break; + } + ProtoVarInt value = *res; + bool decoded = false; + + // Check regular fields + for (size_t j = 0; j < field_count; j++) { + if (fields[j].field_num == field_id && get_wire_type(fields[j].get_type()) == 0) { + void *field_addr = base + fields[j].offset; + + switch (fields[j].get_type()) { + case ProtoFieldType::TYPE_BOOL: + *static_cast(field_addr) = value.as_bool(); + decoded = true; + break; + case ProtoFieldType::TYPE_INT32: + *static_cast(field_addr) = value.as_int32(); + decoded = true; + break; + case ProtoFieldType::TYPE_UINT32: + *static_cast(field_addr) = value.as_uint32(); + decoded = true; + break; + case ProtoFieldType::TYPE_INT64: + *static_cast(field_addr) = value.as_int64(); + decoded = true; + break; + case ProtoFieldType::TYPE_UINT64: + *static_cast(field_addr) = value.as_uint64(); + decoded = true; + break; + case ProtoFieldType::TYPE_SINT32: + *static_cast(field_addr) = value.as_sint32(); + decoded = true; + break; + case ProtoFieldType::TYPE_SINT64: + *static_cast(field_addr) = value.as_sint64(); + decoded = true; + break; + case ProtoFieldType::TYPE_ENUM: + *static_cast(field_addr) = value.as_uint32(); + decoded = true; + break; + default: + break; + } + break; + } + } + + // Check repeated fields if not found + if (!decoded && repeated_fields) { + for (size_t j = 0; j < repeated_count; j++) { + if (repeated_fields[j].field_num == field_id && get_wire_type(repeated_fields[j].get_type()) == 0) { + void *field_addr = base + repeated_fields[j].offset; + + switch (repeated_fields[j].get_type()) { + case ProtoFieldType::TYPE_BOOL: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_bool()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_INT32: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_int32()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_UINT32: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_uint32()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_INT64: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_int64()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_UINT64: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_uint64()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_SINT32: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_sint32()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_SINT64: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_sint64()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_ENUM: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_uint32()); + decoded = true; + break; + } + default: + break; + } + break; + } + } + } + + if (!decoded) { + ESP_LOGV(TAG, "Skipping VarInt field %" PRIu32 " at %" PRIu32, field_id, i); + } + i += consumed; + break; + } + + case 2: { // Length-delimited + res = ProtoVarInt::parse(&buffer[i], length - i, &consumed); + if (!res.has_value()) { + ESP_LOGV(TAG, "Invalid length delimited size at %" PRIu32, i); + error = true; + break; + } + uint32_t field_length = res->as_uint32(); + i += consumed; + + if (i + field_length > length) { + ESP_LOGV(TAG, "Length delimited field %" PRIu32 " exceeds buffer", field_id); + error = true; + break; + } + + ProtoLengthDelimited value(&buffer[i], field_length); + bool decoded = false; + + // Check regular fields + for (size_t j = 0; j < field_count; j++) { + if (fields[j].field_num == field_id && get_wire_type(fields[j].get_type()) == 2) { + void *field_addr = base + fields[j].offset; + + switch (fields[j].get_type()) { + case ProtoFieldType::TYPE_STRING: { + auto *str = static_cast(field_addr); + *str = value.as_string(); + decoded = true; + break; + } + case ProtoFieldType::TYPE_BYTES: { + auto *str = static_cast(field_addr); + *str = value.as_string(); + decoded = true; + break; + } + case ProtoFieldType::TYPE_MESSAGE: { + // Use message handler registry + if (fields[j].get_message_type_id() < MESSAGE_HANDLER_COUNT) { + decoded = MESSAGE_HANDLERS[fields[j].get_message_type_id()].decode(field_addr, value); + } + break; + } + default: + break; + } + break; + } + } + + // Check repeated fields if not found + if (!decoded && repeated_fields) { + for (size_t j = 0; j < repeated_count; j++) { + if (repeated_fields[j].field_num == field_id && get_wire_type(repeated_fields[j].get_type()) == 2) { + void *field_addr = base + repeated_fields[j].offset; + + switch (repeated_fields[j].get_type()) { + case ProtoFieldType::TYPE_STRING: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_string()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_MESSAGE: { + // Use repeated message handler registry + if (repeated_fields[j].get_message_type_id() < REPEATED_MESSAGE_HANDLER_COUNT) { + decoded = + REPEATED_MESSAGE_HANDLERS[repeated_fields[j].get_message_type_id()].decode(field_addr, value); + } + break; + } + default: + break; + } + break; + } + } + } + + if (!decoded) { + ESP_LOGV(TAG, "Skipping length delimited field %" PRIu32 " at %" PRIu32, field_id, i); + } + i += field_length; + break; + } + + case 5: { // 32-bit + if (i + 4 > length) { + ESP_LOGV(TAG, "32-bit field %" PRIu32 " exceeds buffer", field_id); + error = true; + break; + } + + uint32_t raw = 0; + raw |= uint32_t(buffer[i]) << 0; + raw |= uint32_t(buffer[i + 1]) << 8; + raw |= uint32_t(buffer[i + 2]) << 16; + raw |= uint32_t(buffer[i + 3]) << 24; + Proto32Bit value(raw); + bool decoded = false; + + // Check regular fields + for (size_t j = 0; j < field_count; j++) { + if (fields[j].field_num == field_id && get_wire_type(fields[j].get_type()) == 5) { + void *field_addr = base + fields[j].offset; + + switch (fields[j].get_type()) { + case ProtoFieldType::TYPE_FLOAT: + *static_cast(field_addr) = value.as_float(); + decoded = true; + break; + case ProtoFieldType::TYPE_FIXED32: + *static_cast(field_addr) = value.as_fixed32(); + decoded = true; + break; + case ProtoFieldType::TYPE_SFIXED32: + *static_cast(field_addr) = value.as_sfixed32(); + decoded = true; + break; + default: + break; + } + break; + } + } + + // Check repeated fields if not found + if (!decoded && repeated_fields) { + for (size_t j = 0; j < repeated_count; j++) { + if (repeated_fields[j].field_num == field_id && get_wire_type(repeated_fields[j].get_type()) == 5) { + void *field_addr = base + repeated_fields[j].offset; + + switch (repeated_fields[j].get_type()) { + case ProtoFieldType::TYPE_FLOAT: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_float()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_FIXED32: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_fixed32()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_SFIXED32: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_sfixed32()); + decoded = true; + break; + } + default: + break; + } + break; + } + } + } + + if (!decoded) { + ESP_LOGV(TAG, "Skipping 32-bit field %" PRIu32 " at %" PRIu32, field_id, i); + } + i += 4; + break; + } + + case 1: { // 64-bit + if (i + 8 > length) { + ESP_LOGV(TAG, "64-bit field %" PRIu32 " exceeds buffer", field_id); + error = true; + break; + } + + uint64_t raw = 0; + raw |= uint64_t(buffer[i]) << 0; + raw |= uint64_t(buffer[i + 1]) << 8; + raw |= uint64_t(buffer[i + 2]) << 16; + raw |= uint64_t(buffer[i + 3]) << 24; + raw |= uint64_t(buffer[i + 4]) << 32; + raw |= uint64_t(buffer[i + 5]) << 40; + raw |= uint64_t(buffer[i + 6]) << 48; + raw |= uint64_t(buffer[i + 7]) << 56; + Proto64Bit value(raw); + bool decoded = false; + + // Check regular fields + for (size_t j = 0; j < field_count; j++) { + if (fields[j].field_num == field_id && get_wire_type(fields[j].get_type()) == 1) { + void *field_addr = base + fields[j].offset; + + switch (fields[j].get_type()) { + case ProtoFieldType::TYPE_DOUBLE: + *static_cast(field_addr) = value.as_double(); + decoded = true; + break; + case ProtoFieldType::TYPE_FIXED64: + *static_cast(field_addr) = value.as_fixed64(); + decoded = true; + break; + case ProtoFieldType::TYPE_SFIXED64: + *static_cast(field_addr) = value.as_sfixed64(); + decoded = true; + break; + default: + break; + } + break; + } + } + + // Check repeated fields if not found + if (!decoded && repeated_fields) { + for (size_t j = 0; j < repeated_count; j++) { + if (repeated_fields[j].field_num == field_id && get_wire_type(repeated_fields[j].get_type()) == 1) { + void *field_addr = base + repeated_fields[j].offset; + + switch (repeated_fields[j].get_type()) { + case ProtoFieldType::TYPE_DOUBLE: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_double()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_FIXED64: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_fixed64()); + decoded = true; + break; + } + case ProtoFieldType::TYPE_SFIXED64: { + auto *vec = static_cast *>(field_addr); + vec->push_back(value.as_sfixed64()); + decoded = true; + break; + } + default: + break; + } + break; + } + } + } + + if (!decoded) { + ESP_LOGV(TAG, "Skipping 64-bit field %" PRIu32 " at %" PRIu32, field_id, i); + } + i += 8; + break; + } + + default: + ESP_LOGV(TAG, "Unknown field type %" PRIu32 " at %" PRIu32, field_type, i); + return; + } + + if (error) { + break; + } + } +} + +void ProtoMessage::encode_v3(ProtoWriteBuffer buffer) const { + const uint8_t *base = reinterpret_cast(this); + + // Get V3 metadata + const FieldMetaV3 *fields = get_field_metadata_v3(); + size_t field_count = get_field_count_v3(); + + // Regular fields + for (size_t i = 0; i < field_count; i++) { + const void *field_addr = base + fields[i].offset; + + switch (fields[i].get_type()) { + case ProtoFieldType::TYPE_BOOL: { + const auto *val = static_cast(field_addr); + buffer.encode_bool(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_INT32: { + const auto *val = static_cast(field_addr); + buffer.encode_int32(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_UINT32: { + const auto *val = static_cast(field_addr); + buffer.encode_uint32(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_INT64: { + const auto *val = static_cast(field_addr); + buffer.encode_int64(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_UINT64: { + const auto *val = static_cast(field_addr); + buffer.encode_uint64(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_SINT32: { + const auto *val = static_cast(field_addr); + buffer.encode_sint32(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_SINT64: { + const auto *val = static_cast(field_addr); + buffer.encode_sint64(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_ENUM: { + const auto *val = static_cast(field_addr); + buffer.encode_uint32(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_STRING: { + const auto *val = static_cast(field_addr); + buffer.encode_string(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_BYTES: { + const auto *str = static_cast(field_addr); + buffer.encode_bytes(fields[i].field_num, reinterpret_cast(str->data()), str->size(), false); + break; + } + case ProtoFieldType::TYPE_FLOAT: { + const auto *val = static_cast(field_addr); + buffer.encode_float(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_FIXED32: { + const auto *val = static_cast(field_addr); + buffer.encode_fixed32(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_SFIXED32: { + const auto *val = static_cast(field_addr); + buffer.encode_sfixed32(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_DOUBLE: { + const auto *val = static_cast(field_addr); + buffer.encode_double(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_FIXED64: { + const auto *val = static_cast(field_addr); + buffer.encode_fixed64(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_SFIXED64: { + const auto *val = static_cast(field_addr); + buffer.encode_sfixed64(fields[i].field_num, *val, false); + break; + } + case ProtoFieldType::TYPE_MESSAGE: { + // Use message handler registry + if (fields[i].get_message_type_id() < MESSAGE_HANDLER_COUNT) { + MESSAGE_HANDLERS[fields[i].get_message_type_id()].encode(buffer, field_addr, fields[i].field_num); + } + break; + } + } + } + + // Repeated fields + const RepeatedFieldMetaV3 *repeated_fields = get_repeated_field_metadata_v3(); + size_t repeated_count = get_repeated_field_count_v3(); + + for (size_t i = 0; i < repeated_count; i++) { + const void *field_addr = base + repeated_fields[i].offset; + + switch (repeated_fields[i].get_type()) { + case ProtoFieldType::TYPE_BOOL: { + const auto *vec = static_cast *>(field_addr); + for (bool val : *vec) { + buffer.encode_bool(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_INT32: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_int32(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_UINT32: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_uint32(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_INT64: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_int64(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_UINT64: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_uint64(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_SINT32: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_sint32(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_SINT64: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_sint64(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_ENUM: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_uint32(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_STRING: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_string(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_BYTES: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_bytes(repeated_fields[i].field_num, reinterpret_cast(val.data()), val.size(), + true); + } + break; + } + case ProtoFieldType::TYPE_FLOAT: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_float(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_FIXED32: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_fixed32(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_SFIXED32: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_sfixed32(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_DOUBLE: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_double(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_FIXED64: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_fixed64(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_SFIXED64: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + buffer.encode_sfixed64(repeated_fields[i].field_num, val, true); + } + break; + } + case ProtoFieldType::TYPE_MESSAGE: { + // Use repeated message handler registry + if (repeated_fields[i].get_message_type_id() < REPEATED_MESSAGE_HANDLER_COUNT) { + REPEATED_MESSAGE_HANDLERS[repeated_fields[i].get_message_type_id()].encode(buffer, field_addr, + repeated_fields[i].field_num); + } + break; + } + } + } +} + +void ProtoMessage::calculate_size_v3(uint32_t &total_size) const { + const uint8_t *base = reinterpret_cast(this); + + // Get V3 metadata + const FieldMetaV3 *fields = get_field_metadata_v3(); + size_t field_count = get_field_count_v3(); + + // Regular fields + for (size_t i = 0; i < field_count; i++) { + const void *field_addr = base + fields[i].offset; + + switch (fields[i].get_type()) { + case ProtoFieldType::TYPE_BOOL: { + const auto *val = static_cast(field_addr); + ProtoSize::add_bool_field(total_size, fields[i].get_precalced_size(), *val, false); + break; + } + case ProtoFieldType::TYPE_INT32: { + const auto *val = static_cast(field_addr); + ProtoSize::add_int32_field(total_size, fields[i].get_precalced_size(), *val, false); + break; + } + case ProtoFieldType::TYPE_UINT32: { + const auto *val = static_cast(field_addr); + ProtoSize::add_uint32_field(total_size, fields[i].get_precalced_size(), *val, false); + break; + } + case ProtoFieldType::TYPE_INT64: { + const auto *val = static_cast(field_addr); + ProtoSize::add_int64_field(total_size, fields[i].get_precalced_size(), *val, false); + break; + } + case ProtoFieldType::TYPE_UINT64: { + const auto *val = static_cast(field_addr); + ProtoSize::add_uint64_field(total_size, fields[i].get_precalced_size(), *val, false); + break; + } + case ProtoFieldType::TYPE_SINT32: { + const auto *val = static_cast(field_addr); + ProtoSize::add_sint32_field(total_size, fields[i].get_precalced_size(), *val, false); + break; + } + case ProtoFieldType::TYPE_SINT64: { + const auto *val = static_cast(field_addr); + ProtoSize::add_sint64_field(total_size, fields[i].get_precalced_size(), *val, false); + break; + } + case ProtoFieldType::TYPE_ENUM: { + const auto *val = static_cast(field_addr); + ProtoSize::add_enum_field(total_size, fields[i].get_precalced_size(), *val, false); + break; + } + case ProtoFieldType::TYPE_STRING: { + const auto *val = static_cast(field_addr); + ProtoSize::add_string_field(total_size, fields[i].get_precalced_size(), *val, false); + break; + } + case ProtoFieldType::TYPE_BYTES: { + const auto *str = static_cast(field_addr); + ProtoSize::add_string_field(total_size, fields[i].get_precalced_size(), *str, false); + break; + } + case ProtoFieldType::TYPE_FLOAT: { + const auto *val = static_cast(field_addr); + ProtoSize::add_fixed_field<4>(total_size, fields[i].get_precalced_size(), *val != 0.0f, false); + break; + } + case ProtoFieldType::TYPE_FIXED32: { + const auto *val = static_cast(field_addr); + ProtoSize::add_fixed_field<4>(total_size, fields[i].get_precalced_size(), *val != 0, false); + break; + } + case ProtoFieldType::TYPE_SFIXED32: { + const auto *val = static_cast(field_addr); + ProtoSize::add_fixed_field<4>(total_size, fields[i].get_precalced_size(), *val != 0, false); + break; + } + case ProtoFieldType::TYPE_DOUBLE: { + const auto *val = static_cast(field_addr); + ProtoSize::add_fixed_field<8>(total_size, fields[i].get_precalced_size(), *val != 0.0, false); + break; + } + case ProtoFieldType::TYPE_FIXED64: { + const auto *val = static_cast(field_addr); + ProtoSize::add_fixed_field<8>(total_size, fields[i].get_precalced_size(), *val != 0, false); + break; + } + case ProtoFieldType::TYPE_SFIXED64: { + const auto *val = static_cast(field_addr); + ProtoSize::add_fixed_field<8>(total_size, fields[i].get_precalced_size(), *val != 0, false); + break; + } + case ProtoFieldType::TYPE_MESSAGE: { + // Use message handler registry + if (fields[i].get_message_type_id() < MESSAGE_HANDLER_COUNT) { + MESSAGE_HANDLERS[fields[i].get_message_type_id()].size(total_size, field_addr, fields[i].get_precalced_size(), + false); + } + break; + } + } + } + + // Repeated fields + const RepeatedFieldMetaV3 *repeated_fields = get_repeated_field_metadata_v3(); + size_t repeated_count = get_repeated_field_count_v3(); + + for (size_t i = 0; i < repeated_count; i++) { + const void *field_addr = base + repeated_fields[i].offset; + + switch (repeated_fields[i].get_type()) { + case ProtoFieldType::TYPE_BOOL: { + const auto *vec = static_cast *>(field_addr); + for (bool val : *vec) { + ProtoSize::add_bool_field(total_size, repeated_fields[i].get_precalced_size(), val, true); + } + break; + } + case ProtoFieldType::TYPE_INT32: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + ProtoSize::add_int32_field(total_size, repeated_fields[i].get_precalced_size(), val, true); + } + break; + } + case ProtoFieldType::TYPE_UINT32: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + ProtoSize::add_uint32_field(total_size, repeated_fields[i].get_precalced_size(), val, true); + } + break; + } + case ProtoFieldType::TYPE_INT64: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + ProtoSize::add_int64_field(total_size, repeated_fields[i].get_precalced_size(), val, true); + } + break; + } + case ProtoFieldType::TYPE_UINT64: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + ProtoSize::add_uint64_field(total_size, repeated_fields[i].get_precalced_size(), val, true); + } + break; + } + case ProtoFieldType::TYPE_SINT32: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + ProtoSize::add_sint32_field(total_size, repeated_fields[i].get_precalced_size(), val, true); + } + break; + } + case ProtoFieldType::TYPE_SINT64: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + ProtoSize::add_sint64_field(total_size, repeated_fields[i].get_precalced_size(), val, true); + } + break; + } + case ProtoFieldType::TYPE_ENUM: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + ProtoSize::add_enum_field(total_size, repeated_fields[i].get_precalced_size(), val, true); + } + break; + } + case ProtoFieldType::TYPE_STRING: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + ProtoSize::add_string_field(total_size, repeated_fields[i].get_precalced_size(), val, true); + } + break; + } + case ProtoFieldType::TYPE_FLOAT: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + ProtoSize::add_fixed_field<4>(total_size, repeated_fields[i].get_precalced_size(), val != 0.0f, true); + } + break; + } + case ProtoFieldType::TYPE_FIXED32: { + const auto *vec = static_cast *>(field_addr); + size_t count = vec->size(); + if (count > 0) { + total_size += count * (repeated_fields[i].get_precalced_size() + 4); + } + break; + } + case ProtoFieldType::TYPE_SFIXED32: { + const auto *vec = static_cast *>(field_addr); + size_t count = vec->size(); + if (count > 0) { + total_size += count * (repeated_fields[i].get_precalced_size() + 4); + } + break; + } + case ProtoFieldType::TYPE_DOUBLE: { + const auto *vec = static_cast *>(field_addr); + for (const auto &val : *vec) { + ProtoSize::add_fixed_field<8>(total_size, repeated_fields[i].get_precalced_size(), val != 0.0, true); + } + break; + } + case ProtoFieldType::TYPE_FIXED64: { + const auto *vec = static_cast *>(field_addr); + size_t count = vec->size(); + if (count > 0) { + total_size += count * (repeated_fields[i].get_precalced_size() + 8); + } + break; + } + case ProtoFieldType::TYPE_SFIXED64: { + const auto *vec = static_cast *>(field_addr); + size_t count = vec->size(); + if (count > 0) { + total_size += count * (repeated_fields[i].get_precalced_size() + 8); + } + break; + } + case ProtoFieldType::TYPE_MESSAGE: { + // Use repeated message handler registry + if (repeated_fields[i].get_message_type_id() < REPEATED_MESSAGE_HANDLER_COUNT) { + REPEATED_MESSAGE_HANDLERS[repeated_fields[i].get_message_type_id()].size( + total_size, field_addr, repeated_fields[i].get_precalced_size()); + } + break; + } + } + } +} + // Message type handler implementations moved to api_pb2.cpp (generated by Python script) } // namespace api diff --git a/esphome/components/api/proto.h b/esphome/components/api/proto.h index a3bd812f38..ee832bc5c3 100644 --- a/esphome/components/api/proto.h +++ b/esphome/components/api/proto.h @@ -20,28 +20,28 @@ class ProtoWriteBuffer; enum class ProtoFieldType : uint8_t { // Varint types (wire type 0) TYPE_BOOL = 0, - TYPE_INT32, - TYPE_UINT32, - TYPE_INT64, - TYPE_UINT64, - TYPE_SINT32, - TYPE_SINT64, - TYPE_ENUM, + TYPE_INT32 = 1, + TYPE_UINT32 = 2, + TYPE_INT64 = 3, + TYPE_UINT64 = 4, + TYPE_SINT32 = 5, + TYPE_SINT64 = 6, + TYPE_ENUM = 7, // Length-delimited types (wire type 2) - TYPE_STRING, - TYPE_BYTES, - TYPE_MESSAGE, + TYPE_STRING = 8, + TYPE_BYTES = 9, + TYPE_MESSAGE = 10, // 32-bit types (wire type 5) - TYPE_FLOAT, - TYPE_FIXED32, - TYPE_SFIXED32, + TYPE_FLOAT = 11, + TYPE_FIXED32 = 12, + TYPE_SFIXED32 = 13, // 64-bit types (wire type 1) - TYPE_DOUBLE, - TYPE_FIXED64, - TYPE_SFIXED64, + TYPE_DOUBLE = 14, + TYPE_FIXED64 = 15, + TYPE_SFIXED64 = 16, }; // Helper to get wire type from field type @@ -261,7 +261,51 @@ using RepeatedEncodeFunc = void (*)(ProtoWriteBuffer &, const void *field_ptr, u using RepeatedSizeFunc = void (*)(uint32_t &total_size, const void *field_ptr, uint8_t precalced_field_id_size); using RepeatedDecodeLengthFunc = bool (*)(void *field_ptr, ProtoLengthDelimited value); -// New type-based metadata structure (smaller and more efficient) +// Message handler registry entry +struct MessageHandler { + EncodeFunc encode; + SizeFunc size; + DecodeLengthFunc decode; +}; + +// Repeated message handler registry entry +struct RepeatedMessageHandler { + RepeatedEncodeFunc encode; + RepeatedSizeFunc size; + RepeatedDecodeLengthFunc decode; +}; + +// Global message handler registries (defined in proto.cpp) +extern const MessageHandler MESSAGE_HANDLERS[]; +extern const size_t MESSAGE_HANDLER_COUNT; +extern const RepeatedMessageHandler REPEATED_MESSAGE_HANDLERS[]; +extern const size_t REPEATED_MESSAGE_HANDLER_COUNT; + +// Optimized metadata structure (4 bytes - no padding on 32-bit architectures) +struct FieldMetaV3 { + uint8_t field_num; // Protobuf field number (1-255) + uint8_t type_and_size; // bits 0-4: ProtoFieldType, bits 5-6: precalced_field_id_size-1, bit 7: reserved + union { + uint16_t offset; // For non-message types: offset in class (0-65535) + struct { + uint8_t offset_low; // For TYPE_MESSAGE: low byte of offset + uint8_t message_type_id; // For TYPE_MESSAGE: index into MESSAGE_HANDLERS + }; + }; + + // Helper methods + ProtoFieldType get_type() const { return static_cast(type_and_size & 0x1F); } + uint8_t get_precalced_size() const { return ((type_and_size >> 5) & 0x03) + 1; } + uint16_t get_offset() const { + if (get_type() == ProtoFieldType::TYPE_MESSAGE) { + return offset_low; // Limited to 255 for messages + } + return offset; + } + uint8_t get_message_type_id() const { return message_type_id; } +}; + +// Keep V2 for now during transition struct FieldMetaV2 { uint8_t field_num; // Protobuf field number (1-255) uint16_t offset; // offset of field in class @@ -440,7 +484,31 @@ class ProtoWriteBuffer { std::vector *buffer_; }; -// New type-based repeated field metadata +// Optimized repeated field metadata (4 bytes - no padding on 32-bit architectures) +struct RepeatedFieldMetaV3 { + uint8_t field_num; // Protobuf field number (1-255) + uint8_t type_and_size; // bits 0-4: ProtoFieldType, bits 5-6: precalced_field_id_size-1, bit 7: reserved + union { + uint16_t offset; // For non-message types: offset in class (0-65535) + struct { + uint8_t offset_low; // For TYPE_MESSAGE: low byte of offset + uint8_t message_type_id; // For TYPE_MESSAGE: index into REPEATED_MESSAGE_HANDLERS + }; + }; + + // Helper methods + ProtoFieldType get_type() const { return static_cast(type_and_size & 0x1F); } + uint8_t get_precalced_size() const { return ((type_and_size >> 5) & 0x03) + 1; } + uint16_t get_offset() const { + if (get_type() == ProtoFieldType::TYPE_MESSAGE) { + return offset_low; // Limited to 255 for messages + } + return offset; + } + uint8_t get_message_type_id() const { return message_type_id; } +}; + +// Keep V2 for now during transition struct RepeatedFieldMetaV2 { uint8_t field_num; uint16_t offset; @@ -468,11 +536,23 @@ class ProtoMessage { virtual const RepeatedFieldMetaV2 *get_repeated_field_metadata_v2() const { return nullptr; } virtual size_t get_repeated_field_count_v2() const { return 0; } - // Encode/decode/calculate_size using V2 metadata + // V3 metadata getters - for optimized implementation + virtual const FieldMetaV3 *get_field_metadata_v3() const { return nullptr; } + virtual size_t get_field_count_v3() const { return 0; } + virtual const RepeatedFieldMetaV3 *get_repeated_field_metadata_v3() const { return nullptr; } + virtual size_t get_repeated_field_count_v3() const { return 0; } + + // Encode/decode/calculate_size using V2 metadata (will check for V3 first) void encode(ProtoWriteBuffer buffer) const; void decode(const uint8_t *buffer, size_t length); void calculate_size(uint32_t &total_size) const; + protected: + // V3 implementations + void encode_v3(ProtoWriteBuffer buffer) const; + void decode_v3(const uint8_t *buffer, size_t length); + void calculate_size_v3(uint32_t &total_size) const; + #ifdef HAS_PROTO_MESSAGE_DUMP std::string dump() const; virtual void dump_to(std::string &out) const = 0;