From 4d54cb9b316bf35b32d26332a2887ceae26fa30a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 18 May 2025 17:05:20 -0400 Subject: [PATCH] Refactor API frame helpers to enable buffer reuse (#8825) --- esphome/components/api/api_connection.cpp | 2 +- esphome/components/api/api_connection.h | 9 +- esphome/components/api/api_frame_helper.cpp | 128 ++++++++++++++------ esphome/components/api/api_frame_helper.h | 41 ++++++- esphome/components/api/proto.h | 28 +++++ 5 files changed, 162 insertions(+), 46 deletions(-) diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index 847d7840dc..d71e5587a3 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -1962,7 +1962,7 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) } } - APIError err = this->helper_->write_packet(message_type, buffer.get_buffer()->data(), buffer.get_buffer()->size()); + APIError err = this->helper_->write_protobuf_packet(message_type, buffer); if (err == APIError::WOULD_BLOCK) return false; if (err != APIError::OK) { diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index 1e47418d90..b40e9602be 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -315,7 +315,14 @@ class APIConnection : public APIServerConnection { ProtoWriteBuffer create_buffer(uint32_t reserve_size) override { // FIXME: ensure no recursive writes can happen this->proto_write_buffer_.clear(); - this->proto_write_buffer_.reserve(reserve_size); + // Get header padding size - used for both reserve and insert + uint8_t header_padding = this->helper_->frame_header_padding(); + // Reserve space for header padding + message + footer + // - Header padding: space for protocol headers (7 bytes for Noise, 6 for Plaintext) + // - Footer: space for MAC (16 bytes for Noise, 0 for Plaintext) + this->proto_write_buffer_.reserve(reserve_size + header_padding + this->helper_->frame_footer_size()); + // Insert header padding bytes so message encoding starts at the correct position + this->proto_write_buffer_.insert(this->proto_write_buffer_.begin(), header_padding, 0); return {&this->proto_write_buffer_}; } bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) override; diff --git a/esphome/components/api/api_frame_helper.cpp b/esphome/components/api/api_frame_helper.cpp index f251ceb6e4..f18f4104b6 100644 --- a/esphome/components/api/api_frame_helper.cpp +++ b/esphome/components/api/api_frame_helper.cpp @@ -493,9 +493,12 @@ void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &rea std::vector data; data.resize(reason.length() + 1); data[0] = 0x01; // failure - for (size_t i = 0; i < reason.length(); i++) { - data[i + 1] = (uint8_t) reason[i]; + + // Copy error message in bulk + if (!reason.empty()) { + std::memcpy(data.data() + 1, reason.c_str(), reason.length()); } + // temporarily remove failed state auto orig_state = state_; state_ = State::EXPLICIT_REJECT; @@ -557,7 +560,7 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { return APIError::OK; } bool APINoiseFrameHelper::can_write_without_blocking() { return state_ == State::DATA && tx_buf_.empty(); } -APIError APINoiseFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) { +APIError APINoiseFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) { int err; APIError aerr; aerr = state_action_(); @@ -569,31 +572,36 @@ APIError APINoiseFrameHelper::write_packet(uint16_t type, const uint8_t *payload return APIError::WOULD_BLOCK; } + std::vector *raw_buffer = buffer.get_buffer(); + // Message data starts after padding + size_t payload_len = raw_buffer->size() - frame_header_padding_; size_t padding = 0; size_t msg_len = 4 + payload_len + padding; - size_t frame_len = 3 + msg_len + noise_cipherstate_get_mac_length(send_cipher_); - auto tmpbuf = std::unique_ptr{new (std::nothrow) uint8_t[frame_len]}; - if (tmpbuf == nullptr) { - HELPER_LOG("Could not allocate for writing packet"); - return APIError::OUT_OF_MEMORY; - } - tmpbuf[0] = 0x01; // indicator - // tmpbuf[1], tmpbuf[2] to be set later + // We need to resize to include MAC space, but we already reserved it in create_buffer + raw_buffer->resize(raw_buffer->size() + frame_footer_size_); + + // Write the noise header in the padded area + // Buffer layout: + // [0] - 0x01 indicator byte + // [1-2] - Size of encrypted payload (filled after encryption) + // [3-4] - Message type (encrypted) + // [5-6] - Payload length (encrypted) + // [7...] - Actual payload data (encrypted) + uint8_t *buf_start = raw_buffer->data(); + buf_start[0] = 0x01; // indicator + // buf_start[1], buf_start[2] to be set later after encryption const uint8_t msg_offset = 3; - const uint8_t payload_offset = msg_offset + 4; - tmpbuf[msg_offset + 0] = (uint8_t) (type >> 8); // type - tmpbuf[msg_offset + 1] = (uint8_t) type; - tmpbuf[msg_offset + 2] = (uint8_t) (payload_len >> 8); // data_len - tmpbuf[msg_offset + 3] = (uint8_t) payload_len; - // copy data - std::copy(payload, payload + payload_len, &tmpbuf[payload_offset]); - // fill padding with zeros - std::fill(&tmpbuf[payload_offset + payload_len], &tmpbuf[frame_len], 0); + buf_start[msg_offset + 0] = (uint8_t) (type >> 8); // type high byte + buf_start[msg_offset + 1] = (uint8_t) type; // type low byte + buf_start[msg_offset + 2] = (uint8_t) (payload_len >> 8); // data_len high byte + buf_start[msg_offset + 3] = (uint8_t) payload_len; // data_len low byte + // payload data is already in the buffer starting at position 7 NoiseBuffer mbuf; noise_buffer_init(mbuf); - noise_buffer_set_inout(mbuf, &tmpbuf[msg_offset], msg_len, frame_len - msg_offset); + // The capacity parameter should be msg_len + frame_footer_size_ (MAC length) to allow space for encryption + noise_buffer_set_inout(mbuf, buf_start + msg_offset, msg_len, msg_len + frame_footer_size_); err = noise_cipherstate_encrypt(send_cipher_, &mbuf); if (err != 0) { state_ = State::FAILED; @@ -602,11 +610,13 @@ APIError APINoiseFrameHelper::write_packet(uint16_t type, const uint8_t *payload } size_t total_len = 3 + mbuf.size; - tmpbuf[1] = (uint8_t) (mbuf.size >> 8); - tmpbuf[2] = (uint8_t) mbuf.size; + buf_start[1] = (uint8_t) (mbuf.size >> 8); + buf_start[2] = (uint8_t) mbuf.size; struct iovec iov; - iov.iov_base = &tmpbuf[0]; + // Point iov_base to the beginning of the buffer (no unused padding in Noise) + // We send the entire frame: indicator + size + encrypted(type + data_len + payload + MAC) + iov.iov_base = buf_start; iov.iov_len = total_len; // write raw to not have two packets sent if NAGLE disabled @@ -718,6 +728,8 @@ APIError APINoiseFrameHelper::check_handshake_finished_() { return APIError::HANDSHAKESTATE_SPLIT_FAILED; } + frame_footer_size_ = noise_cipherstate_get_mac_length(send_cipher_); + HELPER_LOG("Handshake complete!"); noise_handshakestate_free(handshake_); handshake_ = nullptr; @@ -990,28 +1002,66 @@ APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { return APIError::OK; } bool APIPlaintextFrameHelper::can_write_without_blocking() { return state_ == State::DATA && tx_buf_.empty(); } -APIError APIPlaintextFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) { +APIError APIPlaintextFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) { if (state_ != State::DATA) { return APIError::BAD_STATE; } - std::vector header; - header.reserve(1 + api::ProtoSize::varint(static_cast(payload_len)) + - api::ProtoSize::varint(static_cast(type))); - header.push_back(0x00); - ProtoVarInt(payload_len).encode(header); - ProtoVarInt(type).encode(header); + std::vector *raw_buffer = buffer.get_buffer(); + // Message data starts after padding (frame_header_padding_ = 6) + size_t payload_len = raw_buffer->size() - frame_header_padding_; - struct iovec iov[2]; - iov[0].iov_base = &header[0]; - iov[0].iov_len = header.size(); - if (payload_len == 0) { - return write_raw_(iov, 1); + // Calculate varint sizes for header components + size_t size_varint_len = api::ProtoSize::varint(static_cast(payload_len)); + size_t type_varint_len = api::ProtoSize::varint(static_cast(type)); + size_t total_header_len = 1 + size_varint_len + type_varint_len; + + if (total_header_len > frame_header_padding_) { + // Header is too large to fit in the padding + return APIError::BAD_ARG; } - iov[1].iov_base = const_cast(payload); - iov[1].iov_len = payload_len; - return write_raw_(iov, 2); + // Calculate where to start writing the header + // The header starts at the latest possible position to minimize unused padding + // + // Example 1 (small values): total_header_len = 3, header_offset = 6 - 3 = 3 + // [0-2] - Unused padding + // [3] - 0x00 indicator byte + // [4] - Payload size varint (1 byte, for sizes 0-127) + // [5] - Message type varint (1 byte, for types 0-127) + // [6...] - Actual payload data + // + // Example 2 (medium values): total_header_len = 4, header_offset = 6 - 4 = 2 + // [0-1] - Unused padding + // [2] - 0x00 indicator byte + // [3-4] - Payload size varint (2 bytes, for sizes 128-16383) + // [5] - Message type varint (1 byte, for types 0-127) + // [6...] - Actual payload data + // + // Example 3 (large values): total_header_len = 6, header_offset = 6 - 6 = 0 + // [0] - 0x00 indicator byte + // [1-3] - Payload size varint (3 bytes, for sizes 16384-2097151) + // [4-5] - Message type varint (2 bytes, for types 128-32767) + // [6...] - Actual payload data + uint8_t *buf_start = raw_buffer->data(); + size_t header_offset = frame_header_padding_ - total_header_len; + + // Write the plaintext header + buf_start[header_offset] = 0x00; // indicator + + // Encode size varint directly into buffer + ProtoVarInt(payload_len).encode_to_buffer_unchecked(buf_start + header_offset + 1, size_varint_len); + + // Encode type varint directly into buffer + ProtoVarInt(type).encode_to_buffer_unchecked(buf_start + header_offset + 1 + size_varint_len, type_varint_len); + + struct iovec iov; + // Point iov_base to the beginning of our header (skip unused padding) + // This ensures we only send the actual header and payload, not the empty padding bytes + iov.iov_base = buf_start + header_offset; + iov.iov_len = total_header_len + payload_len; + + return write_raw_(&iov, 1); } APIError APIPlaintextFrameHelper::try_send_tx_buf_() { // try send from tx_buf diff --git a/esphome/components/api/api_frame_helper.h b/esphome/components/api/api_frame_helper.h index db506ea1ce..25bfd594ec 100644 --- a/esphome/components/api/api_frame_helper.h +++ b/esphome/components/api/api_frame_helper.h @@ -16,6 +16,8 @@ namespace esphome { namespace api { +class ProtoWriteBuffer; + struct ReadPacketBuffer { std::vector container; uint16_t type; @@ -65,32 +67,46 @@ class APIFrameHelper { virtual APIError loop() = 0; virtual APIError read_packet(ReadPacketBuffer *buffer) = 0; virtual bool can_write_without_blocking() = 0; - virtual APIError write_packet(uint16_t type, const uint8_t *data, size_t len) = 0; + virtual APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) = 0; virtual std::string getpeername() = 0; virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0; virtual APIError close() = 0; virtual APIError shutdown(int how) = 0; // Give this helper a name for logging virtual void set_log_info(std::string info) = 0; + // Get the frame header padding required by this protocol + virtual uint8_t frame_header_padding() = 0; + // Get the frame footer size required by this protocol + virtual uint8_t frame_footer_size() = 0; protected: // Common implementation for writing raw data to socket template APIError write_raw_(const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector &tx_buf, const std::string &info, StateEnum &state, StateEnum failed_state); + + uint8_t frame_header_padding_{0}; + uint8_t frame_footer_size_{0}; }; #ifdef USE_API_NOISE class APINoiseFrameHelper : public APIFrameHelper { public: APINoiseFrameHelper(std::unique_ptr socket, std::shared_ptr ctx) - : socket_(std::move(socket)), ctx_(std::move(std::move(ctx))) {} + : socket_(std::move(socket)), ctx_(std::move(ctx)) { + // Noise header structure: + // Pos 0: indicator (0x01) + // Pos 1-2: encrypted payload size (16-bit big-endian) + // Pos 3-6: encrypted type (16-bit) + data_len (16-bit) + // Pos 7+: actual payload data + frame_header_padding_ = 7; + } ~APINoiseFrameHelper() override; APIError init() override; APIError loop() override; APIError read_packet(ReadPacketBuffer *buffer) override; bool can_write_without_blocking() override; - APIError write_packet(uint16_t type, const uint8_t *payload, size_t len) override; + APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) override; std::string getpeername() override { return this->socket_->getpeername(); } int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { return this->socket_->getpeername(addr, addrlen); @@ -99,6 +115,10 @@ class APINoiseFrameHelper : public APIFrameHelper { APIError shutdown(int how) override; // Give this helper a name for logging void set_log_info(std::string info) override { info_ = std::move(info); } + // Get the frame header padding required by this protocol + uint8_t frame_header_padding() override { return frame_header_padding_; } + // Get the frame footer size required by this protocol + uint8_t frame_footer_size() override { return frame_footer_size_; } protected: struct ParsedFrame { @@ -152,13 +172,20 @@ class APINoiseFrameHelper : public APIFrameHelper { #ifdef USE_API_PLAINTEXT class APIPlaintextFrameHelper : public APIFrameHelper { public: - APIPlaintextFrameHelper(std::unique_ptr socket) : socket_(std::move(socket)) {} + APIPlaintextFrameHelper(std::unique_ptr socket) : socket_(std::move(socket)) { + // Plaintext header structure (worst case): + // Pos 0: indicator (0x00) + // Pos 1-3: payload size varint (up to 3 bytes) + // Pos 4-5: message type varint (up to 2 bytes) + // Pos 6+: actual payload data + frame_header_padding_ = 6; + } ~APIPlaintextFrameHelper() override = default; APIError init() override; APIError loop() override; APIError read_packet(ReadPacketBuffer *buffer) override; bool can_write_without_blocking() override; - APIError write_packet(uint16_t type, const uint8_t *payload, size_t len) override; + APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) override; std::string getpeername() override { return this->socket_->getpeername(); } int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { return this->socket_->getpeername(addr, addrlen); @@ -167,6 +194,10 @@ class APIPlaintextFrameHelper : public APIFrameHelper { APIError shutdown(int how) override; // Give this helper a name for logging void set_log_info(std::string info) override { info_ = std::move(info); } + // Get the frame header padding required by this protocol + uint8_t frame_header_padding() override { return frame_header_padding_; } + // Get the frame footer size required by this protocol + uint8_t frame_footer_size() override { return frame_footer_size_; } protected: struct ParsedFrame { diff --git a/esphome/components/api/proto.h b/esphome/components/api/proto.h index e110a58eda..65bef0b6f7 100644 --- a/esphome/components/api/proto.h +++ b/esphome/components/api/proto.h @@ -83,6 +83,34 @@ class ProtoVarInt { return static_cast(this->value_ >> 1); } } + /** + * Encode the varint value to a pre-allocated buffer without bounds checking. + * + * @param buffer The pre-allocated buffer to write the encoded varint to + * @param len The size of the buffer in bytes + * + * @note The caller is responsible for ensuring the buffer is large enough + * to hold the encoded value. Use ProtoSize::varint() to calculate + * the exact size needed before calling this method. + * @note No bounds checking is performed for performance reasons. + */ + void encode_to_buffer_unchecked(uint8_t *buffer, size_t len) { + uint64_t val = this->value_; + if (val <= 0x7F) { + buffer[0] = val; + return; + } + size_t i = 0; + while (val && i < len) { + uint8_t temp = val & 0x7F; + val >>= 7; + if (val) { + buffer[i++] = temp | 0x80; + } else { + buffer[i++] = temp; + } + } + } void encode(std::vector &out) { uint64_t val = this->value_; if (val <= 0x7F) {