From 9cb86241b93d73adcadc2dd09a6c056df8109563 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jul 2025 19:40:21 -1000 Subject: [PATCH] cleanup --- esphome/components/api/api_pb2.cpp | 8 ++++ esphome/components/api/api_pb2.h | 4 ++ script/api_protobuf/api_protobuf.py | 68 ++++++++++++++++++++++------- 3 files changed, 65 insertions(+), 15 deletions(-) diff --git a/esphome/components/api/api_pb2.cpp b/esphome/components/api/api_pb2.cpp index 3b1f8d201f..c3814e8089 100644 --- a/esphome/components/api/api_pb2.cpp +++ b/esphome/components/api/api_pb2.cpp @@ -831,6 +831,8 @@ bool NoiseEncryptionSetKeyRequest::decode_length(uint32_t field_id, ProtoLengthD switch (field_id) { case 1: this->key = value.as_string(); + this->key_ptr_ = reinterpret_cast(this->key.data()); + this->key_len_ = this->key.size(); break; default: return false; @@ -1999,6 +2001,8 @@ bool BluetoothGATTWriteRequest::decode_length(uint32_t field_id, ProtoLengthDeli switch (field_id) { case 4: this->data = value.as_string(); + this->data_ptr_ = reinterpret_cast(this->data.data()); + this->data_len_ = this->data.size(); break; default: return false; @@ -2035,6 +2039,8 @@ bool BluetoothGATTWriteDescriptorRequest::decode_length(uint32_t field_id, Proto switch (field_id) { case 3: this->data = value.as_string(); + this->data_ptr_ = reinterpret_cast(this->data.data()); + this->data_len_ = this->data.size(); break; default: return false; @@ -2257,6 +2263,8 @@ bool VoiceAssistantAudio::decode_length(uint32_t field_id, ProtoLengthDelimited switch (field_id) { case 1: this->data = value.as_string(); + this->data_ptr_ = reinterpret_cast(this->data.data()); + this->data_len_ = this->data.size(); break; default: return false; diff --git a/esphome/components/api/api_pb2.h b/esphome/components/api/api_pb2.h index f116cfe4f9..b357b4cf54 100644 --- a/esphome/components/api/api_pb2.h +++ b/esphome/components/api/api_pb2.h @@ -993,6 +993,7 @@ class NoiseEncryptionSetKeyRequest : public ProtoDecodableMessage { #endif const uint8_t *key_ptr_{nullptr}; size_t key_len_{0}; + std::string key{}; // Storage for decoded data void set_key(const uint8_t *data, size_t len) { this->key_ptr_ = data; this->key_len_ = len; @@ -1921,6 +1922,7 @@ class BluetoothGATTWriteRequest : public ProtoDecodableMessage { bool response{false}; const uint8_t *data_ptr_{nullptr}; size_t data_len_{0}; + std::string data{}; // Storage for decoded data void set_data(const uint8_t *data, size_t len) { this->data_ptr_ = data; this->data_len_ = len; @@ -1960,6 +1962,7 @@ class BluetoothGATTWriteDescriptorRequest : public ProtoDecodableMessage { uint32_t handle{0}; const uint8_t *data_ptr_{nullptr}; size_t data_len_{0}; + std::string data{}; // Storage for decoded data void set_data(const uint8_t *data, size_t len) { this->data_ptr_ = data; this->data_len_ = len; @@ -2298,6 +2301,7 @@ class VoiceAssistantAudio : public ProtoDecodableMessage { #endif const uint8_t *data_ptr_{nullptr}; size_t data_len_{0}; + std::string data{}; // Storage for decoded data void set_data(const uint8_t *data, size_t len) { this->data_ptr_ = data; this->data_len_ = len; diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index 213b6fd332..a51fdc233b 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -313,7 +313,9 @@ def validate_field_type(field_type: int, field_name: str = "") -> None: ) -def create_field_type_info(field: descriptor.FieldDescriptorProto) -> TypeInfo: +def create_field_type_info( + field: descriptor.FieldDescriptorProto, needs_decode: bool = True +) -> TypeInfo: """Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options.""" if field.label == 3: # repeated return RepeatedTypeInfo(field) @@ -325,6 +327,10 @@ def create_field_type_info(field: descriptor.FieldDescriptorProto) -> TypeInfo: ): return FixedArrayBytesType(field, fixed_size) + # Special handling for bytes fields + if field.type == 12: + return BytesType(field, needs_decode) + validate_field_type(field.type, field.name) return TYPE_INFO[field.type](field) @@ -589,22 +595,54 @@ class BytesType(TypeInfo): default_value = "" reference_type = "std::string &" const_reference_type = "const std::string &" - decode_length = "value.as_string()" encode_func = "encode_bytes" wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2 + def __init__( + self, field: descriptor.FieldDescriptorProto, needs_decode: bool = True + ) -> None: + super().__init__(field) + self.needs_decode = needs_decode + @property def public_content(self) -> list[str]: # Store both pointer and length for zero-copy encoding, plus setter method - return [ + content = [ f"const uint8_t* {self.field_name}_ptr_{{nullptr}};", f"size_t {self.field_name}_len_{{0}};", - f"void set_{self.field_name}(const uint8_t* data, size_t len) {{", - f" this->{self.field_name}_ptr_ = data;", - f" this->{self.field_name}_len_ = len;", - "}", ] + # Only add storage if message needs decoding + if self.needs_decode: + content.append( + f"std::string {self.field_name}{{}}; // Storage for decoded data" + ) + + content.extend( + [ + f"void set_{self.field_name}(const uint8_t* data, size_t len) {{", + f" this->{self.field_name}_ptr_ = data;", + f" this->{self.field_name}_len_ = len;", + "}", + ] + ) + + return content + + @property + def decode_length_content(self) -> str: + if not self.needs_decode: + return "" # No decode needed for SOURCE_SERVER messages + + # Decode into storage and update pointer/length + return ( + f"case {self.number}:\n" + f" this->{self.field_name} = value.as_string();\n" + f" this->{self.field_name}_ptr_ = reinterpret_cast(this->{self.field_name}.data());\n" + f" this->{self.field_name}_len_ = this->{self.field_name}.size();\n" + f" break;" + ) + @property def encode_content(self) -> str: return f"buffer.encode_bytes({self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_);" @@ -1268,7 +1306,7 @@ def build_message_type( if field.options.deprecated: continue - ti = create_field_type_info(field) + ti = create_field_type_info(field, needs_decode) # Skip field declarations for fields that are in the base class # but include their encode/decode logic @@ -1583,10 +1621,16 @@ def build_base_class( public_content = [] protected_content = [] + # Determine if any message using this base class needs decoding + needs_decode = any( + message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT) + for msg in messages + ) + # For base classes, we only declare the fields but don't handle encode/decode # The derived classes will handle encoding/decoding with their specific field numbers for field in common_fields: - ti = create_field_type_info(field) + ti = create_field_type_info(field, needs_decode) # Get field_ifdef if it's consistent across all messages field_ifdef = get_common_field_ifdef(field.name, messages) @@ -1597,12 +1641,6 @@ def build_base_class( if ti.public_content: public_content.extend(wrap_with_ifdef(ti.public_content, field_ifdef)) - # Determine if any message using this base class needs decoding - needs_decode = any( - message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT) - for msg in messages - ) - # Build header parent_class = "ProtoDecodableMessage" if needs_decode else "ProtoMessage" out = f"class {base_class_name} : public {parent_class} {{\n"