diff --git a/esphome/components/api/api_frame_helper.cpp b/esphome/components/api/api_frame_helper.cpp index e9751ac8dc..31b0732275 100644 --- a/esphome/components/api/api_frame_helper.cpp +++ b/esphome/components/api/api_frame_helper.cpp @@ -73,6 +73,91 @@ const char *api_error_to_str(APIError err) { return "UNKNOWN"; } +// Common implementation for writing raw data to socket +template +APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt, socket::Socket *socket, + std::vector &tx_buf, const std::string &info, StateEnum &state, + StateEnum failed_state) { + // This method writes data to socket or buffers it + // Returns APIError::OK if successful (or would block, but data has been buffered) + // Returns APIError::SOCKET_WRITE_FAILED if socket write failed, and sets state to failed_state + + if (iovcnt == 0) + return APIError::OK; // Nothing to do, success + + size_t total_write_len = 0; + for (int i = 0; i < iovcnt; i++) { +#ifdef HELPER_LOG_PACKETS + ESP_LOGVV(TAG, "Sending raw: %s", + format_hex_pretty(reinterpret_cast(iov[i].iov_base), iov[i].iov_len).c_str()); +#endif + total_write_len += iov[i].iov_len; + } + + if (!tx_buf.empty()) { + // try to empty tx_buf first + while (!tx_buf.empty()) { + ssize_t sent = socket->write(tx_buf.data(), tx_buf.size()); + if (is_would_block(sent)) { + break; + } else if (sent == -1) { + ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", info.c_str(), errno); + state = failed_state; + return APIError::SOCKET_WRITE_FAILED; // Socket write failed + } + // TODO: inefficient if multiple packets in txbuf + // replace with deque of buffers + tx_buf.erase(tx_buf.begin(), tx_buf.begin() + sent); + } + } + + if (!tx_buf.empty()) { + // tx buf not empty, can't write now because then stream would be inconsistent + // Reserve space upfront to avoid multiple reallocations + tx_buf.reserve(tx_buf.size() + total_write_len); + for (int i = 0; i < iovcnt; i++) { + tx_buf.insert(tx_buf.end(), reinterpret_cast(iov[i].iov_base), + reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); + } + return APIError::OK; // Success, data buffered + } + + ssize_t sent = socket->writev(iov, iovcnt); + if (is_would_block(sent)) { + // operation would block, add buffer to tx_buf + // Reserve space upfront to avoid multiple reallocations + tx_buf.reserve(tx_buf.size() + total_write_len); + for (int i = 0; i < iovcnt; i++) { + tx_buf.insert(tx_buf.end(), reinterpret_cast(iov[i].iov_base), + reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); + } + return APIError::OK; // Success, data buffered + } else if (sent == -1) { + // an error occurred + ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", info.c_str(), errno); + state = failed_state; + return APIError::SOCKET_WRITE_FAILED; // Socket write failed + } else if ((size_t) sent != total_write_len) { + // partially sent, add end to tx_buf + size_t remaining = total_write_len - sent; + // Reserve space upfront to avoid multiple reallocations + tx_buf.reserve(tx_buf.size() + remaining); + + size_t to_consume = sent; + for (int i = 0; i < iovcnt; i++) { + if (to_consume >= iov[i].iov_len) { + to_consume -= iov[i].iov_len; + } else { + tx_buf.insert(tx_buf.end(), reinterpret_cast(iov[i].iov_base) + to_consume, + reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); + to_consume = 0; + } + } + return APIError::OK; // Success, data buffered + } + return APIError::OK; // Success, all data sent +} + #define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__) // uncomment to log raw packets //#define HELPER_LOG_PACKETS @@ -547,79 +632,6 @@ APIError APINoiseFrameHelper::try_send_tx_buf_() { return APIError::OK; } -/** Write the data to the socket, or buffer it a write would block - * - * @param data The data to write - * @param len The length of data - */ -APIError APINoiseFrameHelper::write_raw_(const struct iovec *iov, int iovcnt) { - if (iovcnt == 0) - return APIError::OK; - APIError aerr; - - size_t total_write_len = 0; - for (int i = 0; i < iovcnt; i++) { -#ifdef HELPER_LOG_PACKETS - ESP_LOGVV(TAG, "Sending raw: %s", - format_hex_pretty(reinterpret_cast(iov[i].iov_base), iov[i].iov_len).c_str()); -#endif - total_write_len += iov[i].iov_len; - } - - if (!tx_buf_.empty()) { - // try to empty tx_buf_ first - aerr = try_send_tx_buf_(); - if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK) - return aerr; - } - - if (!tx_buf_.empty()) { - // tx buf not empty, can't write now because then stream would be inconsistent - // Reserve space upfront to avoid multiple reallocations - tx_buf_.reserve(tx_buf_.size() + total_write_len); - for (int i = 0; i < iovcnt; i++) { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base), - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - } - return APIError::OK; - } - - ssize_t sent = socket_->writev(iov, iovcnt); - if (is_would_block(sent)) { - // operation would block, add buffer to tx_buf - // Reserve space upfront to avoid multiple reallocations - tx_buf_.reserve(tx_buf_.size() + total_write_len); - for (int i = 0; i < iovcnt; i++) { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base), - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - } - return APIError::OK; - } else if (sent == -1) { - // an error occurred - state_ = State::FAILED; - HELPER_LOG("Socket write failed with errno %d", errno); - return APIError::SOCKET_WRITE_FAILED; - } else if ((size_t) sent != total_write_len) { - // partially sent, add end to tx_buf - size_t remaining = total_write_len - sent; - // Reserve space upfront to avoid multiple reallocations - tx_buf_.reserve(tx_buf_.size() + remaining); - - size_t to_consume = sent; - for (int i = 0; i < iovcnt; i++) { - if (to_consume >= iov[i].iov_len) { - to_consume -= iov[i].iov_len; - } else { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base) + to_consume, - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - to_consume = 0; - } - } - return APIError::OK; - } - // fully sent - return APIError::OK; -} APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, size_t len) { uint8_t header[3]; header[0] = 0x01; // indicator @@ -753,6 +765,11 @@ void noise_rand_bytes(void *output, size_t len) { } } } + +// Explicit template instantiation for Noise +template APIError APIFrameHelper::write_raw_( + const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector &tx_buf_, const std::string &info, + APINoiseFrameHelper::State &state, APINoiseFrameHelper::State failed_state); #endif // USE_API_NOISE #ifdef USE_API_PLAINTEXT @@ -977,79 +994,6 @@ APIError APIPlaintextFrameHelper::try_send_tx_buf_() { return APIError::OK; } -/** Write the data to the socket, or buffer it a write would block - * - * @param data The data to write - * @param len The length of data - */ -APIError APIPlaintextFrameHelper::write_raw_(const struct iovec *iov, int iovcnt) { - if (iovcnt == 0) - return APIError::OK; - APIError aerr; - - size_t total_write_len = 0; - for (int i = 0; i < iovcnt; i++) { -#ifdef HELPER_LOG_PACKETS - ESP_LOGVV(TAG, "Sending raw: %s", - format_hex_pretty(reinterpret_cast(iov[i].iov_base), iov[i].iov_len).c_str()); -#endif - total_write_len += iov[i].iov_len; - } - - if (!tx_buf_.empty()) { - // try to empty tx_buf_ first - aerr = try_send_tx_buf_(); - if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK) - return aerr; - } - - if (!tx_buf_.empty()) { - // tx buf not empty, can't write now because then stream would be inconsistent - // Reserve space upfront to avoid multiple reallocations - tx_buf_.reserve(tx_buf_.size() + total_write_len); - for (int i = 0; i < iovcnt; i++) { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base), - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - } - return APIError::OK; - } - - ssize_t sent = socket_->writev(iov, iovcnt); - if (is_would_block(sent)) { - // operation would block, add buffer to tx_buf - // Reserve space upfront to avoid multiple reallocations - tx_buf_.reserve(tx_buf_.size() + total_write_len); - for (int i = 0; i < iovcnt; i++) { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base), - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - } - return APIError::OK; - } else if (sent == -1) { - // an error occurred - state_ = State::FAILED; - HELPER_LOG("Socket write failed with errno %d", errno); - return APIError::SOCKET_WRITE_FAILED; - } else if ((size_t) sent != total_write_len) { - // partially sent, add end to tx_buf - size_t remaining = total_write_len - sent; - // Reserve space upfront to avoid multiple reallocations - tx_buf_.reserve(tx_buf_.size() + remaining); - - size_t to_consume = sent; - for (int i = 0; i < iovcnt; i++) { - if (to_consume >= iov[i].iov_len) { - to_consume -= iov[i].iov_len; - } else { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base) + to_consume, - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - to_consume = 0; - } - } - return APIError::OK; - } - // fully sent - return APIError::OK; -} APIError APIPlaintextFrameHelper::close() { state_ = State::CLOSED; @@ -1067,6 +1011,11 @@ APIError APIPlaintextFrameHelper::shutdown(int how) { } return APIError::OK; } + +// Explicit template instantiation for Plaintext +template APIError APIFrameHelper::write_raw_( + const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector &tx_buf_, const std::string &info, + APIPlaintextFrameHelper::State &state, APIPlaintextFrameHelper::State failed_state); #endif // USE_API_PLAINTEXT } // namespace api diff --git a/esphome/components/api/api_frame_helper.h b/esphome/components/api/api_frame_helper.h index 56d8bf1973..59f3cf7471 100644 --- a/esphome/components/api/api_frame_helper.h +++ b/esphome/components/api/api_frame_helper.h @@ -72,6 +72,12 @@ class APIFrameHelper { virtual APIError shutdown(int how) = 0; // Give this helper a name for logging virtual void set_log_info(std::string info) = 0; + + protected: + // Common implementation for writing raw data to socket + template + APIError write_raw_(const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector &tx_buf, + const std::string &info, StateEnum &state, StateEnum failed_state); }; #ifdef USE_API_NOISE @@ -103,7 +109,9 @@ class APINoiseFrameHelper : public APIFrameHelper { APIError try_read_frame_(ParsedFrame *frame); APIError try_send_tx_buf_(); APIError write_frame_(const uint8_t *data, size_t len); - APIError write_raw_(const struct iovec *iov, int iovcnt); + inline APIError write_raw_(const struct iovec *iov, int iovcnt) { + return APIFrameHelper::write_raw_(iov, iovcnt, socket_.get(), tx_buf_, info_, state_, State::FAILED); + } APIError init_handshake_(); APIError check_handshake_finished_(); void send_explicit_handshake_reject_(const std::string &reason); @@ -164,7 +172,9 @@ class APIPlaintextFrameHelper : public APIFrameHelper { APIError try_read_frame_(ParsedFrame *frame); APIError try_send_tx_buf_(); - APIError write_raw_(const struct iovec *iov, int iovcnt); + inline APIError write_raw_(const struct iovec *iov, int iovcnt) { + return APIFrameHelper::write_raw_(iov, iovcnt, socket_.get(), tx_buf_, info_, state_, State::FAILED); + } std::unique_ptr socket_;