diff --git a/esphome/components/api/__init__.py b/esphome/components/api/__init__.py index 5b302760b1..9cbab8164f 100644 --- a/esphome/components/api/__init__.py +++ b/esphome/components/api/__init__.py @@ -323,9 +323,10 @@ async def api_connected_to_code(config, condition_id, template_arg, args): def FILTER_SOURCE_FILES() -> list[str]: - """Filter out api_pb2_dump.cpp when proto message dumping is not enabled - and user_services.cpp when no services are defined.""" - files_to_filter = [] + """Filter out api_pb2_dump.cpp when proto message dumping is not enabled, + user_services.cpp when no services are defined, and protocol-specific + implementations based on encryption configuration.""" + files_to_filter: list[str] = [] # api_pb2_dump.cpp is only needed when HAS_PROTO_MESSAGE_DUMP is defined # This is a particularly large file that still needs to be opened and read @@ -341,4 +342,16 @@ def FILTER_SOURCE_FILES() -> list[str]: if config and not config.get(CONF_ACTIONS) and not config[CONF_CUSTOM_SERVICES]: files_to_filter.append("user_services.cpp") + # Filter protocol-specific implementations based on encryption configuration + encryption_config = config.get(CONF_ENCRYPTION) if config else None + + # If encryption is not configured at all, we only need plaintext + if encryption_config is None: + files_to_filter.append("api_frame_helper_noise.cpp") + # If encryption is configured with a key, we only need noise + elif encryption_config.get(CONF_KEY): + files_to_filter.append("api_frame_helper_plaintext.cpp") + # If encryption is configured but no key is provided, we need both + # (this allows a plaintext client to provide a noise key) + return files_to_filter diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index c95992e172..602a0256cf 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -1,5 +1,11 @@ #include "api_connection.h" #ifdef USE_API +#ifdef USE_API_NOISE +#include "api_frame_helper_noise.h" +#endif +#ifdef USE_API_PLAINTEXT +#include "api_frame_helper_plaintext.h" +#endif #include #include #include diff --git a/esphome/components/api/api_frame_helper.cpp b/esphome/components/api/api_frame_helper.cpp index 39c01c028c..b1c9478e59 100644 --- a/esphome/components/api/api_frame_helper.cpp +++ b/esphome/components/api/api_frame_helper.cpp @@ -12,18 +12,24 @@ namespace esphome { namespace api { -static const char *const TAG = "api.socket"; +static const char *const TAG = "api.frame_helper"; #define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->client_info_->get_combined_info().c_str(), ##__VA_ARGS__) +#ifdef HELPER_LOG_PACKETS +#define LOG_PACKET_RECEIVED(buffer) ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(buffer).c_str()) +#define LOG_PACKET_SENDING(data, len) ESP_LOGVV(TAG, "Sending raw: %s", format_hex_pretty(data, len).c_str()) +#else +#define LOG_PACKET_RECEIVED(buffer) ((void) 0) +#define LOG_PACKET_SENDING(data, len) ((void) 0) +#endif + const char *api_error_to_str(APIError err) { // not using switch to ensure compiler doesn't try to build a big table out of it if (err == APIError::OK) { return "OK"; } else if (err == APIError::WOULD_BLOCK) { return "WOULD_BLOCK"; - } else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) { - return "BAD_HANDSHAKE_PACKET_LEN"; } else if (err == APIError::BAD_INDICATOR) { return "BAD_INDICATOR"; } else if (err == APIError::BAD_DATA_PACKET) { @@ -44,6 +50,14 @@ const char *api_error_to_str(APIError err) { return "SOCKET_READ_FAILED"; } else if (err == APIError::SOCKET_WRITE_FAILED) { return "SOCKET_WRITE_FAILED"; + } else if (err == APIError::OUT_OF_MEMORY) { + return "OUT_OF_MEMORY"; + } else if (err == APIError::CONNECTION_CLOSED) { + return "CONNECTION_CLOSED"; + } +#ifdef USE_API_NOISE + else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) { + return "BAD_HANDSHAKE_PACKET_LEN"; } else if (err == APIError::HANDSHAKESTATE_READ_FAILED) { return "HANDSHAKESTATE_READ_FAILED"; } else if (err == APIError::HANDSHAKESTATE_WRITE_FAILED) { @@ -54,17 +68,14 @@ const char *api_error_to_str(APIError err) { return "CIPHERSTATE_DECRYPT_FAILED"; } else if (err == APIError::CIPHERSTATE_ENCRYPT_FAILED) { return "CIPHERSTATE_ENCRYPT_FAILED"; - } else if (err == APIError::OUT_OF_MEMORY) { - return "OUT_OF_MEMORY"; } else if (err == APIError::HANDSHAKESTATE_SETUP_FAILED) { return "HANDSHAKESTATE_SETUP_FAILED"; } else if (err == APIError::HANDSHAKESTATE_SPLIT_FAILED) { return "HANDSHAKESTATE_SPLIT_FAILED"; } else if (err == APIError::BAD_HANDSHAKE_ERROR_BYTE) { return "BAD_HANDSHAKE_ERROR_BYTE"; - } else if (err == APIError::CONNECTION_CLOSED) { - return "CONNECTION_CLOSED"; } +#endif return "UNKNOWN"; } @@ -125,8 +136,7 @@ APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt, uint16_ #ifdef HELPER_LOG_PACKETS for (int i = 0; i < iovcnt; i++) { - ESP_LOGVV(TAG, "Sending raw: %s", - format_hex_pretty(reinterpret_cast(iov[i].iov_base), iov[i].iov_len).c_str()); + LOG_PACKET_SENDING(reinterpret_cast(iov[i].iov_base), iov[i].iov_len); } #endif @@ -236,829 +246,6 @@ APIError APIFrameHelper::handle_socket_read_result_(ssize_t received) { } return APIError::OK; } -// uncomment to log raw packets -//#define HELPER_LOG_PACKETS - -#ifdef USE_API_NOISE -static const char *const PROLOGUE_INIT = "NoiseAPIInit"; - -/// Convert a noise error code to a readable error -std::string noise_err_to_str(int err) { - if (err == NOISE_ERROR_NO_MEMORY) - return "NO_MEMORY"; - if (err == NOISE_ERROR_UNKNOWN_ID) - return "UNKNOWN_ID"; - if (err == NOISE_ERROR_UNKNOWN_NAME) - return "UNKNOWN_NAME"; - if (err == NOISE_ERROR_MAC_FAILURE) - return "MAC_FAILURE"; - if (err == NOISE_ERROR_NOT_APPLICABLE) - return "NOT_APPLICABLE"; - if (err == NOISE_ERROR_SYSTEM) - return "SYSTEM"; - if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED) - return "REMOTE_KEY_REQUIRED"; - if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED) - return "LOCAL_KEY_REQUIRED"; - if (err == NOISE_ERROR_PSK_REQUIRED) - return "PSK_REQUIRED"; - if (err == NOISE_ERROR_INVALID_LENGTH) - return "INVALID_LENGTH"; - if (err == NOISE_ERROR_INVALID_PARAM) - return "INVALID_PARAM"; - if (err == NOISE_ERROR_INVALID_STATE) - return "INVALID_STATE"; - if (err == NOISE_ERROR_INVALID_NONCE) - return "INVALID_NONCE"; - if (err == NOISE_ERROR_INVALID_PRIVATE_KEY) - return "INVALID_PRIVATE_KEY"; - if (err == NOISE_ERROR_INVALID_PUBLIC_KEY) - return "INVALID_PUBLIC_KEY"; - if (err == NOISE_ERROR_INVALID_FORMAT) - return "INVALID_FORMAT"; - if (err == NOISE_ERROR_INVALID_SIGNATURE) - return "INVALID_SIGNATURE"; - return to_string(err); -} - -/// Initialize the frame helper, returns OK if successful. -APIError APINoiseFrameHelper::init() { - APIError err = init_common_(); - if (err != APIError::OK) { - return err; - } - - // init prologue - prologue_.insert(prologue_.end(), PROLOGUE_INIT, PROLOGUE_INIT + strlen(PROLOGUE_INIT)); - - state_ = State::CLIENT_HELLO; - return APIError::OK; -} -// Helper for handling handshake frame errors -APIError APINoiseFrameHelper::handle_handshake_frame_error_(APIError aerr) { - if (aerr == APIError::BAD_INDICATOR) { - send_explicit_handshake_reject_("Bad indicator byte"); - } else if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) { - send_explicit_handshake_reject_("Bad handshake packet len"); - } - return aerr; -} - -// Helper for handling noise library errors -APIError APINoiseFrameHelper::handle_noise_error_(int err, const char *func_name, APIError api_err) { - if (err != 0) { - state_ = State::FAILED; - HELPER_LOG("%s failed: %s", func_name, noise_err_to_str(err).c_str()); - return api_err; - } - return APIError::OK; -} - -/// Run through handshake messages (if in that phase) -APIError APINoiseFrameHelper::loop() { - // During handshake phase, process as many actions as possible until we can't progress - // socket_->ready() stays true until next main loop, but state_action() will return - // WOULD_BLOCK when no more data is available to read - while (state_ != State::DATA && this->socket_->ready()) { - APIError err = state_action_(); - if (err == APIError::WOULD_BLOCK) { - break; - } - if (err != APIError::OK) { - return err; - } - } - - // Use base class implementation for buffer sending - return APIFrameHelper::loop(); -} - -/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter - * - * @param frame: The struct to hold the frame information in. - * msg_start: points to the start of the payload - this pointer is only valid until the next - * try_receive_raw_ call - * - * @return 0 if a full packet is in rx_buf_ - * @return -1 if error, check errno. - * - * errno EWOULDBLOCK: Packet could not be read without blocking. Try again later. - * errno ENOMEM: Not enough memory for reading packet. - * errno API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame. - * errno API_ERROR_HANDSHAKE_PACKET_LEN: Packet too big for this phase. - */ -APIError APINoiseFrameHelper::try_read_frame_(std::vector *frame) { - if (frame == nullptr) { - HELPER_LOG("Bad argument for try_read_frame_"); - return APIError::BAD_ARG; - } - - // read header - if (rx_header_buf_len_ < 3) { - // no header information yet - uint8_t to_read = 3 - rx_header_buf_len_; - ssize_t received = this->socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read); - APIError err = handle_socket_read_result_(received); - if (err != APIError::OK) { - return err; - } - rx_header_buf_len_ += static_cast(received); - if (static_cast(received) != to_read) { - // not a full read - return APIError::WOULD_BLOCK; - } - - if (rx_header_buf_[0] != 0x01) { - state_ = State::FAILED; - HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]); - return APIError::BAD_INDICATOR; - } - // header reading done - } - - // read body - uint16_t msg_size = (((uint16_t) rx_header_buf_[1]) << 8) | rx_header_buf_[2]; - - if (state_ != State::DATA && msg_size > 128) { - // for handshake message only permit up to 128 bytes - state_ = State::FAILED; - HELPER_LOG("Bad packet len for handshake: %d", msg_size); - return APIError::BAD_HANDSHAKE_PACKET_LEN; - } - - // reserve space for body - if (rx_buf_.size() != msg_size) { - rx_buf_.resize(msg_size); - } - - if (rx_buf_len_ < msg_size) { - // more data to read - uint16_t to_read = msg_size - rx_buf_len_; - ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read); - APIError err = handle_socket_read_result_(received); - if (err != APIError::OK) { - return err; - } - rx_buf_len_ += static_cast(received); - if (static_cast(received) != to_read) { - // not all read - return APIError::WOULD_BLOCK; - } - } - - // uncomment for even more debugging -#ifdef HELPER_LOG_PACKETS - ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(rx_buf_).c_str()); -#endif - *frame = std::move(rx_buf_); - // consume msg - rx_buf_ = {}; - rx_buf_len_ = 0; - rx_header_buf_len_ = 0; - return APIError::OK; -} - -/** To be called from read/write methods. - * - * This method runs through the internal handshake methods, if in that state. - * - * If the handshake is still active when this method returns and a read/write can't take place at - * the moment, returns WOULD_BLOCK. - * If an error occurred, returns that error. Only returns OK if the transport is ready for data - * traffic. - */ -APIError APINoiseFrameHelper::state_action_() { - int err; - APIError aerr; - if (state_ == State::INITIALIZE) { - HELPER_LOG("Bad state for method: %d", (int) state_); - return APIError::BAD_STATE; - } - if (state_ == State::CLIENT_HELLO) { - // waiting for client hello - std::vector frame; - aerr = try_read_frame_(&frame); - if (aerr != APIError::OK) { - return handle_handshake_frame_error_(aerr); - } - // ignore contents, may be used in future for flags - // Reserve space for: existing prologue + 2 size bytes + frame data - prologue_.reserve(prologue_.size() + 2 + frame.size()); - prologue_.push_back((uint8_t) (frame.size() >> 8)); - prologue_.push_back((uint8_t) frame.size()); - prologue_.insert(prologue_.end(), frame.begin(), frame.end()); - - state_ = State::SERVER_HELLO; - } - if (state_ == State::SERVER_HELLO) { - // send server hello - const std::string &name = App.get_name(); - const std::string &mac = get_mac_address(); - - std::vector msg; - // Reserve space for: 1 byte proto + name + null + mac + null - msg.reserve(1 + name.size() + 1 + mac.size() + 1); - - // chosen proto - msg.push_back(0x01); - - // node name, terminated by null byte - const uint8_t *name_ptr = reinterpret_cast(name.c_str()); - msg.insert(msg.end(), name_ptr, name_ptr + name.size() + 1); - // node mac, terminated by null byte - const uint8_t *mac_ptr = reinterpret_cast(mac.c_str()); - msg.insert(msg.end(), mac_ptr, mac_ptr + mac.size() + 1); - - aerr = write_frame_(msg.data(), msg.size()); - if (aerr != APIError::OK) - return aerr; - - // start handshake - aerr = init_handshake_(); - if (aerr != APIError::OK) - return aerr; - - state_ = State::HANDSHAKE; - } - if (state_ == State::HANDSHAKE) { - int action = noise_handshakestate_get_action(handshake_); - if (action == NOISE_ACTION_READ_MESSAGE) { - // waiting for handshake msg - std::vector frame; - aerr = try_read_frame_(&frame); - if (aerr != APIError::OK) { - return handle_handshake_frame_error_(aerr); - } - - if (frame.empty()) { - send_explicit_handshake_reject_("Empty handshake message"); - return APIError::BAD_HANDSHAKE_ERROR_BYTE; - } else if (frame[0] != 0x00) { - HELPER_LOG("Bad handshake error byte: %u", frame[0]); - send_explicit_handshake_reject_("Bad handshake error byte"); - return APIError::BAD_HANDSHAKE_ERROR_BYTE; - } - - NoiseBuffer mbuf; - noise_buffer_init(mbuf); - noise_buffer_set_input(mbuf, frame.data() + 1, frame.size() - 1); - err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr); - if (err != 0) { - // Special handling for MAC failure - send_explicit_handshake_reject_(err == NOISE_ERROR_MAC_FAILURE ? "Handshake MAC failure" : "Handshake error"); - return handle_noise_error_(err, "noise_handshakestate_read_message", APIError::HANDSHAKESTATE_READ_FAILED); - } - - aerr = check_handshake_finished_(); - if (aerr != APIError::OK) - return aerr; - } else if (action == NOISE_ACTION_WRITE_MESSAGE) { - uint8_t buffer[65]; - NoiseBuffer mbuf; - noise_buffer_init(mbuf); - noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1); - - err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr); - APIError aerr_write = - handle_noise_error_(err, "noise_handshakestate_write_message", APIError::HANDSHAKESTATE_WRITE_FAILED); - if (aerr_write != APIError::OK) - return aerr_write; - buffer[0] = 0x00; // success - - aerr = write_frame_(buffer, mbuf.size + 1); - if (aerr != APIError::OK) - return aerr; - aerr = check_handshake_finished_(); - if (aerr != APIError::OK) - return aerr; - } else { - // bad state for action - state_ = State::FAILED; - HELPER_LOG("Bad action for handshake: %d", action); - return APIError::HANDSHAKESTATE_BAD_STATE; - } - } - if (state_ == State::CLOSED || state_ == State::FAILED) { - return APIError::BAD_STATE; - } - return APIError::OK; -} -void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) { - std::vector data; - data.resize(reason.length() + 1); - data[0] = 0x01; // failure - - // Copy error message in bulk - if (!reason.empty()) { - std::memcpy(data.data() + 1, reason.c_str(), reason.length()); - } - - // temporarily remove failed state - auto orig_state = state_; - state_ = State::EXPLICIT_REJECT; - write_frame_(data.data(), data.size()); - state_ = orig_state; -} -APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { - int err; - APIError aerr; - aerr = state_action_(); - if (aerr != APIError::OK) { - return aerr; - } - - if (state_ != State::DATA) { - return APIError::WOULD_BLOCK; - } - - std::vector frame; - aerr = try_read_frame_(&frame); - if (aerr != APIError::OK) - return aerr; - - NoiseBuffer mbuf; - noise_buffer_init(mbuf); - noise_buffer_set_inout(mbuf, frame.data(), frame.size(), frame.size()); - err = noise_cipherstate_decrypt(recv_cipher_, &mbuf); - APIError decrypt_err = handle_noise_error_(err, "noise_cipherstate_decrypt", APIError::CIPHERSTATE_DECRYPT_FAILED); - if (decrypt_err != APIError::OK) - return decrypt_err; - - uint16_t msg_size = mbuf.size; - uint8_t *msg_data = frame.data(); - if (msg_size < 4) { - state_ = State::FAILED; - HELPER_LOG("Bad data packet: size %d too short", msg_size); - return APIError::BAD_DATA_PACKET; - } - - uint16_t type = (((uint16_t) msg_data[0]) << 8) | msg_data[1]; - uint16_t data_len = (((uint16_t) msg_data[2]) << 8) | msg_data[3]; - if (data_len > msg_size - 4) { - state_ = State::FAILED; - HELPER_LOG("Bad data packet: data_len %u greater than msg_size %u", data_len, msg_size); - return APIError::BAD_DATA_PACKET; - } - - buffer->container = std::move(frame); - buffer->data_offset = 4; - buffer->data_len = data_len; - buffer->type = type; - return APIError::OK; -} -APIError APINoiseFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) { - // Resize to include MAC space (required for Noise encryption) - buffer.get_buffer()->resize(buffer.get_buffer()->size() + frame_footer_size_); - PacketInfo packet{type, 0, - static_cast(buffer.get_buffer()->size() - frame_header_padding_ - frame_footer_size_)}; - return write_protobuf_packets(buffer, std::span(&packet, 1)); -} - -APIError APINoiseFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, std::span packets) { - APIError aerr = state_action_(); - if (aerr != APIError::OK) { - return aerr; - } - - if (state_ != State::DATA) { - return APIError::WOULD_BLOCK; - } - - if (packets.empty()) { - return APIError::OK; - } - - std::vector *raw_buffer = buffer.get_buffer(); - uint8_t *buffer_data = raw_buffer->data(); // Cache buffer pointer - - this->reusable_iovs_.clear(); - this->reusable_iovs_.reserve(packets.size()); - uint16_t total_write_len = 0; - - // We need to encrypt each packet in place - for (const auto &packet : packets) { - // The buffer already has padding at offset - uint8_t *buf_start = buffer_data + packet.offset; - - // Write noise header - buf_start[0] = 0x01; // indicator - // buf_start[1], buf_start[2] to be set after encryption - - // Write message header (to be encrypted) - const uint8_t msg_offset = 3; - buf_start[msg_offset] = static_cast(packet.message_type >> 8); // type high byte - buf_start[msg_offset + 1] = static_cast(packet.message_type); // type low byte - buf_start[msg_offset + 2] = static_cast(packet.payload_size >> 8); // data_len high byte - buf_start[msg_offset + 3] = static_cast(packet.payload_size); // data_len low byte - // payload data is already in the buffer starting at offset + 7 - - // Make sure we have space for MAC - // The buffer should already have been sized appropriately - - // Encrypt the message in place - NoiseBuffer mbuf; - noise_buffer_init(mbuf); - noise_buffer_set_inout(mbuf, buf_start + msg_offset, 4 + packet.payload_size, - 4 + packet.payload_size + frame_footer_size_); - - int err = noise_cipherstate_encrypt(send_cipher_, &mbuf); - APIError aerr = handle_noise_error_(err, "noise_cipherstate_encrypt", APIError::CIPHERSTATE_ENCRYPT_FAILED); - if (aerr != APIError::OK) - return aerr; - - // Fill in the encrypted size - buf_start[1] = static_cast(mbuf.size >> 8); - buf_start[2] = static_cast(mbuf.size); - - // Add iovec for this encrypted packet - size_t packet_len = static_cast(3 + mbuf.size); // indicator + size + encrypted data - this->reusable_iovs_.push_back({buf_start, packet_len}); - total_write_len += packet_len; - } - - // Send all encrypted packets in one writev call - return this->write_raw_(this->reusable_iovs_.data(), this->reusable_iovs_.size(), total_write_len); -} - -APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, uint16_t len) { - uint8_t header[3]; - header[0] = 0x01; // indicator - header[1] = (uint8_t) (len >> 8); - header[2] = (uint8_t) len; - - struct iovec iov[2]; - iov[0].iov_base = header; - iov[0].iov_len = 3; - if (len == 0) { - return this->write_raw_(iov, 1, 3); // Just header - } - iov[1].iov_base = const_cast(data); - iov[1].iov_len = len; - - return this->write_raw_(iov, 2, 3 + len); // Header + data -} - -/** Initiate the data structures for the handshake. - * - * @return 0 on success, -1 on error (check errno) - */ -APIError APINoiseFrameHelper::init_handshake_() { - int err; - memset(&nid_, 0, sizeof(nid_)); - // const char *proto = "Noise_NNpsk0_25519_ChaChaPoly_SHA256"; - // err = noise_protocol_name_to_id(&nid_, proto, strlen(proto)); - nid_.pattern_id = NOISE_PATTERN_NN; - nid_.cipher_id = NOISE_CIPHER_CHACHAPOLY; - nid_.dh_id = NOISE_DH_CURVE25519; - nid_.prefix_id = NOISE_PREFIX_STANDARD; - nid_.hybrid_id = NOISE_DH_NONE; - nid_.hash_id = NOISE_HASH_SHA256; - nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0; - - err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER); - APIError aerr = handle_noise_error_(err, "noise_handshakestate_new_by_id", APIError::HANDSHAKESTATE_SETUP_FAILED); - if (aerr != APIError::OK) - return aerr; - - const auto &psk = ctx_->get_psk(); - err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size()); - aerr = handle_noise_error_(err, "noise_handshakestate_set_pre_shared_key", APIError::HANDSHAKESTATE_SETUP_FAILED); - if (aerr != APIError::OK) - return aerr; - - err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size()); - aerr = handle_noise_error_(err, "noise_handshakestate_set_prologue", APIError::HANDSHAKESTATE_SETUP_FAILED); - if (aerr != APIError::OK) - return aerr; - // set_prologue copies it into handshakestate, so we can get rid of it now - prologue_ = {}; - - err = noise_handshakestate_start(handshake_); - aerr = handle_noise_error_(err, "noise_handshakestate_start", APIError::HANDSHAKESTATE_SETUP_FAILED); - if (aerr != APIError::OK) - return aerr; - return APIError::OK; -} - -APIError APINoiseFrameHelper::check_handshake_finished_() { - assert(state_ == State::HANDSHAKE); - - int action = noise_handshakestate_get_action(handshake_); - if (action == NOISE_ACTION_READ_MESSAGE || action == NOISE_ACTION_WRITE_MESSAGE) - return APIError::OK; - if (action != NOISE_ACTION_SPLIT) { - state_ = State::FAILED; - HELPER_LOG("Bad action for handshake: %d", action); - return APIError::HANDSHAKESTATE_BAD_STATE; - } - int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_); - APIError aerr = handle_noise_error_(err, "noise_handshakestate_split", APIError::HANDSHAKESTATE_SPLIT_FAILED); - if (aerr != APIError::OK) - return aerr; - - frame_footer_size_ = noise_cipherstate_get_mac_length(send_cipher_); - - HELPER_LOG("Handshake complete!"); - noise_handshakestate_free(handshake_); - handshake_ = nullptr; - state_ = State::DATA; - return APIError::OK; -} - -APINoiseFrameHelper::~APINoiseFrameHelper() { - if (handshake_ != nullptr) { - noise_handshakestate_free(handshake_); - handshake_ = nullptr; - } - if (send_cipher_ != nullptr) { - noise_cipherstate_free(send_cipher_); - send_cipher_ = nullptr; - } - if (recv_cipher_ != nullptr) { - noise_cipherstate_free(recv_cipher_); - recv_cipher_ = nullptr; - } -} - -extern "C" { -// declare how noise generates random bytes (here with a good HWRNG based on the RF system) -void noise_rand_bytes(void *output, size_t len) { - if (!esphome::random_bytes(reinterpret_cast(output), len)) { - ESP_LOGE(TAG, "Acquiring random bytes failed; rebooting"); - arch_restart(); - } -} -} - -#endif // USE_API_NOISE - -#ifdef USE_API_PLAINTEXT - -/// Initialize the frame helper, returns OK if successful. -APIError APIPlaintextFrameHelper::init() { - APIError err = init_common_(); - if (err != APIError::OK) { - return err; - } - - state_ = State::DATA; - return APIError::OK; -} -APIError APIPlaintextFrameHelper::loop() { - if (state_ != State::DATA) { - return APIError::BAD_STATE; - } - // Use base class implementation for buffer sending - return APIFrameHelper::loop(); -} - -/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter - * - * @param frame: The struct to hold the frame information in. - * msg: store the parsed frame in that struct - * - * @return See APIError - * - * error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame. - */ -APIError APIPlaintextFrameHelper::try_read_frame_(std::vector *frame) { - if (frame == nullptr) { - HELPER_LOG("Bad argument for try_read_frame_"); - return APIError::BAD_ARG; - } - - // read header - while (!rx_header_parsed_) { - // Now that we know when the socket is ready, we can read up to 3 bytes - // into the rx_header_buf_ before we have to switch back to reading - // one byte at a time to ensure we don't read past the message and - // into the next one. - - // Read directly into rx_header_buf_ at the current position - // Try to get to at least 3 bytes total (indicator + 2 varint bytes), then read one byte at a time - ssize_t received = - this->socket_->read(&rx_header_buf_[rx_header_buf_pos_], rx_header_buf_pos_ < 3 ? 3 - rx_header_buf_pos_ : 1); - APIError err = handle_socket_read_result_(received); - if (err != APIError::OK) { - return err; - } - - // If this was the first read, validate the indicator byte - if (rx_header_buf_pos_ == 0 && received > 0) { - if (rx_header_buf_[0] != 0x00) { - state_ = State::FAILED; - HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]); - return APIError::BAD_INDICATOR; - } - } - - rx_header_buf_pos_ += received; - - // Check for buffer overflow - if (rx_header_buf_pos_ >= sizeof(rx_header_buf_)) { - state_ = State::FAILED; - HELPER_LOG("Header buffer overflow"); - return APIError::BAD_DATA_PACKET; - } - - // Need at least 3 bytes total (indicator + 2 varint bytes) before trying to parse - if (rx_header_buf_pos_ < 3) { - continue; - } - - // At this point, we have at least 3 bytes total: - // - Validated indicator byte (0x00) stored at position 0 - // - At least 2 bytes in the buffer for the varints - // Buffer layout: - // [0]: indicator byte (0x00) - // [1-3]: Message size varint (variable length) - // - 2 bytes would only allow up to 16383, which is less than noise's UINT16_MAX (65535) - // - 3 bytes allows up to 2097151, ensuring we support at least as much as noise - // [2-5]: Message type varint (variable length) - // We now attempt to parse both varints. If either is incomplete, - // we'll continue reading more bytes. - - // Skip indicator byte at position 0 - uint8_t varint_pos = 1; - uint32_t consumed = 0; - - auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[varint_pos], rx_header_buf_pos_ - varint_pos, &consumed); - if (!msg_size_varint.has_value()) { - // not enough data there yet - continue; - } - - if (msg_size_varint->as_uint32() > std::numeric_limits::max()) { - state_ = State::FAILED; - HELPER_LOG("Bad packet: message size %" PRIu32 " exceeds maximum %u", msg_size_varint->as_uint32(), - std::numeric_limits::max()); - return APIError::BAD_DATA_PACKET; - } - rx_header_parsed_len_ = msg_size_varint->as_uint16(); - - // Move to next varint position - varint_pos += consumed; - - auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[varint_pos], rx_header_buf_pos_ - varint_pos, &consumed); - if (!msg_type_varint.has_value()) { - // not enough data there yet - continue; - } - if (msg_type_varint->as_uint32() > std::numeric_limits::max()) { - state_ = State::FAILED; - HELPER_LOG("Bad packet: message type %" PRIu32 " exceeds maximum %u", msg_type_varint->as_uint32(), - std::numeric_limits::max()); - return APIError::BAD_DATA_PACKET; - } - rx_header_parsed_type_ = msg_type_varint->as_uint16(); - rx_header_parsed_ = true; - } - // header reading done - - // reserve space for body - if (rx_buf_.size() != rx_header_parsed_len_) { - rx_buf_.resize(rx_header_parsed_len_); - } - - if (rx_buf_len_ < rx_header_parsed_len_) { - // more data to read - uint16_t to_read = rx_header_parsed_len_ - rx_buf_len_; - ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read); - APIError err = handle_socket_read_result_(received); - if (err != APIError::OK) { - return err; - } - rx_buf_len_ += static_cast(received); - if (static_cast(received) != to_read) { - // not all read - return APIError::WOULD_BLOCK; - } - } - - // uncomment for even more debugging -#ifdef HELPER_LOG_PACKETS - ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(rx_buf_).c_str()); -#endif - *frame = std::move(rx_buf_); - // consume msg - rx_buf_ = {}; - rx_buf_len_ = 0; - rx_header_buf_pos_ = 0; - rx_header_parsed_ = false; - return APIError::OK; -} -APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { - APIError aerr; - - if (state_ != State::DATA) { - return APIError::WOULD_BLOCK; - } - - std::vector frame; - aerr = try_read_frame_(&frame); - if (aerr != APIError::OK) { - if (aerr == APIError::BAD_INDICATOR) { - // Make sure to tell the remote that we don't - // understand the indicator byte so it knows - // we do not support it. - struct iovec iov[1]; - // The \x00 first byte is the marker for plaintext. - // - // The remote will know how to handle the indicator byte, - // but it likely won't understand the rest of the message. - // - // We must send at least 3 bytes to be read, so we add - // a message after the indicator byte to ensures its long - // enough and can aid in debugging. - const char msg[] = "\x00" - "Bad indicator byte"; - iov[0].iov_base = (void *) msg; - iov[0].iov_len = 19; - this->write_raw_(iov, 1, 19); - } - return aerr; - } - - buffer->container = std::move(frame); - buffer->data_offset = 0; - buffer->data_len = rx_header_parsed_len_; - buffer->type = rx_header_parsed_type_; - return APIError::OK; -} -APIError APIPlaintextFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) { - PacketInfo packet{type, 0, static_cast(buffer.get_buffer()->size() - frame_header_padding_)}; - return write_protobuf_packets(buffer, std::span(&packet, 1)); -} - -APIError APIPlaintextFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, std::span packets) { - if (state_ != State::DATA) { - return APIError::BAD_STATE; - } - - if (packets.empty()) { - return APIError::OK; - } - - std::vector *raw_buffer = buffer.get_buffer(); - uint8_t *buffer_data = raw_buffer->data(); // Cache buffer pointer - - this->reusable_iovs_.clear(); - this->reusable_iovs_.reserve(packets.size()); - uint16_t total_write_len = 0; - - for (const auto &packet : packets) { - // Calculate varint sizes for header layout - uint8_t size_varint_len = api::ProtoSize::varint(static_cast(packet.payload_size)); - uint8_t type_varint_len = api::ProtoSize::varint(static_cast(packet.message_type)); - uint8_t total_header_len = 1 + size_varint_len + type_varint_len; - - // Calculate where to start writing the header - // The header starts at the latest possible position to minimize unused padding - // - // Example 1 (small values): total_header_len = 3, header_offset = 6 - 3 = 3 - // [0-2] - Unused padding - // [3] - 0x00 indicator byte - // [4] - Payload size varint (1 byte, for sizes 0-127) - // [5] - Message type varint (1 byte, for types 0-127) - // [6...] - Actual payload data - // - // Example 2 (medium values): total_header_len = 4, header_offset = 6 - 4 = 2 - // [0-1] - Unused padding - // [2] - 0x00 indicator byte - // [3-4] - Payload size varint (2 bytes, for sizes 128-16383) - // [5] - Message type varint (1 byte, for types 0-127) - // [6...] - Actual payload data - // - // Example 3 (large values): total_header_len = 6, header_offset = 6 - 6 = 0 - // [0] - 0x00 indicator byte - // [1-3] - Payload size varint (3 bytes, for sizes 16384-2097151) - // [4-5] - Message type varint (2 bytes, for types 128-32767) - // [6...] - Actual payload data - // - // The message starts at offset + frame_header_padding_ - // So we write the header starting at offset + frame_header_padding_ - total_header_len - uint8_t *buf_start = buffer_data + packet.offset; - uint32_t header_offset = frame_header_padding_ - total_header_len; - - // Write the plaintext header - buf_start[header_offset] = 0x00; // indicator - - // Encode varints directly into buffer - ProtoVarInt(packet.payload_size).encode_to_buffer_unchecked(buf_start + header_offset + 1, size_varint_len); - ProtoVarInt(packet.message_type) - .encode_to_buffer_unchecked(buf_start + header_offset + 1 + size_varint_len, type_varint_len); - - // Add iovec for this packet (header + payload) - size_t packet_len = static_cast(total_header_len + packet.payload_size); - this->reusable_iovs_.push_back({buf_start + header_offset, packet_len}); - total_write_len += packet_len; - } - - // Send all packets in one writev call - return write_raw_(this->reusable_iovs_.data(), this->reusable_iovs_.size(), total_write_len); -} - -#endif // USE_API_PLAINTEXT } // namespace api } // namespace esphome diff --git a/esphome/components/api/api_frame_helper.h b/esphome/components/api/api_frame_helper.h index 87a4b57c2f..231a3366ce 100644 --- a/esphome/components/api/api_frame_helper.h +++ b/esphome/components/api/api_frame_helper.h @@ -8,17 +8,16 @@ #include "esphome/core/defines.h" #ifdef USE_API -#ifdef USE_API_NOISE -#include "noise/protocol.h" -#endif - -#include "api_noise_context.h" #include "esphome/components/socket/socket.h" #include "esphome/core/application.h" +#include "esphome/core/log.h" namespace esphome { namespace api { +// uncomment to log raw packets +//#define HELPER_LOG_PACKETS + // Forward declaration struct ClientInfo; @@ -43,7 +42,6 @@ struct PacketInfo { enum class APIError : uint16_t { OK = 0, WOULD_BLOCK = 1001, - BAD_HANDSHAKE_PACKET_LEN = 1002, BAD_INDICATOR = 1003, BAD_DATA_PACKET = 1004, TCP_NODELAY_FAILED = 1005, @@ -54,16 +52,19 @@ enum class APIError : uint16_t { BAD_ARG = 1010, SOCKET_READ_FAILED = 1011, SOCKET_WRITE_FAILED = 1012, + OUT_OF_MEMORY = 1018, + CONNECTION_CLOSED = 1022, +#ifdef USE_API_NOISE + BAD_HANDSHAKE_PACKET_LEN = 1002, HANDSHAKESTATE_READ_FAILED = 1013, HANDSHAKESTATE_WRITE_FAILED = 1014, HANDSHAKESTATE_BAD_STATE = 1015, CIPHERSTATE_DECRYPT_FAILED = 1016, CIPHERSTATE_ENCRYPT_FAILED = 1017, - OUT_OF_MEMORY = 1018, HANDSHAKESTATE_SETUP_FAILED = 1019, HANDSHAKESTATE_SPLIT_FAILED = 1020, BAD_HANDSHAKE_ERROR_BYTE = 1021, - CONNECTION_CLOSED = 1022, +#endif }; const char *api_error_to_str(APIError err); @@ -183,109 +184,7 @@ class APIFrameHelper { APIError handle_socket_read_result_(ssize_t received); }; -#ifdef USE_API_NOISE -class APINoiseFrameHelper : public APIFrameHelper { - public: - APINoiseFrameHelper(std::unique_ptr socket, std::shared_ptr ctx, - const ClientInfo *client_info) - : APIFrameHelper(std::move(socket), client_info), ctx_(std::move(ctx)) { - // Noise header structure: - // Pos 0: indicator (0x01) - // Pos 1-2: encrypted payload size (16-bit big-endian) - // Pos 3-6: encrypted type (16-bit) + data_len (16-bit) - // Pos 7+: actual payload data - frame_header_padding_ = 7; - } - ~APINoiseFrameHelper() override; - APIError init() override; - APIError loop() override; - APIError read_packet(ReadPacketBuffer *buffer) override; - APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) override; - APIError write_protobuf_packets(ProtoWriteBuffer buffer, std::span packets) override; - // Get the frame header padding required by this protocol - uint8_t frame_header_padding() override { return frame_header_padding_; } - // Get the frame footer size required by this protocol - uint8_t frame_footer_size() override { return frame_footer_size_; } - - protected: - APIError state_action_(); - APIError try_read_frame_(std::vector *frame); - APIError write_frame_(const uint8_t *data, uint16_t len); - APIError init_handshake_(); - APIError check_handshake_finished_(); - void send_explicit_handshake_reject_(const std::string &reason); - APIError handle_handshake_frame_error_(APIError aerr); - APIError handle_noise_error_(int err, const char *func_name, APIError api_err); - - // Pointers first (4 bytes each) - NoiseHandshakeState *handshake_{nullptr}; - NoiseCipherState *send_cipher_{nullptr}; - NoiseCipherState *recv_cipher_{nullptr}; - - // Shared pointer (8 bytes on 32-bit = 4 bytes control block pointer + 4 bytes object pointer) - std::shared_ptr ctx_; - - // Vector (12 bytes on 32-bit) - std::vector prologue_; - - // NoiseProtocolId (size depends on implementation) - NoiseProtocolId nid_; - - // Group small types together - // Fixed-size header buffer for noise protocol: - // 1 byte for indicator + 2 bytes for message size (16-bit value, not varint) - // Note: Maximum message size is UINT16_MAX (65535), with a limit of 128 bytes during handshake phase - uint8_t rx_header_buf_[3]; - uint8_t rx_header_buf_len_ = 0; - // 4 bytes total, no padding -}; -#endif // USE_API_NOISE - -#ifdef USE_API_PLAINTEXT -class APIPlaintextFrameHelper : public APIFrameHelper { - public: - APIPlaintextFrameHelper(std::unique_ptr socket, const ClientInfo *client_info) - : APIFrameHelper(std::move(socket), client_info) { - // Plaintext header structure (worst case): - // Pos 0: indicator (0x00) - // Pos 1-3: payload size varint (up to 3 bytes) - // Pos 4-5: message type varint (up to 2 bytes) - // Pos 6+: actual payload data - frame_header_padding_ = 6; - } - ~APIPlaintextFrameHelper() override = default; - APIError init() override; - APIError loop() override; - APIError read_packet(ReadPacketBuffer *buffer) override; - APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) override; - APIError write_protobuf_packets(ProtoWriteBuffer buffer, std::span packets) override; - uint8_t frame_header_padding() override { return frame_header_padding_; } - // Get the frame footer size required by this protocol - uint8_t frame_footer_size() override { return frame_footer_size_; } - - protected: - APIError try_read_frame_(std::vector *frame); - - // Group 2-byte aligned types - uint16_t rx_header_parsed_type_ = 0; - uint16_t rx_header_parsed_len_ = 0; - - // Group 1-byte types together - // Fixed-size header buffer for plaintext protocol: - // We now store the indicator byte + the two varints. - // To match noise protocol's maximum message size (UINT16_MAX = 65535), we need: - // 1 byte for indicator + 3 bytes for message size varint (supports up to 2097151) + 2 bytes for message type varint - // - // While varints could theoretically be up to 10 bytes each for 64-bit values, - // attempting to process messages with headers that large would likely crash the - // ESP32 due to memory constraints. - uint8_t rx_header_buf_[6]; // 1 byte indicator + 5 bytes for varints (3 for size + 2 for type) - uint8_t rx_header_buf_pos_ = 0; - bool rx_header_parsed_ = false; - // 8 bytes total, no padding needed -}; -#endif - } // namespace api } // namespace esphome -#endif + +#endif // USE_API diff --git a/esphome/components/api/api_frame_helper_noise.cpp b/esphome/components/api/api_frame_helper_noise.cpp new file mode 100644 index 0000000000..3c2c9e059e --- /dev/null +++ b/esphome/components/api/api_frame_helper_noise.cpp @@ -0,0 +1,577 @@ +#include "api_frame_helper_noise.h" +#ifdef USE_API +#ifdef USE_API_NOISE +#include "api_connection.h" // For ClientInfo struct +#include "esphome/core/application.h" +#include "esphome/core/hal.h" +#include "esphome/core/helpers.h" +#include "esphome/core/log.h" +#include "proto.h" +#include +#include + +namespace esphome { +namespace api { + +static const char *const TAG = "api.noise"; +static const char *const PROLOGUE_INIT = "NoiseAPIInit"; + +#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->client_info_->get_combined_info().c_str(), ##__VA_ARGS__) + +#ifdef HELPER_LOG_PACKETS +#define LOG_PACKET_RECEIVED(buffer) ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(buffer).c_str()) +#define LOG_PACKET_SENDING(data, len) ESP_LOGVV(TAG, "Sending raw: %s", format_hex_pretty(data, len).c_str()) +#else +#define LOG_PACKET_RECEIVED(buffer) ((void) 0) +#define LOG_PACKET_SENDING(data, len) ((void) 0) +#endif + +/// Convert a noise error code to a readable error +std::string noise_err_to_str(int err) { + if (err == NOISE_ERROR_NO_MEMORY) + return "NO_MEMORY"; + if (err == NOISE_ERROR_UNKNOWN_ID) + return "UNKNOWN_ID"; + if (err == NOISE_ERROR_UNKNOWN_NAME) + return "UNKNOWN_NAME"; + if (err == NOISE_ERROR_MAC_FAILURE) + return "MAC_FAILURE"; + if (err == NOISE_ERROR_NOT_APPLICABLE) + return "NOT_APPLICABLE"; + if (err == NOISE_ERROR_SYSTEM) + return "SYSTEM"; + if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED) + return "REMOTE_KEY_REQUIRED"; + if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED) + return "LOCAL_KEY_REQUIRED"; + if (err == NOISE_ERROR_PSK_REQUIRED) + return "PSK_REQUIRED"; + if (err == NOISE_ERROR_INVALID_LENGTH) + return "INVALID_LENGTH"; + if (err == NOISE_ERROR_INVALID_PARAM) + return "INVALID_PARAM"; + if (err == NOISE_ERROR_INVALID_STATE) + return "INVALID_STATE"; + if (err == NOISE_ERROR_INVALID_NONCE) + return "INVALID_NONCE"; + if (err == NOISE_ERROR_INVALID_PRIVATE_KEY) + return "INVALID_PRIVATE_KEY"; + if (err == NOISE_ERROR_INVALID_PUBLIC_KEY) + return "INVALID_PUBLIC_KEY"; + if (err == NOISE_ERROR_INVALID_FORMAT) + return "INVALID_FORMAT"; + if (err == NOISE_ERROR_INVALID_SIGNATURE) + return "INVALID_SIGNATURE"; + return to_string(err); +} + +/// Initialize the frame helper, returns OK if successful. +APIError APINoiseFrameHelper::init() { + APIError err = init_common_(); + if (err != APIError::OK) { + return err; + } + + // init prologue + prologue_.insert(prologue_.end(), PROLOGUE_INIT, PROLOGUE_INIT + strlen(PROLOGUE_INIT)); + + state_ = State::CLIENT_HELLO; + return APIError::OK; +} +// Helper for handling handshake frame errors +APIError APINoiseFrameHelper::handle_handshake_frame_error_(APIError aerr) { + if (aerr == APIError::BAD_INDICATOR) { + send_explicit_handshake_reject_("Bad indicator byte"); + } else if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) { + send_explicit_handshake_reject_("Bad handshake packet len"); + } + return aerr; +} + +// Helper for handling noise library errors +APIError APINoiseFrameHelper::handle_noise_error_(int err, const char *func_name, APIError api_err) { + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("%s failed: %s", func_name, noise_err_to_str(err).c_str()); + return api_err; + } + return APIError::OK; +} + +/// Run through handshake messages (if in that phase) +APIError APINoiseFrameHelper::loop() { + // During handshake phase, process as many actions as possible until we can't progress + // socket_->ready() stays true until next main loop, but state_action() will return + // WOULD_BLOCK when no more data is available to read + while (state_ != State::DATA && this->socket_->ready()) { + APIError err = state_action_(); + if (err == APIError::WOULD_BLOCK) { + break; + } + if (err != APIError::OK) { + return err; + } + } + + // Use base class implementation for buffer sending + return APIFrameHelper::loop(); +} + +/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter + * + * @param frame: The struct to hold the frame information in. + * msg_start: points to the start of the payload - this pointer is only valid until the next + * try_receive_raw_ call + * + * @return 0 if a full packet is in rx_buf_ + * @return -1 if error, check errno. + * + * errno EWOULDBLOCK: Packet could not be read without blocking. Try again later. + * errno ENOMEM: Not enough memory for reading packet. + * errno API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame. + * errno API_ERROR_HANDSHAKE_PACKET_LEN: Packet too big for this phase. + */ +APIError APINoiseFrameHelper::try_read_frame_(std::vector *frame) { + if (frame == nullptr) { + HELPER_LOG("Bad argument for try_read_frame_"); + return APIError::BAD_ARG; + } + + // read header + if (rx_header_buf_len_ < 3) { + // no header information yet + uint8_t to_read = 3 - rx_header_buf_len_; + ssize_t received = this->socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read); + APIError err = handle_socket_read_result_(received); + if (err != APIError::OK) { + return err; + } + rx_header_buf_len_ += static_cast(received); + if (static_cast(received) != to_read) { + // not a full read + return APIError::WOULD_BLOCK; + } + + if (rx_header_buf_[0] != 0x01) { + state_ = State::FAILED; + HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]); + return APIError::BAD_INDICATOR; + } + // header reading done + } + + // read body + uint16_t msg_size = (((uint16_t) rx_header_buf_[1]) << 8) | rx_header_buf_[2]; + + if (state_ != State::DATA && msg_size > 128) { + // for handshake message only permit up to 128 bytes + state_ = State::FAILED; + HELPER_LOG("Bad packet len for handshake: %d", msg_size); + return APIError::BAD_HANDSHAKE_PACKET_LEN; + } + + // reserve space for body + if (rx_buf_.size() != msg_size) { + rx_buf_.resize(msg_size); + } + + if (rx_buf_len_ < msg_size) { + // more data to read + uint16_t to_read = msg_size - rx_buf_len_; + ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read); + APIError err = handle_socket_read_result_(received); + if (err != APIError::OK) { + return err; + } + rx_buf_len_ += static_cast(received); + if (static_cast(received) != to_read) { + // not all read + return APIError::WOULD_BLOCK; + } + } + + LOG_PACKET_RECEIVED(rx_buf_); + *frame = std::move(rx_buf_); + // consume msg + rx_buf_ = {}; + rx_buf_len_ = 0; + rx_header_buf_len_ = 0; + return APIError::OK; +} + +/** To be called from read/write methods. + * + * This method runs through the internal handshake methods, if in that state. + * + * If the handshake is still active when this method returns and a read/write can't take place at + * the moment, returns WOULD_BLOCK. + * If an error occurred, returns that error. Only returns OK if the transport is ready for data + * traffic. + */ +APIError APINoiseFrameHelper::state_action_() { + int err; + APIError aerr; + if (state_ == State::INITIALIZE) { + HELPER_LOG("Bad state for method: %d", (int) state_); + return APIError::BAD_STATE; + } + if (state_ == State::CLIENT_HELLO) { + // waiting for client hello + std::vector frame; + aerr = try_read_frame_(&frame); + if (aerr != APIError::OK) { + return handle_handshake_frame_error_(aerr); + } + // ignore contents, may be used in future for flags + // Reserve space for: existing prologue + 2 size bytes + frame data + prologue_.reserve(prologue_.size() + 2 + frame.size()); + prologue_.push_back((uint8_t) (frame.size() >> 8)); + prologue_.push_back((uint8_t) frame.size()); + prologue_.insert(prologue_.end(), frame.begin(), frame.end()); + + state_ = State::SERVER_HELLO; + } + if (state_ == State::SERVER_HELLO) { + // send server hello + const std::string &name = App.get_name(); + const std::string &mac = get_mac_address(); + + std::vector msg; + // Reserve space for: 1 byte proto + name + null + mac + null + msg.reserve(1 + name.size() + 1 + mac.size() + 1); + + // chosen proto + msg.push_back(0x01); + + // node name, terminated by null byte + const uint8_t *name_ptr = reinterpret_cast(name.c_str()); + msg.insert(msg.end(), name_ptr, name_ptr + name.size() + 1); + // node mac, terminated by null byte + const uint8_t *mac_ptr = reinterpret_cast(mac.c_str()); + msg.insert(msg.end(), mac_ptr, mac_ptr + mac.size() + 1); + + aerr = write_frame_(msg.data(), msg.size()); + if (aerr != APIError::OK) + return aerr; + + // start handshake + aerr = init_handshake_(); + if (aerr != APIError::OK) + return aerr; + + state_ = State::HANDSHAKE; + } + if (state_ == State::HANDSHAKE) { + int action = noise_handshakestate_get_action(handshake_); + if (action == NOISE_ACTION_READ_MESSAGE) { + // waiting for handshake msg + std::vector frame; + aerr = try_read_frame_(&frame); + if (aerr != APIError::OK) { + return handle_handshake_frame_error_(aerr); + } + + if (frame.empty()) { + send_explicit_handshake_reject_("Empty handshake message"); + return APIError::BAD_HANDSHAKE_ERROR_BYTE; + } else if (frame[0] != 0x00) { + HELPER_LOG("Bad handshake error byte: %u", frame[0]); + send_explicit_handshake_reject_("Bad handshake error byte"); + return APIError::BAD_HANDSHAKE_ERROR_BYTE; + } + + NoiseBuffer mbuf; + noise_buffer_init(mbuf); + noise_buffer_set_input(mbuf, frame.data() + 1, frame.size() - 1); + err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr); + if (err != 0) { + // Special handling for MAC failure + send_explicit_handshake_reject_(err == NOISE_ERROR_MAC_FAILURE ? "Handshake MAC failure" : "Handshake error"); + return handle_noise_error_(err, "noise_handshakestate_read_message", APIError::HANDSHAKESTATE_READ_FAILED); + } + + aerr = check_handshake_finished_(); + if (aerr != APIError::OK) + return aerr; + } else if (action == NOISE_ACTION_WRITE_MESSAGE) { + uint8_t buffer[65]; + NoiseBuffer mbuf; + noise_buffer_init(mbuf); + noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1); + + err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr); + APIError aerr_write = + handle_noise_error_(err, "noise_handshakestate_write_message", APIError::HANDSHAKESTATE_WRITE_FAILED); + if (aerr_write != APIError::OK) + return aerr_write; + buffer[0] = 0x00; // success + + aerr = write_frame_(buffer, mbuf.size + 1); + if (aerr != APIError::OK) + return aerr; + aerr = check_handshake_finished_(); + if (aerr != APIError::OK) + return aerr; + } else { + // bad state for action + state_ = State::FAILED; + HELPER_LOG("Bad action for handshake: %d", action); + return APIError::HANDSHAKESTATE_BAD_STATE; + } + } + if (state_ == State::CLOSED || state_ == State::FAILED) { + return APIError::BAD_STATE; + } + return APIError::OK; +} +void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) { + std::vector data; + data.resize(reason.length() + 1); + data[0] = 0x01; // failure + + // Copy error message in bulk + if (!reason.empty()) { + std::memcpy(data.data() + 1, reason.c_str(), reason.length()); + } + + // temporarily remove failed state + auto orig_state = state_; + state_ = State::EXPLICIT_REJECT; + write_frame_(data.data(), data.size()); + state_ = orig_state; +} +APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { + int err; + APIError aerr; + aerr = state_action_(); + if (aerr != APIError::OK) { + return aerr; + } + + if (state_ != State::DATA) { + return APIError::WOULD_BLOCK; + } + + std::vector frame; + aerr = try_read_frame_(&frame); + if (aerr != APIError::OK) + return aerr; + + NoiseBuffer mbuf; + noise_buffer_init(mbuf); + noise_buffer_set_inout(mbuf, frame.data(), frame.size(), frame.size()); + err = noise_cipherstate_decrypt(recv_cipher_, &mbuf); + APIError decrypt_err = handle_noise_error_(err, "noise_cipherstate_decrypt", APIError::CIPHERSTATE_DECRYPT_FAILED); + if (decrypt_err != APIError::OK) + return decrypt_err; + + uint16_t msg_size = mbuf.size; + uint8_t *msg_data = frame.data(); + if (msg_size < 4) { + state_ = State::FAILED; + HELPER_LOG("Bad data packet: size %d too short", msg_size); + return APIError::BAD_DATA_PACKET; + } + + uint16_t type = (((uint16_t) msg_data[0]) << 8) | msg_data[1]; + uint16_t data_len = (((uint16_t) msg_data[2]) << 8) | msg_data[3]; + if (data_len > msg_size - 4) { + state_ = State::FAILED; + HELPER_LOG("Bad data packet: data_len %u greater than msg_size %u", data_len, msg_size); + return APIError::BAD_DATA_PACKET; + } + + buffer->container = std::move(frame); + buffer->data_offset = 4; + buffer->data_len = data_len; + buffer->type = type; + return APIError::OK; +} +APIError APINoiseFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) { + // Resize to include MAC space (required for Noise encryption) + buffer.get_buffer()->resize(buffer.get_buffer()->size() + frame_footer_size_); + PacketInfo packet{type, 0, + static_cast(buffer.get_buffer()->size() - frame_header_padding_ - frame_footer_size_)}; + return write_protobuf_packets(buffer, std::span(&packet, 1)); +} + +APIError APINoiseFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, std::span packets) { + APIError aerr = state_action_(); + if (aerr != APIError::OK) { + return aerr; + } + + if (state_ != State::DATA) { + return APIError::WOULD_BLOCK; + } + + if (packets.empty()) { + return APIError::OK; + } + + std::vector *raw_buffer = buffer.get_buffer(); + uint8_t *buffer_data = raw_buffer->data(); // Cache buffer pointer + + this->reusable_iovs_.clear(); + this->reusable_iovs_.reserve(packets.size()); + uint16_t total_write_len = 0; + + // We need to encrypt each packet in place + for (const auto &packet : packets) { + // The buffer already has padding at offset + uint8_t *buf_start = buffer_data + packet.offset; + + // Write noise header + buf_start[0] = 0x01; // indicator + // buf_start[1], buf_start[2] to be set after encryption + + // Write message header (to be encrypted) + const uint8_t msg_offset = 3; + buf_start[msg_offset] = static_cast(packet.message_type >> 8); // type high byte + buf_start[msg_offset + 1] = static_cast(packet.message_type); // type low byte + buf_start[msg_offset + 2] = static_cast(packet.payload_size >> 8); // data_len high byte + buf_start[msg_offset + 3] = static_cast(packet.payload_size); // data_len low byte + // payload data is already in the buffer starting at offset + 7 + + // Make sure we have space for MAC + // The buffer should already have been sized appropriately + + // Encrypt the message in place + NoiseBuffer mbuf; + noise_buffer_init(mbuf); + noise_buffer_set_inout(mbuf, buf_start + msg_offset, 4 + packet.payload_size, + 4 + packet.payload_size + frame_footer_size_); + + int err = noise_cipherstate_encrypt(send_cipher_, &mbuf); + APIError aerr = handle_noise_error_(err, "noise_cipherstate_encrypt", APIError::CIPHERSTATE_ENCRYPT_FAILED); + if (aerr != APIError::OK) + return aerr; + + // Fill in the encrypted size + buf_start[1] = static_cast(mbuf.size >> 8); + buf_start[2] = static_cast(mbuf.size); + + // Add iovec for this encrypted packet + size_t packet_len = static_cast(3 + mbuf.size); // indicator + size + encrypted data + this->reusable_iovs_.push_back({buf_start, packet_len}); + total_write_len += packet_len; + } + + // Send all encrypted packets in one writev call + return this->write_raw_(this->reusable_iovs_.data(), this->reusable_iovs_.size(), total_write_len); +} + +APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, uint16_t len) { + uint8_t header[3]; + header[0] = 0x01; // indicator + header[1] = (uint8_t) (len >> 8); + header[2] = (uint8_t) len; + + struct iovec iov[2]; + iov[0].iov_base = header; + iov[0].iov_len = 3; + if (len == 0) { + return this->write_raw_(iov, 1, 3); // Just header + } + iov[1].iov_base = const_cast(data); + iov[1].iov_len = len; + + return this->write_raw_(iov, 2, 3 + len); // Header + data +} + +/** Initiate the data structures for the handshake. + * + * @return 0 on success, -1 on error (check errno) + */ +APIError APINoiseFrameHelper::init_handshake_() { + int err; + memset(&nid_, 0, sizeof(nid_)); + // const char *proto = "Noise_NNpsk0_25519_ChaChaPoly_SHA256"; + // err = noise_protocol_name_to_id(&nid_, proto, strlen(proto)); + nid_.pattern_id = NOISE_PATTERN_NN; + nid_.cipher_id = NOISE_CIPHER_CHACHAPOLY; + nid_.dh_id = NOISE_DH_CURVE25519; + nid_.prefix_id = NOISE_PREFIX_STANDARD; + nid_.hybrid_id = NOISE_DH_NONE; + nid_.hash_id = NOISE_HASH_SHA256; + nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0; + + err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER); + APIError aerr = handle_noise_error_(err, "noise_handshakestate_new_by_id", APIError::HANDSHAKESTATE_SETUP_FAILED); + if (aerr != APIError::OK) + return aerr; + + const auto &psk = ctx_->get_psk(); + err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size()); + aerr = handle_noise_error_(err, "noise_handshakestate_set_pre_shared_key", APIError::HANDSHAKESTATE_SETUP_FAILED); + if (aerr != APIError::OK) + return aerr; + + err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size()); + aerr = handle_noise_error_(err, "noise_handshakestate_set_prologue", APIError::HANDSHAKESTATE_SETUP_FAILED); + if (aerr != APIError::OK) + return aerr; + // set_prologue copies it into handshakestate, so we can get rid of it now + prologue_ = {}; + + err = noise_handshakestate_start(handshake_); + aerr = handle_noise_error_(err, "noise_handshakestate_start", APIError::HANDSHAKESTATE_SETUP_FAILED); + if (aerr != APIError::OK) + return aerr; + return APIError::OK; +} + +APIError APINoiseFrameHelper::check_handshake_finished_() { + assert(state_ == State::HANDSHAKE); + + int action = noise_handshakestate_get_action(handshake_); + if (action == NOISE_ACTION_READ_MESSAGE || action == NOISE_ACTION_WRITE_MESSAGE) + return APIError::OK; + if (action != NOISE_ACTION_SPLIT) { + state_ = State::FAILED; + HELPER_LOG("Bad action for handshake: %d", action); + return APIError::HANDSHAKESTATE_BAD_STATE; + } + int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_); + APIError aerr = handle_noise_error_(err, "noise_handshakestate_split", APIError::HANDSHAKESTATE_SPLIT_FAILED); + if (aerr != APIError::OK) + return aerr; + + frame_footer_size_ = noise_cipherstate_get_mac_length(send_cipher_); + + HELPER_LOG("Handshake complete!"); + noise_handshakestate_free(handshake_); + handshake_ = nullptr; + state_ = State::DATA; + return APIError::OK; +} + +APINoiseFrameHelper::~APINoiseFrameHelper() { + if (handshake_ != nullptr) { + noise_handshakestate_free(handshake_); + handshake_ = nullptr; + } + if (send_cipher_ != nullptr) { + noise_cipherstate_free(send_cipher_); + send_cipher_ = nullptr; + } + if (recv_cipher_ != nullptr) { + noise_cipherstate_free(recv_cipher_); + recv_cipher_ = nullptr; + } +} + +extern "C" { +// declare how noise generates random bytes (here with a good HWRNG based on the RF system) +void noise_rand_bytes(void *output, size_t len) { + if (!esphome::random_bytes(reinterpret_cast(output), len)) { + ESP_LOGE(TAG, "Acquiring random bytes failed; rebooting"); + arch_restart(); + } +} +} + +} // namespace api +} // namespace esphome +#endif // USE_API_NOISE +#endif // USE_API diff --git a/esphome/components/api/api_frame_helper_noise.h b/esphome/components/api/api_frame_helper_noise.h new file mode 100644 index 0000000000..ed5141d625 --- /dev/null +++ b/esphome/components/api/api_frame_helper_noise.h @@ -0,0 +1,70 @@ +#pragma once +#include "api_frame_helper.h" +#ifdef USE_API +#ifdef USE_API_NOISE +#include "noise/protocol.h" +#include "api_noise_context.h" + +namespace esphome { +namespace api { + +class APINoiseFrameHelper : public APIFrameHelper { + public: + APINoiseFrameHelper(std::unique_ptr socket, std::shared_ptr ctx, + const ClientInfo *client_info) + : APIFrameHelper(std::move(socket), client_info), ctx_(std::move(ctx)) { + // Noise header structure: + // Pos 0: indicator (0x01) + // Pos 1-2: encrypted payload size (16-bit big-endian) + // Pos 3-6: encrypted type (16-bit) + data_len (16-bit) + // Pos 7+: actual payload data + frame_header_padding_ = 7; + } + ~APINoiseFrameHelper() override; + APIError init() override; + APIError loop() override; + APIError read_packet(ReadPacketBuffer *buffer) override; + APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) override; + APIError write_protobuf_packets(ProtoWriteBuffer buffer, std::span packets) override; + // Get the frame header padding required by this protocol + uint8_t frame_header_padding() override { return frame_header_padding_; } + // Get the frame footer size required by this protocol + uint8_t frame_footer_size() override { return frame_footer_size_; } + + protected: + APIError state_action_(); + APIError try_read_frame_(std::vector *frame); + APIError write_frame_(const uint8_t *data, uint16_t len); + APIError init_handshake_(); + APIError check_handshake_finished_(); + void send_explicit_handshake_reject_(const std::string &reason); + APIError handle_handshake_frame_error_(APIError aerr); + APIError handle_noise_error_(int err, const char *func_name, APIError api_err); + + // Pointers first (4 bytes each) + NoiseHandshakeState *handshake_{nullptr}; + NoiseCipherState *send_cipher_{nullptr}; + NoiseCipherState *recv_cipher_{nullptr}; + + // Shared pointer (8 bytes on 32-bit = 4 bytes control block pointer + 4 bytes object pointer) + std::shared_ptr ctx_; + + // Vector (12 bytes on 32-bit) + std::vector prologue_; + + // NoiseProtocolId (size depends on implementation) + NoiseProtocolId nid_; + + // Group small types together + // Fixed-size header buffer for noise protocol: + // 1 byte for indicator + 2 bytes for message size (16-bit value, not varint) + // Note: Maximum message size is UINT16_MAX (65535), with a limit of 128 bytes during handshake phase + uint8_t rx_header_buf_[3]; + uint8_t rx_header_buf_len_ = 0; + // 4 bytes total, no padding +}; + +} // namespace api +} // namespace esphome +#endif // USE_API_NOISE +#endif // USE_API diff --git a/esphome/components/api/api_frame_helper_plaintext.cpp b/esphome/components/api/api_frame_helper_plaintext.cpp new file mode 100644 index 0000000000..d0bc631e1b --- /dev/null +++ b/esphome/components/api/api_frame_helper_plaintext.cpp @@ -0,0 +1,292 @@ +#include "api_frame_helper_plaintext.h" +#ifdef USE_API +#ifdef USE_API_PLAINTEXT +#include "api_connection.h" // For ClientInfo struct +#include "esphome/core/application.h" +#include "esphome/core/hal.h" +#include "esphome/core/helpers.h" +#include "esphome/core/log.h" +#include "proto.h" +#include +#include + +namespace esphome { +namespace api { + +static const char *const TAG = "api.plaintext"; + +#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->client_info_->get_combined_info().c_str(), ##__VA_ARGS__) + +#ifdef HELPER_LOG_PACKETS +#define LOG_PACKET_RECEIVED(buffer) ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(buffer).c_str()) +#define LOG_PACKET_SENDING(data, len) ESP_LOGVV(TAG, "Sending raw: %s", format_hex_pretty(data, len).c_str()) +#else +#define LOG_PACKET_RECEIVED(buffer) ((void) 0) +#define LOG_PACKET_SENDING(data, len) ((void) 0) +#endif + +/// Initialize the frame helper, returns OK if successful. +APIError APIPlaintextFrameHelper::init() { + APIError err = init_common_(); + if (err != APIError::OK) { + return err; + } + + state_ = State::DATA; + return APIError::OK; +} +APIError APIPlaintextFrameHelper::loop() { + if (state_ != State::DATA) { + return APIError::BAD_STATE; + } + // Use base class implementation for buffer sending + return APIFrameHelper::loop(); +} + +/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter + * + * @param frame: The struct to hold the frame information in. + * msg: store the parsed frame in that struct + * + * @return See APIError + * + * error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame. + */ +APIError APIPlaintextFrameHelper::try_read_frame_(std::vector *frame) { + if (frame == nullptr) { + HELPER_LOG("Bad argument for try_read_frame_"); + return APIError::BAD_ARG; + } + + // read header + while (!rx_header_parsed_) { + // Now that we know when the socket is ready, we can read up to 3 bytes + // into the rx_header_buf_ before we have to switch back to reading + // one byte at a time to ensure we don't read past the message and + // into the next one. + + // Read directly into rx_header_buf_ at the current position + // Try to get to at least 3 bytes total (indicator + 2 varint bytes), then read one byte at a time + ssize_t received = + this->socket_->read(&rx_header_buf_[rx_header_buf_pos_], rx_header_buf_pos_ < 3 ? 3 - rx_header_buf_pos_ : 1); + APIError err = handle_socket_read_result_(received); + if (err != APIError::OK) { + return err; + } + + // If this was the first read, validate the indicator byte + if (rx_header_buf_pos_ == 0 && received > 0) { + if (rx_header_buf_[0] != 0x00) { + state_ = State::FAILED; + HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]); + return APIError::BAD_INDICATOR; + } + } + + rx_header_buf_pos_ += received; + + // Check for buffer overflow + if (rx_header_buf_pos_ >= sizeof(rx_header_buf_)) { + state_ = State::FAILED; + HELPER_LOG("Header buffer overflow"); + return APIError::BAD_DATA_PACKET; + } + + // Need at least 3 bytes total (indicator + 2 varint bytes) before trying to parse + if (rx_header_buf_pos_ < 3) { + continue; + } + + // At this point, we have at least 3 bytes total: + // - Validated indicator byte (0x00) stored at position 0 + // - At least 2 bytes in the buffer for the varints + // Buffer layout: + // [0]: indicator byte (0x00) + // [1-3]: Message size varint (variable length) + // - 2 bytes would only allow up to 16383, which is less than noise's UINT16_MAX (65535) + // - 3 bytes allows up to 2097151, ensuring we support at least as much as noise + // [2-5]: Message type varint (variable length) + // We now attempt to parse both varints. If either is incomplete, + // we'll continue reading more bytes. + + // Skip indicator byte at position 0 + uint8_t varint_pos = 1; + uint32_t consumed = 0; + + auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[varint_pos], rx_header_buf_pos_ - varint_pos, &consumed); + if (!msg_size_varint.has_value()) { + // not enough data there yet + continue; + } + + if (msg_size_varint->as_uint32() > std::numeric_limits::max()) { + state_ = State::FAILED; + HELPER_LOG("Bad packet: message size %" PRIu32 " exceeds maximum %u", msg_size_varint->as_uint32(), + std::numeric_limits::max()); + return APIError::BAD_DATA_PACKET; + } + rx_header_parsed_len_ = msg_size_varint->as_uint16(); + + // Move to next varint position + varint_pos += consumed; + + auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[varint_pos], rx_header_buf_pos_ - varint_pos, &consumed); + if (!msg_type_varint.has_value()) { + // not enough data there yet + continue; + } + if (msg_type_varint->as_uint32() > std::numeric_limits::max()) { + state_ = State::FAILED; + HELPER_LOG("Bad packet: message type %" PRIu32 " exceeds maximum %u", msg_type_varint->as_uint32(), + std::numeric_limits::max()); + return APIError::BAD_DATA_PACKET; + } + rx_header_parsed_type_ = msg_type_varint->as_uint16(); + rx_header_parsed_ = true; + } + // header reading done + + // reserve space for body + if (rx_buf_.size() != rx_header_parsed_len_) { + rx_buf_.resize(rx_header_parsed_len_); + } + + if (rx_buf_len_ < rx_header_parsed_len_) { + // more data to read + uint16_t to_read = rx_header_parsed_len_ - rx_buf_len_; + ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read); + APIError err = handle_socket_read_result_(received); + if (err != APIError::OK) { + return err; + } + rx_buf_len_ += static_cast(received); + if (static_cast(received) != to_read) { + // not all read + return APIError::WOULD_BLOCK; + } + } + + LOG_PACKET_RECEIVED(rx_buf_); + *frame = std::move(rx_buf_); + // consume msg + rx_buf_ = {}; + rx_buf_len_ = 0; + rx_header_buf_pos_ = 0; + rx_header_parsed_ = false; + return APIError::OK; +} +APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { + APIError aerr; + + if (state_ != State::DATA) { + return APIError::WOULD_BLOCK; + } + + std::vector frame; + aerr = try_read_frame_(&frame); + if (aerr != APIError::OK) { + if (aerr == APIError::BAD_INDICATOR) { + // Make sure to tell the remote that we don't + // understand the indicator byte so it knows + // we do not support it. + struct iovec iov[1]; + // The \x00 first byte is the marker for plaintext. + // + // The remote will know how to handle the indicator byte, + // but it likely won't understand the rest of the message. + // + // We must send at least 3 bytes to be read, so we add + // a message after the indicator byte to ensures its long + // enough and can aid in debugging. + const char msg[] = "\x00" + "Bad indicator byte"; + iov[0].iov_base = (void *) msg; + iov[0].iov_len = 19; + this->write_raw_(iov, 1, 19); + } + return aerr; + } + + buffer->container = std::move(frame); + buffer->data_offset = 0; + buffer->data_len = rx_header_parsed_len_; + buffer->type = rx_header_parsed_type_; + return APIError::OK; +} +APIError APIPlaintextFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) { + PacketInfo packet{type, 0, static_cast(buffer.get_buffer()->size() - frame_header_padding_)}; + return write_protobuf_packets(buffer, std::span(&packet, 1)); +} + +APIError APIPlaintextFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, std::span packets) { + if (state_ != State::DATA) { + return APIError::BAD_STATE; + } + + if (packets.empty()) { + return APIError::OK; + } + + std::vector *raw_buffer = buffer.get_buffer(); + uint8_t *buffer_data = raw_buffer->data(); // Cache buffer pointer + + this->reusable_iovs_.clear(); + this->reusable_iovs_.reserve(packets.size()); + uint16_t total_write_len = 0; + + for (const auto &packet : packets) { + // Calculate varint sizes for header layout + uint8_t size_varint_len = api::ProtoSize::varint(static_cast(packet.payload_size)); + uint8_t type_varint_len = api::ProtoSize::varint(static_cast(packet.message_type)); + uint8_t total_header_len = 1 + size_varint_len + type_varint_len; + + // Calculate where to start writing the header + // The header starts at the latest possible position to minimize unused padding + // + // Example 1 (small values): total_header_len = 3, header_offset = 6 - 3 = 3 + // [0-2] - Unused padding + // [3] - 0x00 indicator byte + // [4] - Payload size varint (1 byte, for sizes 0-127) + // [5] - Message type varint (1 byte, for types 0-127) + // [6...] - Actual payload data + // + // Example 2 (medium values): total_header_len = 4, header_offset = 6 - 4 = 2 + // [0-1] - Unused padding + // [2] - 0x00 indicator byte + // [3-4] - Payload size varint (2 bytes, for sizes 128-16383) + // [5] - Message type varint (1 byte, for types 0-127) + // [6...] - Actual payload data + // + // Example 3 (large values): total_header_len = 6, header_offset = 6 - 6 = 0 + // [0] - 0x00 indicator byte + // [1-3] - Payload size varint (3 bytes, for sizes 16384-2097151) + // [4-5] - Message type varint (2 bytes, for types 128-32767) + // [6...] - Actual payload data + // + // The message starts at offset + frame_header_padding_ + // So we write the header starting at offset + frame_header_padding_ - total_header_len + uint8_t *buf_start = buffer_data + packet.offset; + uint32_t header_offset = frame_header_padding_ - total_header_len; + + // Write the plaintext header + buf_start[header_offset] = 0x00; // indicator + + // Encode varints directly into buffer + ProtoVarInt(packet.payload_size).encode_to_buffer_unchecked(buf_start + header_offset + 1, size_varint_len); + ProtoVarInt(packet.message_type) + .encode_to_buffer_unchecked(buf_start + header_offset + 1 + size_varint_len, type_varint_len); + + // Add iovec for this packet (header + payload) + size_t packet_len = static_cast(total_header_len + packet.payload_size); + this->reusable_iovs_.push_back({buf_start + header_offset, packet_len}); + total_write_len += packet_len; + } + + // Send all packets in one writev call + return write_raw_(this->reusable_iovs_.data(), this->reusable_iovs_.size(), total_write_len); +} + +} // namespace api +} // namespace esphome +#endif // USE_API_PLAINTEXT +#endif // USE_API diff --git a/esphome/components/api/api_frame_helper_plaintext.h b/esphome/components/api/api_frame_helper_plaintext.h new file mode 100644 index 0000000000..465ceae827 --- /dev/null +++ b/esphome/components/api/api_frame_helper_plaintext.h @@ -0,0 +1,55 @@ +#pragma once +#include "api_frame_helper.h" +#ifdef USE_API +#ifdef USE_API_PLAINTEXT + +namespace esphome { +namespace api { + +class APIPlaintextFrameHelper : public APIFrameHelper { + public: + APIPlaintextFrameHelper(std::unique_ptr socket, const ClientInfo *client_info) + : APIFrameHelper(std::move(socket), client_info) { + // Plaintext header structure (worst case): + // Pos 0: indicator (0x00) + // Pos 1-3: payload size varint (up to 3 bytes) + // Pos 4-5: message type varint (up to 2 bytes) + // Pos 6+: actual payload data + frame_header_padding_ = 6; + } + ~APIPlaintextFrameHelper() override = default; + APIError init() override; + APIError loop() override; + APIError read_packet(ReadPacketBuffer *buffer) override; + APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) override; + APIError write_protobuf_packets(ProtoWriteBuffer buffer, std::span packets) override; + uint8_t frame_header_padding() override { return frame_header_padding_; } + // Get the frame footer size required by this protocol + uint8_t frame_footer_size() override { return frame_footer_size_; } + + protected: + APIError try_read_frame_(std::vector *frame); + + // Group 2-byte aligned types + uint16_t rx_header_parsed_type_ = 0; + uint16_t rx_header_parsed_len_ = 0; + + // Group 1-byte types together + // Fixed-size header buffer for plaintext protocol: + // We now store the indicator byte + the two varints. + // To match noise protocol's maximum message size (UINT16_MAX = 65535), we need: + // 1 byte for indicator + 3 bytes for message size varint (supports up to 2097151) + 2 bytes for message type varint + // + // While varints could theoretically be up to 10 bytes each for 64-bit values, + // attempting to process messages with headers that large would likely crash the + // ESP32 due to memory constraints. + uint8_t rx_header_buf_[6]; // 1 byte indicator + 5 bytes for varints (3 for size + 2 for type) + uint8_t rx_header_buf_pos_ = 0; + bool rx_header_parsed_ = false; + // 8 bytes total, no padding needed +}; + +} // namespace api +} // namespace esphome +#endif // USE_API_PLAINTEXT +#endif // USE_API