diff --git a/esphome/components/api/api_frame_helper.cpp b/esphome/components/api/api_frame_helper.cpp index f18f4104b6..aa80c41597 100644 --- a/esphome/components/api/api_frame_helper.cpp +++ b/esphome/components/api/api_frame_helper.cpp @@ -7,20 +7,13 @@ #include "proto.h" #include "api_pb2_size.h" #include +#include namespace esphome { namespace api { static const char *const TAG = "api.socket"; -/// Is the given return value (from write syscalls) a wouldblock error? -bool is_would_block(ssize_t ret) { - if (ret == -1) { - return errno == EWOULDBLOCK || errno == EAGAIN; - } - return ret == 0; -} - const char *api_error_to_str(APIError err) { // not using switch to ensure compiler doesn't try to build a big table out of it if (err == APIError::OK) { @@ -73,92 +66,154 @@ const char *api_error_to_str(APIError err) { return "UNKNOWN"; } -// Common implementation for writing raw data to socket -template -APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt, socket::Socket *socket, - std::vector &tx_buf, const std::string &info, StateEnum &state, - StateEnum failed_state) { - // This method writes data to socket or buffers it +// Helper method to buffer data from IOVs +void APIFrameHelper::buffer_data_from_iov_(const struct iovec *iov, int iovcnt, uint16_t total_write_len) { + SendBuffer buffer; + buffer.data.reserve(total_write_len); + for (int i = 0; i < iovcnt; i++) { + const uint8_t *data = reinterpret_cast(iov[i].iov_base); + buffer.data.insert(buffer.data.end(), data, data + iov[i].iov_len); + } + this->tx_buf_.push_back(std::move(buffer)); +} + +// This method writes data to socket or buffers it +APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt) { // Returns APIError::OK if successful (or would block, but data has been buffered) - // Returns APIError::SOCKET_WRITE_FAILED if socket write failed, and sets state to failed_state + // Returns APIError::SOCKET_WRITE_FAILED if socket write failed, and sets state to FAILED if (iovcnt == 0) return APIError::OK; // Nothing to do, success - size_t total_write_len = 0; + uint16_t total_write_len = 0; for (int i = 0; i < iovcnt; i++) { #ifdef HELPER_LOG_PACKETS ESP_LOGVV(TAG, "Sending raw: %s", format_hex_pretty(reinterpret_cast(iov[i].iov_base), iov[i].iov_len).c_str()); #endif - total_write_len += iov[i].iov_len; + total_write_len += static_cast(iov[i].iov_len); } - if (!tx_buf.empty()) { - // try to empty tx_buf first - while (!tx_buf.empty()) { - ssize_t sent = socket->write(tx_buf.data(), tx_buf.size()); - if (is_would_block(sent)) { - break; - } else if (sent == -1) { - ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", info.c_str(), errno); - state = failed_state; - return APIError::SOCKET_WRITE_FAILED; // Socket write failed - } - // TODO: inefficient if multiple packets in txbuf - // replace with deque of buffers - tx_buf.erase(tx_buf.begin(), tx_buf.begin() + sent); + // Try to send any existing buffered data first if there is any + if (!this->tx_buf_.empty()) { + APIError send_result = try_send_tx_buf_(); + // If real error occurred (not just WOULD_BLOCK), return it + if (send_result != APIError::OK && send_result != APIError::WOULD_BLOCK) { + return send_result; + } + + // If there is still data in the buffer, we can't send, buffer + // the new data and return + if (!this->tx_buf_.empty()) { + this->buffer_data_from_iov_(iov, iovcnt, total_write_len); + return APIError::OK; // Success, data buffered } } - if (!tx_buf.empty()) { - // tx buf not empty, can't write now because then stream would be inconsistent - // Reserve space upfront to avoid multiple reallocations - tx_buf.reserve(tx_buf.size() + total_write_len); - for (int i = 0; i < iovcnt; i++) { - tx_buf.insert(tx_buf.end(), reinterpret_cast(iov[i].iov_base), - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - } - return APIError::OK; // Success, data buffered - } + // Try to send directly if no buffered data + ssize_t sent = this->socket_->writev(iov, iovcnt); - ssize_t sent = socket->writev(iov, iovcnt); - if (is_would_block(sent)) { - // operation would block, add buffer to tx_buf - // Reserve space upfront to avoid multiple reallocations - tx_buf.reserve(tx_buf.size() + total_write_len); - for (int i = 0; i < iovcnt; i++) { - tx_buf.insert(tx_buf.end(), reinterpret_cast(iov[i].iov_base), - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); + if (sent == -1) { + if (errno == EWOULDBLOCK || errno == EAGAIN) { + // Socket would block, buffer the data + this->buffer_data_from_iov_(iov, iovcnt, total_write_len); + return APIError::OK; // Success, data buffered } - return APIError::OK; // Success, data buffered - } else if (sent == -1) { - // an error occurred - ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", info.c_str(), errno); - state = failed_state; + // Socket error + ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", this->info_.c_str(), errno); + this->state_ = State::FAILED; return APIError::SOCKET_WRITE_FAILED; // Socket write failed - } else if ((size_t) sent != total_write_len) { - // partially sent, add end to tx_buf - size_t remaining = total_write_len - sent; - // Reserve space upfront to avoid multiple reallocations - tx_buf.reserve(tx_buf.size() + remaining); + } else if (static_cast(sent) < total_write_len) { + // Partially sent, buffer the remaining data + SendBuffer buffer; + uint16_t to_consume = static_cast(sent); + uint16_t remaining = total_write_len - static_cast(sent); + + buffer.data.reserve(remaining); - size_t to_consume = sent; for (int i = 0; i < iovcnt; i++) { if (to_consume >= iov[i].iov_len) { - to_consume -= iov[i].iov_len; + // This segment was fully sent + to_consume -= static_cast(iov[i].iov_len); } else { - tx_buf.insert(tx_buf.end(), reinterpret_cast(iov[i].iov_base) + to_consume, - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); + // This segment was partially sent or not sent at all + const uint8_t *data = reinterpret_cast(iov[i].iov_base) + to_consume; + uint16_t len = static_cast(iov[i].iov_len) - to_consume; + buffer.data.insert(buffer.data.end(), data, data + len); to_consume = 0; } } - return APIError::OK; // Success, data buffered + + this->tx_buf_.push_back(std::move(buffer)); } - return APIError::OK; // Success, all data sent + + return APIError::OK; // Success, all data sent or buffered } -#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__) +// Common implementation for trying to send buffered data +// IMPORTANT: Caller MUST ensure tx_buf_ is not empty before calling this method +APIError APIFrameHelper::try_send_tx_buf_() { + // Try to send from tx_buf - we assume it's not empty as it's the caller's responsibility to check + bool tx_buf_empty = false; + while (!tx_buf_empty) { + // Get the first buffer in the queue + SendBuffer &front_buffer = this->tx_buf_.front(); + + // Try to send the remaining data in this buffer + ssize_t sent = this->socket_->write(front_buffer.current_data(), front_buffer.remaining()); + + if (sent == -1) { + if (errno != EWOULDBLOCK && errno != EAGAIN) { + // Real socket error (not just would block) + ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", this->info_.c_str(), errno); + this->state_ = State::FAILED; + return APIError::SOCKET_WRITE_FAILED; // Socket write failed + } + // Socket would block, we'll try again later + return APIError::WOULD_BLOCK; + } else if (sent == 0) { + // Nothing sent but not an error + return APIError::WOULD_BLOCK; + } else if (static_cast(sent) < front_buffer.remaining()) { + // Partially sent, update offset + // Cast to ensure no overflow issues with uint16_t + front_buffer.offset += static_cast(sent); + return APIError::WOULD_BLOCK; // Stop processing more buffers if we couldn't send a complete buffer + } else { + // Buffer completely sent, remove it from the queue + this->tx_buf_.pop_front(); + // Update empty status for the loop condition + tx_buf_empty = this->tx_buf_.empty(); + // Continue loop to try sending the next buffer + } + } + + return APIError::OK; // All buffers sent successfully +} + +APIError APIFrameHelper::init_common_() { + if (state_ != State::INITIALIZE || this->socket_ == nullptr) { + ESP_LOGVV(TAG, "%s: Bad state for init %d", this->info_.c_str(), (int) state_); + return APIError::BAD_STATE; + } + int err = this->socket_->setblocking(false); + if (err != 0) { + state_ = State::FAILED; + ESP_LOGVV(TAG, "%s: Setting nonblocking failed with errno %d", this->info_.c_str(), errno); + return APIError::TCP_NONBLOCKING_FAILED; + } + + int enable = 1; + err = this->socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); + if (err != 0) { + state_ = State::FAILED; + ESP_LOGVV(TAG, "%s: Setting nodelay failed with errno %d", this->info_.c_str(), errno); + return APIError::TCP_NODELAY_FAILED; + } + return APIError::OK; +} + +#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->info_.c_str(), ##__VA_ARGS__) // uncomment to log raw packets //#define HELPER_LOG_PACKETS @@ -206,23 +261,9 @@ std::string noise_err_to_str(int err) { /// Initialize the frame helper, returns OK if successful. APIError APINoiseFrameHelper::init() { - if (state_ != State::INITIALIZE || socket_ == nullptr) { - HELPER_LOG("Bad state for init %d", (int) state_); - return APIError::BAD_STATE; - } - int err = socket_->setblocking(false); - if (err != 0) { - state_ = State::FAILED; - HELPER_LOG("Setting nonblocking failed with errno %d", errno); - return APIError::TCP_NONBLOCKING_FAILED; - } - - int enable = 1; - err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); - if (err != 0) { - state_ = State::FAILED; - HELPER_LOG("Setting nodelay failed with errno %d", errno); - return APIError::TCP_NODELAY_FAILED; + APIError err = init_common_(); + if (err != APIError::OK) { + return err; } // init prologue @@ -234,17 +275,16 @@ APIError APINoiseFrameHelper::init() { /// Run through handshake messages (if in that phase) APIError APINoiseFrameHelper::loop() { APIError err = state_action_(); - if (err == APIError::WOULD_BLOCK) - return APIError::OK; - if (err != APIError::OK) + if (err != APIError::OK && err != APIError::WOULD_BLOCK) { return err; - if (!tx_buf_.empty()) { + } + if (!this->tx_buf_.empty()) { err = try_send_tx_buf_(); - if (err != APIError::OK) { + if (err != APIError::OK && err != APIError::WOULD_BLOCK) { return err; } } - return APIError::OK; + return APIError::OK; // Convert WOULD_BLOCK to OK to avoid connection termination } /** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter @@ -270,8 +310,8 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) { // read header if (rx_header_buf_len_ < 3) { // no header information yet - size_t to_read = 3 - rx_header_buf_len_; - ssize_t received = socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read); + uint8_t to_read = 3 - rx_header_buf_len_; + ssize_t received = this->socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read); if (received == -1) { if (errno == EWOULDBLOCK || errno == EAGAIN) { return APIError::WOULD_BLOCK; @@ -284,8 +324,8 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) { HELPER_LOG("Connection closed"); return APIError::CONNECTION_CLOSED; } - rx_header_buf_len_ += received; - if ((size_t) received != to_read) { + rx_header_buf_len_ += static_cast(received); + if (static_cast(received) != to_read) { // not a full read return APIError::WOULD_BLOCK; } @@ -317,8 +357,8 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) { if (rx_buf_len_ < msg_size) { // more data to read - size_t to_read = msg_size - rx_buf_len_; - ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read); + uint16_t to_read = msg_size - rx_buf_len_; + ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read); if (received == -1) { if (errno == EWOULDBLOCK || errno == EAGAIN) { return APIError::WOULD_BLOCK; @@ -331,8 +371,8 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) { HELPER_LOG("Connection closed"); return APIError::CONNECTION_CLOSED; } - rx_buf_len_ += received; - if ((size_t) received != to_read) { + rx_buf_len_ += static_cast(received); + if (static_cast(received) != to_read) { // not all read return APIError::WOULD_BLOCK; } @@ -381,6 +421,8 @@ APIError APINoiseFrameHelper::state_action_() { if (aerr != APIError::OK) return aerr; // ignore contents, may be used in future for flags + // Reserve space for: existing prologue + 2 size bytes + frame data + prologue_.reserve(prologue_.size() + 2 + frame.msg.size()); prologue_.push_back((uint8_t) (frame.msg.size() >> 8)); prologue_.push_back((uint8_t) frame.msg.size()); prologue_.insert(prologue_.end(), frame.msg.begin(), frame.msg.end()); @@ -389,16 +431,20 @@ APIError APINoiseFrameHelper::state_action_() { } if (state_ == State::SERVER_HELLO) { // send server hello + const std::string &name = App.get_name(); + const std::string &mac = get_mac_address(); + std::vector msg; + // Reserve space for: 1 byte proto + name + null + mac + null + msg.reserve(1 + name.size() + 1 + mac.size() + 1); + // chosen proto msg.push_back(0x01); // node name, terminated by null byte - const std::string &name = App.get_name(); const uint8_t *name_ptr = reinterpret_cast(name.c_str()); msg.insert(msg.end(), name_ptr, name_ptr + name.size() + 1); // node mac, terminated by null byte - const std::string &mac = get_mac_address(); const uint8_t *mac_ptr = reinterpret_cast(mac.c_str()); msg.insert(msg.end(), mac_ptr, mac_ptr + mac.size() + 1); @@ -505,7 +551,6 @@ void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &rea write_frame_(data.data(), data.size()); state_ = orig_state; } - APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { int err; APIError aerr; @@ -533,7 +578,7 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { return APIError::CIPHERSTATE_DECRYPT_FAILED; } - size_t msg_size = mbuf.size; + uint16_t msg_size = mbuf.size; uint8_t *msg_data = frame.msg.data(); if (msg_size < 4) { state_ = State::FAILED; @@ -559,7 +604,6 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { buffer->type = type; return APIError::OK; } -bool APINoiseFrameHelper::can_write_without_blocking() { return state_ == State::DATA && tx_buf_.empty(); } APIError APINoiseFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) { int err; APIError aerr; @@ -574,9 +618,9 @@ APIError APINoiseFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuf std::vector *raw_buffer = buffer.get_buffer(); // Message data starts after padding - size_t payload_len = raw_buffer->size() - frame_header_padding_; - size_t padding = 0; - size_t msg_len = 4 + payload_len + padding; + uint16_t payload_len = raw_buffer->size() - frame_header_padding_; + uint16_t padding = 0; + uint16_t msg_len = 4 + payload_len + padding; // We need to resize to include MAC space, but we already reserved it in create_buffer raw_buffer->resize(raw_buffer->size() + frame_footer_size_); @@ -609,7 +653,7 @@ APIError APINoiseFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuf return APIError::CIPHERSTATE_ENCRYPT_FAILED; } - size_t total_len = 3 + mbuf.size; + uint16_t total_len = 3 + mbuf.size; buf_start[1] = (uint8_t) (mbuf.size >> 8); buf_start[2] = (uint8_t) mbuf.size; @@ -620,29 +664,9 @@ APIError APINoiseFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuf iov.iov_len = total_len; // write raw to not have two packets sent if NAGLE disabled - return write_raw_(&iov, 1); + return this->write_raw_(&iov, 1); } -APIError APINoiseFrameHelper::try_send_tx_buf_() { - // try send from tx_buf - while (state_ != State::CLOSED && !tx_buf_.empty()) { - ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size()); - if (sent == -1) { - if (errno == EWOULDBLOCK || errno == EAGAIN) - break; - state_ = State::FAILED; - HELPER_LOG("Socket write failed with errno %d", errno); - return APIError::SOCKET_WRITE_FAILED; - } else if (sent == 0) { - break; - } - // TODO: inefficient if multiple packets in txbuf - // replace with deque of buffers - tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent); - } - - return APIError::OK; -} -APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, size_t len) { +APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, uint16_t len) { uint8_t header[3]; header[0] = 0x01; // indicator header[1] = (uint8_t) (len >> 8); @@ -652,12 +676,12 @@ APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, size_t len) { iov[0].iov_base = header; iov[0].iov_len = 3; if (len == 0) { - return write_raw_(iov, 1); + return this->write_raw_(iov, 1); } iov[1].iov_base = const_cast(data); iov[1].iov_len = len; - return write_raw_(iov, 2); + return this->write_raw_(iov, 2); } /** Initiate the data structures for the handshake. @@ -752,22 +776,6 @@ APINoiseFrameHelper::~APINoiseFrameHelper() { } } -APIError APINoiseFrameHelper::close() { - state_ = State::CLOSED; - int err = socket_->close(); - if (err == -1) - return APIError::CLOSE_FAILED; - return APIError::OK; -} -APIError APINoiseFrameHelper::shutdown(int how) { - int err = socket_->shutdown(how); - if (err == -1) - return APIError::SHUTDOWN_FAILED; - if (how == SHUT_RDWR) { - state_ = State::CLOSED; - } - return APIError::OK; -} extern "C" { // declare how noise generates random bytes (here with a good HWRNG based on the RF system) void noise_rand_bytes(void *output, size_t len) { @@ -778,32 +786,15 @@ void noise_rand_bytes(void *output, size_t len) { } } -// Explicit template instantiation for Noise -template APIError APIFrameHelper::write_raw_( - const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector &tx_buf_, const std::string &info, - APINoiseFrameHelper::State &state, APINoiseFrameHelper::State failed_state); #endif // USE_API_NOISE #ifdef USE_API_PLAINTEXT /// Initialize the frame helper, returns OK if successful. APIError APIPlaintextFrameHelper::init() { - if (state_ != State::INITIALIZE || socket_ == nullptr) { - HELPER_LOG("Bad state for init %d", (int) state_); - return APIError::BAD_STATE; - } - int err = socket_->setblocking(false); - if (err != 0) { - state_ = State::FAILED; - HELPER_LOG("Setting nonblocking failed with errno %d", errno); - return APIError::TCP_NONBLOCKING_FAILED; - } - int enable = 1; - err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); - if (err != 0) { - state_ = State::FAILED; - HELPER_LOG("Setting nodelay failed with errno %d", errno); - return APIError::TCP_NODELAY_FAILED; + APIError err = init_common_(); + if (err != APIError::OK) { + return err; } state_ = State::DATA; @@ -814,14 +805,13 @@ APIError APIPlaintextFrameHelper::loop() { if (state_ != State::DATA) { return APIError::BAD_STATE; } - // try send pending TX data - if (!tx_buf_.empty()) { + if (!this->tx_buf_.empty()) { APIError err = try_send_tx_buf_(); - if (err != APIError::OK) { + if (err != APIError::OK && err != APIError::WOULD_BLOCK) { return err; } } - return APIError::OK; + return APIError::OK; // Convert WOULD_BLOCK to OK to avoid connection termination } /** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter @@ -846,7 +836,7 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { // there is no data on the wire (which is the common case). // This results in faster failure detection compared to // attempting to read multiple bytes at once. - ssize_t received = socket_->read(&data, 1); + ssize_t received = this->socket_->read(&data, 1); if (received == -1) { if (errno == EWOULDBLOCK || errno == EAGAIN) { return APIError::WOULD_BLOCK; @@ -910,14 +900,24 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { continue; } - rx_header_parsed_len_ = msg_size_varint->as_uint32(); + if (msg_size_varint->as_uint32() > 65535) { + state_ = State::FAILED; + HELPER_LOG("Bad packet: message size %" PRIu32 " exceeds maximum 65535", msg_size_varint->as_uint32()); + return APIError::BAD_DATA_PACKET; + } + rx_header_parsed_len_ = msg_size_varint->as_uint16(); auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[consumed], rx_header_buf_pos_ - 1 - consumed, &consumed); if (!msg_type_varint.has_value()) { // not enough data there yet continue; } - rx_header_parsed_type_ = msg_type_varint->as_uint32(); + if (msg_type_varint->as_uint32() > 65535) { + state_ = State::FAILED; + HELPER_LOG("Bad packet: message type %" PRIu32 " exceeds maximum 65535", msg_type_varint->as_uint32()); + return APIError::BAD_DATA_PACKET; + } + rx_header_parsed_type_ = msg_type_varint->as_uint16(); rx_header_parsed_ = true; } // header reading done @@ -929,8 +929,8 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { if (rx_buf_len_ < rx_header_parsed_len_) { // more data to read - size_t to_read = rx_header_parsed_len_ - rx_buf_len_; - ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read); + uint16_t to_read = rx_header_parsed_len_ - rx_buf_len_; + ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read); if (received == -1) { if (errno == EWOULDBLOCK || errno == EAGAIN) { return APIError::WOULD_BLOCK; @@ -943,8 +943,8 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { HELPER_LOG("Connection closed"); return APIError::CONNECTION_CLOSED; } - rx_buf_len_ += received; - if ((size_t) received != to_read) { + rx_buf_len_ += static_cast(received); + if (static_cast(received) != to_read) { // not all read return APIError::WOULD_BLOCK; } @@ -962,7 +962,6 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { rx_header_parsed_ = false; return APIError::OK; } - APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { APIError aerr; @@ -990,7 +989,7 @@ APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { "Bad indicator byte"; iov[0].iov_base = (void *) msg; iov[0].iov_len = 19; - write_raw_(iov, 1); + this->write_raw_(iov, 1); } return aerr; } @@ -1001,7 +1000,6 @@ APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { buffer->type = rx_header_parsed_type_; return APIError::OK; } -bool APIPlaintextFrameHelper::can_write_without_blocking() { return state_ == State::DATA && tx_buf_.empty(); } APIError APIPlaintextFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) { if (state_ != State::DATA) { return APIError::BAD_STATE; @@ -1009,12 +1007,12 @@ APIError APIPlaintextFrameHelper::write_protobuf_packet(uint16_t type, ProtoWrit std::vector *raw_buffer = buffer.get_buffer(); // Message data starts after padding (frame_header_padding_ = 6) - size_t payload_len = raw_buffer->size() - frame_header_padding_; + uint16_t payload_len = static_cast(raw_buffer->size() - frame_header_padding_); // Calculate varint sizes for header components - size_t size_varint_len = api::ProtoSize::varint(static_cast(payload_len)); - size_t type_varint_len = api::ProtoSize::varint(static_cast(type)); - size_t total_header_len = 1 + size_varint_len + type_varint_len; + uint8_t size_varint_len = api::ProtoSize::varint(static_cast(payload_len)); + uint8_t type_varint_len = api::ProtoSize::varint(static_cast(type)); + uint8_t total_header_len = 1 + size_varint_len + type_varint_len; if (total_header_len > frame_header_padding_) { // Header is too large to fit in the padding @@ -1044,7 +1042,7 @@ APIError APIPlaintextFrameHelper::write_protobuf_packet(uint16_t type, ProtoWrit // [4-5] - Message type varint (2 bytes, for types 128-32767) // [6...] - Actual payload data uint8_t *buf_start = raw_buffer->data(); - size_t header_offset = frame_header_padding_ - total_header_len; + uint8_t header_offset = frame_header_padding_ - total_header_len; // Write the plaintext header buf_start[header_offset] = 0x00; // indicator @@ -1063,46 +1061,7 @@ APIError APIPlaintextFrameHelper::write_protobuf_packet(uint16_t type, ProtoWrit return write_raw_(&iov, 1); } -APIError APIPlaintextFrameHelper::try_send_tx_buf_() { - // try send from tx_buf - while (state_ != State::CLOSED && !tx_buf_.empty()) { - ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size()); - if (is_would_block(sent)) { - break; - } else if (sent == -1) { - state_ = State::FAILED; - HELPER_LOG("Socket write failed with errno %d", errno); - return APIError::SOCKET_WRITE_FAILED; - } - // TODO: inefficient if multiple packets in txbuf - // replace with deque of buffers - tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent); - } - return APIError::OK; -} - -APIError APIPlaintextFrameHelper::close() { - state_ = State::CLOSED; - int err = socket_->close(); - if (err == -1) - return APIError::CLOSE_FAILED; - return APIError::OK; -} -APIError APIPlaintextFrameHelper::shutdown(int how) { - int err = socket_->shutdown(how); - if (err == -1) - return APIError::SHUTDOWN_FAILED; - if (how == SHUT_RDWR) { - state_ = State::CLOSED; - } - return APIError::OK; -} - -// Explicit template instantiation for Plaintext -template APIError APIFrameHelper::write_raw_( - const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector &tx_buf_, const std::string &info, - APIPlaintextFrameHelper::State &state, APIPlaintextFrameHelper::State failed_state); #endif // USE_API_PLAINTEXT } // namespace api diff --git a/esphome/components/api/api_frame_helper.h b/esphome/components/api/api_frame_helper.h index 25bfd594ec..ea91c3a7f9 100644 --- a/esphome/components/api/api_frame_helper.h +++ b/esphome/components/api/api_frame_helper.h @@ -21,15 +21,8 @@ class ProtoWriteBuffer; struct ReadPacketBuffer { std::vector container; uint16_t type; - size_t data_offset; - size_t data_len; -}; - -struct PacketBuffer { - const std::vector container; - uint16_t type; - uint8_t data_offset; - uint8_t data_len; + uint16_t data_offset; + uint16_t data_len; }; enum class APIError : int { @@ -62,38 +55,117 @@ const char *api_error_to_str(APIError err); class APIFrameHelper { public: + APIFrameHelper() = default; + explicit APIFrameHelper(std::unique_ptr socket) : socket_owned_(std::move(socket)) { + socket_ = socket_owned_.get(); + } virtual ~APIFrameHelper() = default; virtual APIError init() = 0; virtual APIError loop() = 0; virtual APIError read_packet(ReadPacketBuffer *buffer) = 0; - virtual bool can_write_without_blocking() = 0; - virtual APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) = 0; - virtual std::string getpeername() = 0; - virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0; - virtual APIError close() = 0; - virtual APIError shutdown(int how) = 0; + bool can_write_without_blocking() { return state_ == State::DATA && tx_buf_.empty(); } + std::string getpeername() { return socket_->getpeername(); } + int getpeername(struct sockaddr *addr, socklen_t *addrlen) { return socket_->getpeername(addr, addrlen); } + APIError close() { + state_ = State::CLOSED; + int err = this->socket_->close(); + if (err == -1) + return APIError::CLOSE_FAILED; + return APIError::OK; + } + APIError shutdown(int how) { + int err = this->socket_->shutdown(how); + if (err == -1) + return APIError::SHUTDOWN_FAILED; + if (how == SHUT_RDWR) { + state_ = State::CLOSED; + } + return APIError::OK; + } // Give this helper a name for logging - virtual void set_log_info(std::string info) = 0; + void set_log_info(std::string info) { info_ = std::move(info); } + virtual APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) = 0; // Get the frame header padding required by this protocol virtual uint8_t frame_header_padding() = 0; // Get the frame footer size required by this protocol virtual uint8_t frame_footer_size() = 0; protected: + // Struct for holding parsed frame data + struct ParsedFrame { + std::vector msg; + }; + + // Buffer containing data to be sent + struct SendBuffer { + std::vector data; + uint16_t offset{0}; // Current offset within the buffer (uint16_t to reduce memory usage) + + // Using uint16_t reduces memory usage since ESPHome API messages are limited to 64KB max + uint16_t remaining() const { return static_cast(data.size()) - offset; } + const uint8_t *current_data() const { return data.data() + offset; } + }; + + // Queue of data buffers to be sent + std::deque tx_buf_; + + // Common state enum for all frame helpers + // Note: Not all states are used by all implementations + // - INITIALIZE: Used by both Noise and Plaintext + // - CLIENT_HELLO, SERVER_HELLO, HANDSHAKE: Only used by Noise protocol + // - DATA: Used by both Noise and Plaintext + // - CLOSED: Used by both Noise and Plaintext + // - FAILED: Used by both Noise and Plaintext + // - EXPLICIT_REJECT: Only used by Noise protocol + enum class State { + INITIALIZE = 1, + CLIENT_HELLO = 2, // Noise only + SERVER_HELLO = 3, // Noise only + HANDSHAKE = 4, // Noise only + DATA = 5, + CLOSED = 6, + FAILED = 7, + EXPLICIT_REJECT = 8, // Noise only + }; + + // Current state of the frame helper + State state_{State::INITIALIZE}; + + // Helper name for logging + std::string info_; + + // Socket for communication + socket::Socket *socket_{nullptr}; + std::unique_ptr socket_owned_; + // Common implementation for writing raw data to socket + APIError write_raw_(const struct iovec *iov, int iovcnt); + + // Try to send data from the tx buffer + APIError try_send_tx_buf_(); + + // Helper method to buffer data from IOVs + void buffer_data_from_iov_(const struct iovec *iov, int iovcnt, uint16_t total_write_len); template APIError write_raw_(const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector &tx_buf, const std::string &info, StateEnum &state, StateEnum failed_state); uint8_t frame_header_padding_{0}; uint8_t frame_footer_size_{0}; + + // Receive buffer for reading frame data + std::vector rx_buf_; + uint16_t rx_buf_len_ = 0; + + // Common initialization for both plaintext and noise protocols + APIError init_common_(); }; #ifdef USE_API_NOISE class APINoiseFrameHelper : public APIFrameHelper { public: APINoiseFrameHelper(std::unique_ptr socket, std::shared_ptr ctx) - : socket_(std::move(socket)), ctx_(std::move(ctx)) { + : APIFrameHelper(std::move(socket)), ctx_(std::move(ctx)) { // Noise header structure: // Pos 0: indicator (0x01) // Pos 1-2: encrypted payload size (16-bit big-endian) @@ -105,49 +177,25 @@ class APINoiseFrameHelper : public APIFrameHelper { APIError init() override; APIError loop() override; APIError read_packet(ReadPacketBuffer *buffer) override; - bool can_write_without_blocking() override; APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) override; - std::string getpeername() override { return this->socket_->getpeername(); } - int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { - return this->socket_->getpeername(addr, addrlen); - } - APIError close() override; - APIError shutdown(int how) override; - // Give this helper a name for logging - void set_log_info(std::string info) override { info_ = std::move(info); } // Get the frame header padding required by this protocol uint8_t frame_header_padding() override { return frame_header_padding_; } // Get the frame footer size required by this protocol uint8_t frame_footer_size() override { return frame_footer_size_; } protected: - struct ParsedFrame { - std::vector msg; - }; - APIError state_action_(); APIError try_read_frame_(ParsedFrame *frame); - APIError try_send_tx_buf_(); - APIError write_frame_(const uint8_t *data, size_t len); - inline APIError write_raw_(const struct iovec *iov, int iovcnt) { - return APIFrameHelper::write_raw_(iov, iovcnt, socket_.get(), tx_buf_, info_, state_, State::FAILED); - } + APIError write_frame_(const uint8_t *data, uint16_t len); APIError init_handshake_(); APIError check_handshake_finished_(); void send_explicit_handshake_reject_(const std::string &reason); - - std::unique_ptr socket_; - - std::string info_; // Fixed-size header buffer for noise protocol: // 1 byte for indicator + 2 bytes for message size (16-bit value, not varint) // Note: Maximum message size is 65535, with a limit of 128 bytes during handshake phase uint8_t rx_header_buf_[3]; - size_t rx_header_buf_len_ = 0; - std::vector rx_buf_; - size_t rx_buf_len_ = 0; + uint8_t rx_header_buf_len_ = 0; - std::vector tx_buf_; std::vector prologue_; std::shared_ptr ctx_; @@ -155,24 +203,13 @@ class APINoiseFrameHelper : public APIFrameHelper { NoiseCipherState *send_cipher_{nullptr}; NoiseCipherState *recv_cipher_{nullptr}; NoiseProtocolId nid_; - - enum class State { - INITIALIZE = 1, - CLIENT_HELLO = 2, - SERVER_HELLO = 3, - HANDSHAKE = 4, - DATA = 5, - CLOSED = 6, - FAILED = 7, - EXPLICIT_REJECT = 8, - } state_ = State::INITIALIZE; }; #endif // USE_API_NOISE #ifdef USE_API_PLAINTEXT class APIPlaintextFrameHelper : public APIFrameHelper { public: - APIPlaintextFrameHelper(std::unique_ptr socket) : socket_(std::move(socket)) { + APIPlaintextFrameHelper(std::unique_ptr socket) : APIFrameHelper(std::move(socket)) { // Plaintext header structure (worst case): // Pos 0: indicator (0x00) // Pos 1-3: payload size varint (up to 3 bytes) @@ -184,35 +221,13 @@ class APIPlaintextFrameHelper : public APIFrameHelper { APIError init() override; APIError loop() override; APIError read_packet(ReadPacketBuffer *buffer) override; - bool can_write_without_blocking() override; APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) override; - std::string getpeername() override { return this->socket_->getpeername(); } - int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { - return this->socket_->getpeername(addr, addrlen); - } - APIError close() override; - APIError shutdown(int how) override; - // Give this helper a name for logging - void set_log_info(std::string info) override { info_ = std::move(info); } - // Get the frame header padding required by this protocol uint8_t frame_header_padding() override { return frame_header_padding_; } // Get the frame footer size required by this protocol uint8_t frame_footer_size() override { return frame_footer_size_; } protected: - struct ParsedFrame { - std::vector msg; - }; - APIError try_read_frame_(ParsedFrame *frame); - APIError try_send_tx_buf_(); - inline APIError write_raw_(const struct iovec *iov, int iovcnt) { - return APIFrameHelper::write_raw_(iov, iovcnt, socket_.get(), tx_buf_, info_, state_, State::FAILED); - } - - std::unique_ptr socket_; - - std::string info_; // Fixed-size header buffer for plaintext protocol: // We only need space for the two varints since we validate the indicator byte separately. // To match noise protocol's maximum message size (65535), we need: @@ -224,20 +239,8 @@ class APIPlaintextFrameHelper : public APIFrameHelper { uint8_t rx_header_buf_[5]; // 5 bytes for varints (3 for size + 2 for type) uint8_t rx_header_buf_pos_ = 0; bool rx_header_parsed_ = false; - uint32_t rx_header_parsed_type_ = 0; - uint32_t rx_header_parsed_len_ = 0; - - std::vector rx_buf_; - size_t rx_buf_len_ = 0; - - std::vector tx_buf_; - - enum class State { - INITIALIZE = 1, - DATA = 2, - CLOSED = 3, - FAILED = 4, - } state_ = State::INITIALIZE; + uint16_t rx_header_parsed_type_ = 0; + uint16_t rx_header_parsed_len_ = 0; }; #endif diff --git a/esphome/components/api/proto.h b/esphome/components/api/proto.h index 65bef0b6f7..fae722f750 100644 --- a/esphome/components/api/proto.h +++ b/esphome/components/api/proto.h @@ -55,6 +55,7 @@ class ProtoVarInt { return {}; // Incomplete or invalid varint } + uint16_t as_uint16() const { return this->value_; } uint32_t as_uint32() const { return this->value_; } uint64_t as_uint64() const { return this->value_; } bool as_bool() const { return this->value_; }