mirror of
https://github.com/esphome/esphome.git
synced 2025-07-28 14:16:40 +00:00
[api] Split frame helper implementation into protocol-specific files (#9746)
This commit is contained in:
parent
46da075226
commit
a45a45c688
@ -323,9 +323,10 @@ async def api_connected_to_code(config, condition_id, template_arg, args):
|
||||
|
||||
|
||||
def FILTER_SOURCE_FILES() -> list[str]:
|
||||
"""Filter out api_pb2_dump.cpp when proto message dumping is not enabled
|
||||
and user_services.cpp when no services are defined."""
|
||||
files_to_filter = []
|
||||
"""Filter out api_pb2_dump.cpp when proto message dumping is not enabled,
|
||||
user_services.cpp when no services are defined, and protocol-specific
|
||||
implementations based on encryption configuration."""
|
||||
files_to_filter: list[str] = []
|
||||
|
||||
# api_pb2_dump.cpp is only needed when HAS_PROTO_MESSAGE_DUMP is defined
|
||||
# This is a particularly large file that still needs to be opened and read
|
||||
@ -341,4 +342,16 @@ def FILTER_SOURCE_FILES() -> list[str]:
|
||||
if config and not config.get(CONF_ACTIONS) and not config[CONF_CUSTOM_SERVICES]:
|
||||
files_to_filter.append("user_services.cpp")
|
||||
|
||||
# Filter protocol-specific implementations based on encryption configuration
|
||||
encryption_config = config.get(CONF_ENCRYPTION) if config else None
|
||||
|
||||
# If encryption is not configured at all, we only need plaintext
|
||||
if encryption_config is None:
|
||||
files_to_filter.append("api_frame_helper_noise.cpp")
|
||||
# If encryption is configured with a key, we only need noise
|
||||
elif encryption_config.get(CONF_KEY):
|
||||
files_to_filter.append("api_frame_helper_plaintext.cpp")
|
||||
# If encryption is configured but no key is provided, we need both
|
||||
# (this allows a plaintext client to provide a noise key)
|
||||
|
||||
return files_to_filter
|
||||
|
@ -1,5 +1,11 @@
|
||||
#include "api_connection.h"
|
||||
#ifdef USE_API
|
||||
#ifdef USE_API_NOISE
|
||||
#include "api_frame_helper_noise.h"
|
||||
#endif
|
||||
#ifdef USE_API_PLAINTEXT
|
||||
#include "api_frame_helper_plaintext.h"
|
||||
#endif
|
||||
#include <cerrno>
|
||||
#include <cinttypes>
|
||||
#include <utility>
|
||||
|
@ -12,18 +12,24 @@
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
static const char *const TAG = "api.socket";
|
||||
static const char *const TAG = "api.frame_helper";
|
||||
|
||||
#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->client_info_->get_combined_info().c_str(), ##__VA_ARGS__)
|
||||
|
||||
#ifdef HELPER_LOG_PACKETS
|
||||
#define LOG_PACKET_RECEIVED(buffer) ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(buffer).c_str())
|
||||
#define LOG_PACKET_SENDING(data, len) ESP_LOGVV(TAG, "Sending raw: %s", format_hex_pretty(data, len).c_str())
|
||||
#else
|
||||
#define LOG_PACKET_RECEIVED(buffer) ((void) 0)
|
||||
#define LOG_PACKET_SENDING(data, len) ((void) 0)
|
||||
#endif
|
||||
|
||||
const char *api_error_to_str(APIError err) {
|
||||
// not using switch to ensure compiler doesn't try to build a big table out of it
|
||||
if (err == APIError::OK) {
|
||||
return "OK";
|
||||
} else if (err == APIError::WOULD_BLOCK) {
|
||||
return "WOULD_BLOCK";
|
||||
} else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) {
|
||||
return "BAD_HANDSHAKE_PACKET_LEN";
|
||||
} else if (err == APIError::BAD_INDICATOR) {
|
||||
return "BAD_INDICATOR";
|
||||
} else if (err == APIError::BAD_DATA_PACKET) {
|
||||
@ -44,6 +50,14 @@ const char *api_error_to_str(APIError err) {
|
||||
return "SOCKET_READ_FAILED";
|
||||
} else if (err == APIError::SOCKET_WRITE_FAILED) {
|
||||
return "SOCKET_WRITE_FAILED";
|
||||
} else if (err == APIError::OUT_OF_MEMORY) {
|
||||
return "OUT_OF_MEMORY";
|
||||
} else if (err == APIError::CONNECTION_CLOSED) {
|
||||
return "CONNECTION_CLOSED";
|
||||
}
|
||||
#ifdef USE_API_NOISE
|
||||
else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) {
|
||||
return "BAD_HANDSHAKE_PACKET_LEN";
|
||||
} else if (err == APIError::HANDSHAKESTATE_READ_FAILED) {
|
||||
return "HANDSHAKESTATE_READ_FAILED";
|
||||
} else if (err == APIError::HANDSHAKESTATE_WRITE_FAILED) {
|
||||
@ -54,17 +68,14 @@ const char *api_error_to_str(APIError err) {
|
||||
return "CIPHERSTATE_DECRYPT_FAILED";
|
||||
} else if (err == APIError::CIPHERSTATE_ENCRYPT_FAILED) {
|
||||
return "CIPHERSTATE_ENCRYPT_FAILED";
|
||||
} else if (err == APIError::OUT_OF_MEMORY) {
|
||||
return "OUT_OF_MEMORY";
|
||||
} else if (err == APIError::HANDSHAKESTATE_SETUP_FAILED) {
|
||||
return "HANDSHAKESTATE_SETUP_FAILED";
|
||||
} else if (err == APIError::HANDSHAKESTATE_SPLIT_FAILED) {
|
||||
return "HANDSHAKESTATE_SPLIT_FAILED";
|
||||
} else if (err == APIError::BAD_HANDSHAKE_ERROR_BYTE) {
|
||||
return "BAD_HANDSHAKE_ERROR_BYTE";
|
||||
} else if (err == APIError::CONNECTION_CLOSED) {
|
||||
return "CONNECTION_CLOSED";
|
||||
}
|
||||
#endif
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
||||
@ -125,8 +136,7 @@ APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt, uint16_
|
||||
|
||||
#ifdef HELPER_LOG_PACKETS
|
||||
for (int i = 0; i < iovcnt; i++) {
|
||||
ESP_LOGVV(TAG, "Sending raw: %s",
|
||||
format_hex_pretty(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len).c_str());
|
||||
LOG_PACKET_SENDING(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -236,829 +246,6 @@ APIError APIFrameHelper::handle_socket_read_result_(ssize_t received) {
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
// uncomment to log raw packets
|
||||
//#define HELPER_LOG_PACKETS
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
static const char *const PROLOGUE_INIT = "NoiseAPIInit";
|
||||
|
||||
/// Convert a noise error code to a readable error
|
||||
std::string noise_err_to_str(int err) {
|
||||
if (err == NOISE_ERROR_NO_MEMORY)
|
||||
return "NO_MEMORY";
|
||||
if (err == NOISE_ERROR_UNKNOWN_ID)
|
||||
return "UNKNOWN_ID";
|
||||
if (err == NOISE_ERROR_UNKNOWN_NAME)
|
||||
return "UNKNOWN_NAME";
|
||||
if (err == NOISE_ERROR_MAC_FAILURE)
|
||||
return "MAC_FAILURE";
|
||||
if (err == NOISE_ERROR_NOT_APPLICABLE)
|
||||
return "NOT_APPLICABLE";
|
||||
if (err == NOISE_ERROR_SYSTEM)
|
||||
return "SYSTEM";
|
||||
if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED)
|
||||
return "REMOTE_KEY_REQUIRED";
|
||||
if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED)
|
||||
return "LOCAL_KEY_REQUIRED";
|
||||
if (err == NOISE_ERROR_PSK_REQUIRED)
|
||||
return "PSK_REQUIRED";
|
||||
if (err == NOISE_ERROR_INVALID_LENGTH)
|
||||
return "INVALID_LENGTH";
|
||||
if (err == NOISE_ERROR_INVALID_PARAM)
|
||||
return "INVALID_PARAM";
|
||||
if (err == NOISE_ERROR_INVALID_STATE)
|
||||
return "INVALID_STATE";
|
||||
if (err == NOISE_ERROR_INVALID_NONCE)
|
||||
return "INVALID_NONCE";
|
||||
if (err == NOISE_ERROR_INVALID_PRIVATE_KEY)
|
||||
return "INVALID_PRIVATE_KEY";
|
||||
if (err == NOISE_ERROR_INVALID_PUBLIC_KEY)
|
||||
return "INVALID_PUBLIC_KEY";
|
||||
if (err == NOISE_ERROR_INVALID_FORMAT)
|
||||
return "INVALID_FORMAT";
|
||||
if (err == NOISE_ERROR_INVALID_SIGNATURE)
|
||||
return "INVALID_SIGNATURE";
|
||||
return to_string(err);
|
||||
}
|
||||
|
||||
/// Initialize the frame helper, returns OK if successful.
|
||||
APIError APINoiseFrameHelper::init() {
|
||||
APIError err = init_common_();
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
// init prologue
|
||||
prologue_.insert(prologue_.end(), PROLOGUE_INIT, PROLOGUE_INIT + strlen(PROLOGUE_INIT));
|
||||
|
||||
state_ = State::CLIENT_HELLO;
|
||||
return APIError::OK;
|
||||
}
|
||||
// Helper for handling handshake frame errors
|
||||
APIError APINoiseFrameHelper::handle_handshake_frame_error_(APIError aerr) {
|
||||
if (aerr == APIError::BAD_INDICATOR) {
|
||||
send_explicit_handshake_reject_("Bad indicator byte");
|
||||
} else if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) {
|
||||
send_explicit_handshake_reject_("Bad handshake packet len");
|
||||
}
|
||||
return aerr;
|
||||
}
|
||||
|
||||
// Helper for handling noise library errors
|
||||
APIError APINoiseFrameHelper::handle_noise_error_(int err, const char *func_name, APIError api_err) {
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("%s failed: %s", func_name, noise_err_to_str(err).c_str());
|
||||
return api_err;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
/// Run through handshake messages (if in that phase)
|
||||
APIError APINoiseFrameHelper::loop() {
|
||||
// During handshake phase, process as many actions as possible until we can't progress
|
||||
// socket_->ready() stays true until next main loop, but state_action() will return
|
||||
// WOULD_BLOCK when no more data is available to read
|
||||
while (state_ != State::DATA && this->socket_->ready()) {
|
||||
APIError err = state_action_();
|
||||
if (err == APIError::WOULD_BLOCK) {
|
||||
break;
|
||||
}
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
// Use base class implementation for buffer sending
|
||||
return APIFrameHelper::loop();
|
||||
}
|
||||
|
||||
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
|
||||
*
|
||||
* @param frame: The struct to hold the frame information in.
|
||||
* msg_start: points to the start of the payload - this pointer is only valid until the next
|
||||
* try_receive_raw_ call
|
||||
*
|
||||
* @return 0 if a full packet is in rx_buf_
|
||||
* @return -1 if error, check errno.
|
||||
*
|
||||
* errno EWOULDBLOCK: Packet could not be read without blocking. Try again later.
|
||||
* errno ENOMEM: Not enough memory for reading packet.
|
||||
* errno API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
|
||||
* errno API_ERROR_HANDSHAKE_PACKET_LEN: Packet too big for this phase.
|
||||
*/
|
||||
APIError APINoiseFrameHelper::try_read_frame_(std::vector<uint8_t> *frame) {
|
||||
if (frame == nullptr) {
|
||||
HELPER_LOG("Bad argument for try_read_frame_");
|
||||
return APIError::BAD_ARG;
|
||||
}
|
||||
|
||||
// read header
|
||||
if (rx_header_buf_len_ < 3) {
|
||||
// no header information yet
|
||||
uint8_t to_read = 3 - rx_header_buf_len_;
|
||||
ssize_t received = this->socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read);
|
||||
APIError err = handle_socket_read_result_(received);
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
rx_header_buf_len_ += static_cast<uint8_t>(received);
|
||||
if (static_cast<uint8_t>(received) != to_read) {
|
||||
// not a full read
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
if (rx_header_buf_[0] != 0x01) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]);
|
||||
return APIError::BAD_INDICATOR;
|
||||
}
|
||||
// header reading done
|
||||
}
|
||||
|
||||
// read body
|
||||
uint16_t msg_size = (((uint16_t) rx_header_buf_[1]) << 8) | rx_header_buf_[2];
|
||||
|
||||
if (state_ != State::DATA && msg_size > 128) {
|
||||
// for handshake message only permit up to 128 bytes
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad packet len for handshake: %d", msg_size);
|
||||
return APIError::BAD_HANDSHAKE_PACKET_LEN;
|
||||
}
|
||||
|
||||
// reserve space for body
|
||||
if (rx_buf_.size() != msg_size) {
|
||||
rx_buf_.resize(msg_size);
|
||||
}
|
||||
|
||||
if (rx_buf_len_ < msg_size) {
|
||||
// more data to read
|
||||
uint16_t to_read = msg_size - rx_buf_len_;
|
||||
ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read);
|
||||
APIError err = handle_socket_read_result_(received);
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
rx_buf_len_ += static_cast<uint16_t>(received);
|
||||
if (static_cast<uint16_t>(received) != to_read) {
|
||||
// not all read
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
}
|
||||
|
||||
// uncomment for even more debugging
|
||||
#ifdef HELPER_LOG_PACKETS
|
||||
ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(rx_buf_).c_str());
|
||||
#endif
|
||||
*frame = std::move(rx_buf_);
|
||||
// consume msg
|
||||
rx_buf_ = {};
|
||||
rx_buf_len_ = 0;
|
||||
rx_header_buf_len_ = 0;
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
/** To be called from read/write methods.
|
||||
*
|
||||
* This method runs through the internal handshake methods, if in that state.
|
||||
*
|
||||
* If the handshake is still active when this method returns and a read/write can't take place at
|
||||
* the moment, returns WOULD_BLOCK.
|
||||
* If an error occurred, returns that error. Only returns OK if the transport is ready for data
|
||||
* traffic.
|
||||
*/
|
||||
APIError APINoiseFrameHelper::state_action_() {
|
||||
int err;
|
||||
APIError aerr;
|
||||
if (state_ == State::INITIALIZE) {
|
||||
HELPER_LOG("Bad state for method: %d", (int) state_);
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
if (state_ == State::CLIENT_HELLO) {
|
||||
// waiting for client hello
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK) {
|
||||
return handle_handshake_frame_error_(aerr);
|
||||
}
|
||||
// ignore contents, may be used in future for flags
|
||||
// Reserve space for: existing prologue + 2 size bytes + frame data
|
||||
prologue_.reserve(prologue_.size() + 2 + frame.size());
|
||||
prologue_.push_back((uint8_t) (frame.size() >> 8));
|
||||
prologue_.push_back((uint8_t) frame.size());
|
||||
prologue_.insert(prologue_.end(), frame.begin(), frame.end());
|
||||
|
||||
state_ = State::SERVER_HELLO;
|
||||
}
|
||||
if (state_ == State::SERVER_HELLO) {
|
||||
// send server hello
|
||||
const std::string &name = App.get_name();
|
||||
const std::string &mac = get_mac_address();
|
||||
|
||||
std::vector<uint8_t> msg;
|
||||
// Reserve space for: 1 byte proto + name + null + mac + null
|
||||
msg.reserve(1 + name.size() + 1 + mac.size() + 1);
|
||||
|
||||
// chosen proto
|
||||
msg.push_back(0x01);
|
||||
|
||||
// node name, terminated by null byte
|
||||
const uint8_t *name_ptr = reinterpret_cast<const uint8_t *>(name.c_str());
|
||||
msg.insert(msg.end(), name_ptr, name_ptr + name.size() + 1);
|
||||
// node mac, terminated by null byte
|
||||
const uint8_t *mac_ptr = reinterpret_cast<const uint8_t *>(mac.c_str());
|
||||
msg.insert(msg.end(), mac_ptr, mac_ptr + mac.size() + 1);
|
||||
|
||||
aerr = write_frame_(msg.data(), msg.size());
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
// start handshake
|
||||
aerr = init_handshake_();
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
state_ = State::HANDSHAKE;
|
||||
}
|
||||
if (state_ == State::HANDSHAKE) {
|
||||
int action = noise_handshakestate_get_action(handshake_);
|
||||
if (action == NOISE_ACTION_READ_MESSAGE) {
|
||||
// waiting for handshake msg
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK) {
|
||||
return handle_handshake_frame_error_(aerr);
|
||||
}
|
||||
|
||||
if (frame.empty()) {
|
||||
send_explicit_handshake_reject_("Empty handshake message");
|
||||
return APIError::BAD_HANDSHAKE_ERROR_BYTE;
|
||||
} else if (frame[0] != 0x00) {
|
||||
HELPER_LOG("Bad handshake error byte: %u", frame[0]);
|
||||
send_explicit_handshake_reject_("Bad handshake error byte");
|
||||
return APIError::BAD_HANDSHAKE_ERROR_BYTE;
|
||||
}
|
||||
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_input(mbuf, frame.data() + 1, frame.size() - 1);
|
||||
err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr);
|
||||
if (err != 0) {
|
||||
// Special handling for MAC failure
|
||||
send_explicit_handshake_reject_(err == NOISE_ERROR_MAC_FAILURE ? "Handshake MAC failure" : "Handshake error");
|
||||
return handle_noise_error_(err, "noise_handshakestate_read_message", APIError::HANDSHAKESTATE_READ_FAILED);
|
||||
}
|
||||
|
||||
aerr = check_handshake_finished_();
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
} else if (action == NOISE_ACTION_WRITE_MESSAGE) {
|
||||
uint8_t buffer[65];
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1);
|
||||
|
||||
err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr);
|
||||
APIError aerr_write =
|
||||
handle_noise_error_(err, "noise_handshakestate_write_message", APIError::HANDSHAKESTATE_WRITE_FAILED);
|
||||
if (aerr_write != APIError::OK)
|
||||
return aerr_write;
|
||||
buffer[0] = 0x00; // success
|
||||
|
||||
aerr = write_frame_(buffer, mbuf.size + 1);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
aerr = check_handshake_finished_();
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
} else {
|
||||
// bad state for action
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad action for handshake: %d", action);
|
||||
return APIError::HANDSHAKESTATE_BAD_STATE;
|
||||
}
|
||||
}
|
||||
if (state_ == State::CLOSED || state_ == State::FAILED) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) {
|
||||
std::vector<uint8_t> data;
|
||||
data.resize(reason.length() + 1);
|
||||
data[0] = 0x01; // failure
|
||||
|
||||
// Copy error message in bulk
|
||||
if (!reason.empty()) {
|
||||
std::memcpy(data.data() + 1, reason.c_str(), reason.length());
|
||||
}
|
||||
|
||||
// temporarily remove failed state
|
||||
auto orig_state = state_;
|
||||
state_ = State::EXPLICIT_REJECT;
|
||||
write_frame_(data.data(), data.size());
|
||||
state_ = orig_state;
|
||||
}
|
||||
APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
aerr = state_action_();
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_inout(mbuf, frame.data(), frame.size(), frame.size());
|
||||
err = noise_cipherstate_decrypt(recv_cipher_, &mbuf);
|
||||
APIError decrypt_err = handle_noise_error_(err, "noise_cipherstate_decrypt", APIError::CIPHERSTATE_DECRYPT_FAILED);
|
||||
if (decrypt_err != APIError::OK)
|
||||
return decrypt_err;
|
||||
|
||||
uint16_t msg_size = mbuf.size;
|
||||
uint8_t *msg_data = frame.data();
|
||||
if (msg_size < 4) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad data packet: size %d too short", msg_size);
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
|
||||
uint16_t type = (((uint16_t) msg_data[0]) << 8) | msg_data[1];
|
||||
uint16_t data_len = (((uint16_t) msg_data[2]) << 8) | msg_data[3];
|
||||
if (data_len > msg_size - 4) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad data packet: data_len %u greater than msg_size %u", data_len, msg_size);
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
|
||||
buffer->container = std::move(frame);
|
||||
buffer->data_offset = 4;
|
||||
buffer->data_len = data_len;
|
||||
buffer->type = type;
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APINoiseFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) {
|
||||
// Resize to include MAC space (required for Noise encryption)
|
||||
buffer.get_buffer()->resize(buffer.get_buffer()->size() + frame_footer_size_);
|
||||
PacketInfo packet{type, 0,
|
||||
static_cast<uint16_t>(buffer.get_buffer()->size() - frame_header_padding_ - frame_footer_size_)};
|
||||
return write_protobuf_packets(buffer, std::span<const PacketInfo>(&packet, 1));
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, std::span<const PacketInfo> packets) {
|
||||
APIError aerr = state_action_();
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
if (packets.empty()) {
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
|
||||
uint8_t *buffer_data = raw_buffer->data(); // Cache buffer pointer
|
||||
|
||||
this->reusable_iovs_.clear();
|
||||
this->reusable_iovs_.reserve(packets.size());
|
||||
uint16_t total_write_len = 0;
|
||||
|
||||
// We need to encrypt each packet in place
|
||||
for (const auto &packet : packets) {
|
||||
// The buffer already has padding at offset
|
||||
uint8_t *buf_start = buffer_data + packet.offset;
|
||||
|
||||
// Write noise header
|
||||
buf_start[0] = 0x01; // indicator
|
||||
// buf_start[1], buf_start[2] to be set after encryption
|
||||
|
||||
// Write message header (to be encrypted)
|
||||
const uint8_t msg_offset = 3;
|
||||
buf_start[msg_offset] = static_cast<uint8_t>(packet.message_type >> 8); // type high byte
|
||||
buf_start[msg_offset + 1] = static_cast<uint8_t>(packet.message_type); // type low byte
|
||||
buf_start[msg_offset + 2] = static_cast<uint8_t>(packet.payload_size >> 8); // data_len high byte
|
||||
buf_start[msg_offset + 3] = static_cast<uint8_t>(packet.payload_size); // data_len low byte
|
||||
// payload data is already in the buffer starting at offset + 7
|
||||
|
||||
// Make sure we have space for MAC
|
||||
// The buffer should already have been sized appropriately
|
||||
|
||||
// Encrypt the message in place
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_inout(mbuf, buf_start + msg_offset, 4 + packet.payload_size,
|
||||
4 + packet.payload_size + frame_footer_size_);
|
||||
|
||||
int err = noise_cipherstate_encrypt(send_cipher_, &mbuf);
|
||||
APIError aerr = handle_noise_error_(err, "noise_cipherstate_encrypt", APIError::CIPHERSTATE_ENCRYPT_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
// Fill in the encrypted size
|
||||
buf_start[1] = static_cast<uint8_t>(mbuf.size >> 8);
|
||||
buf_start[2] = static_cast<uint8_t>(mbuf.size);
|
||||
|
||||
// Add iovec for this encrypted packet
|
||||
size_t packet_len = static_cast<size_t>(3 + mbuf.size); // indicator + size + encrypted data
|
||||
this->reusable_iovs_.push_back({buf_start, packet_len});
|
||||
total_write_len += packet_len;
|
||||
}
|
||||
|
||||
// Send all encrypted packets in one writev call
|
||||
return this->write_raw_(this->reusable_iovs_.data(), this->reusable_iovs_.size(), total_write_len);
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, uint16_t len) {
|
||||
uint8_t header[3];
|
||||
header[0] = 0x01; // indicator
|
||||
header[1] = (uint8_t) (len >> 8);
|
||||
header[2] = (uint8_t) len;
|
||||
|
||||
struct iovec iov[2];
|
||||
iov[0].iov_base = header;
|
||||
iov[0].iov_len = 3;
|
||||
if (len == 0) {
|
||||
return this->write_raw_(iov, 1, 3); // Just header
|
||||
}
|
||||
iov[1].iov_base = const_cast<uint8_t *>(data);
|
||||
iov[1].iov_len = len;
|
||||
|
||||
return this->write_raw_(iov, 2, 3 + len); // Header + data
|
||||
}
|
||||
|
||||
/** Initiate the data structures for the handshake.
|
||||
*
|
||||
* @return 0 on success, -1 on error (check errno)
|
||||
*/
|
||||
APIError APINoiseFrameHelper::init_handshake_() {
|
||||
int err;
|
||||
memset(&nid_, 0, sizeof(nid_));
|
||||
// const char *proto = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
|
||||
// err = noise_protocol_name_to_id(&nid_, proto, strlen(proto));
|
||||
nid_.pattern_id = NOISE_PATTERN_NN;
|
||||
nid_.cipher_id = NOISE_CIPHER_CHACHAPOLY;
|
||||
nid_.dh_id = NOISE_DH_CURVE25519;
|
||||
nid_.prefix_id = NOISE_PREFIX_STANDARD;
|
||||
nid_.hybrid_id = NOISE_DH_NONE;
|
||||
nid_.hash_id = NOISE_HASH_SHA256;
|
||||
nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0;
|
||||
|
||||
err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER);
|
||||
APIError aerr = handle_noise_error_(err, "noise_handshakestate_new_by_id", APIError::HANDSHAKESTATE_SETUP_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
const auto &psk = ctx_->get_psk();
|
||||
err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size());
|
||||
aerr = handle_noise_error_(err, "noise_handshakestate_set_pre_shared_key", APIError::HANDSHAKESTATE_SETUP_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size());
|
||||
aerr = handle_noise_error_(err, "noise_handshakestate_set_prologue", APIError::HANDSHAKESTATE_SETUP_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
// set_prologue copies it into handshakestate, so we can get rid of it now
|
||||
prologue_ = {};
|
||||
|
||||
err = noise_handshakestate_start(handshake_);
|
||||
aerr = handle_noise_error_(err, "noise_handshakestate_start", APIError::HANDSHAKESTATE_SETUP_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::check_handshake_finished_() {
|
||||
assert(state_ == State::HANDSHAKE);
|
||||
|
||||
int action = noise_handshakestate_get_action(handshake_);
|
||||
if (action == NOISE_ACTION_READ_MESSAGE || action == NOISE_ACTION_WRITE_MESSAGE)
|
||||
return APIError::OK;
|
||||
if (action != NOISE_ACTION_SPLIT) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad action for handshake: %d", action);
|
||||
return APIError::HANDSHAKESTATE_BAD_STATE;
|
||||
}
|
||||
int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_);
|
||||
APIError aerr = handle_noise_error_(err, "noise_handshakestate_split", APIError::HANDSHAKESTATE_SPLIT_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
frame_footer_size_ = noise_cipherstate_get_mac_length(send_cipher_);
|
||||
|
||||
HELPER_LOG("Handshake complete!");
|
||||
noise_handshakestate_free(handshake_);
|
||||
handshake_ = nullptr;
|
||||
state_ = State::DATA;
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
APINoiseFrameHelper::~APINoiseFrameHelper() {
|
||||
if (handshake_ != nullptr) {
|
||||
noise_handshakestate_free(handshake_);
|
||||
handshake_ = nullptr;
|
||||
}
|
||||
if (send_cipher_ != nullptr) {
|
||||
noise_cipherstate_free(send_cipher_);
|
||||
send_cipher_ = nullptr;
|
||||
}
|
||||
if (recv_cipher_ != nullptr) {
|
||||
noise_cipherstate_free(recv_cipher_);
|
||||
recv_cipher_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
// declare how noise generates random bytes (here with a good HWRNG based on the RF system)
|
||||
void noise_rand_bytes(void *output, size_t len) {
|
||||
if (!esphome::random_bytes(reinterpret_cast<uint8_t *>(output), len)) {
|
||||
ESP_LOGE(TAG, "Acquiring random bytes failed; rebooting");
|
||||
arch_restart();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // USE_API_NOISE
|
||||
|
||||
#ifdef USE_API_PLAINTEXT
|
||||
|
||||
/// Initialize the frame helper, returns OK if successful.
|
||||
APIError APIPlaintextFrameHelper::init() {
|
||||
APIError err = init_common_();
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
state_ = State::DATA;
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::loop() {
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
// Use base class implementation for buffer sending
|
||||
return APIFrameHelper::loop();
|
||||
}
|
||||
|
||||
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
|
||||
*
|
||||
* @param frame: The struct to hold the frame information in.
|
||||
* msg: store the parsed frame in that struct
|
||||
*
|
||||
* @return See APIError
|
||||
*
|
||||
* error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
|
||||
*/
|
||||
APIError APIPlaintextFrameHelper::try_read_frame_(std::vector<uint8_t> *frame) {
|
||||
if (frame == nullptr) {
|
||||
HELPER_LOG("Bad argument for try_read_frame_");
|
||||
return APIError::BAD_ARG;
|
||||
}
|
||||
|
||||
// read header
|
||||
while (!rx_header_parsed_) {
|
||||
// Now that we know when the socket is ready, we can read up to 3 bytes
|
||||
// into the rx_header_buf_ before we have to switch back to reading
|
||||
// one byte at a time to ensure we don't read past the message and
|
||||
// into the next one.
|
||||
|
||||
// Read directly into rx_header_buf_ at the current position
|
||||
// Try to get to at least 3 bytes total (indicator + 2 varint bytes), then read one byte at a time
|
||||
ssize_t received =
|
||||
this->socket_->read(&rx_header_buf_[rx_header_buf_pos_], rx_header_buf_pos_ < 3 ? 3 - rx_header_buf_pos_ : 1);
|
||||
APIError err = handle_socket_read_result_(received);
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
// If this was the first read, validate the indicator byte
|
||||
if (rx_header_buf_pos_ == 0 && received > 0) {
|
||||
if (rx_header_buf_[0] != 0x00) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]);
|
||||
return APIError::BAD_INDICATOR;
|
||||
}
|
||||
}
|
||||
|
||||
rx_header_buf_pos_ += received;
|
||||
|
||||
// Check for buffer overflow
|
||||
if (rx_header_buf_pos_ >= sizeof(rx_header_buf_)) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Header buffer overflow");
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
|
||||
// Need at least 3 bytes total (indicator + 2 varint bytes) before trying to parse
|
||||
if (rx_header_buf_pos_ < 3) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// At this point, we have at least 3 bytes total:
|
||||
// - Validated indicator byte (0x00) stored at position 0
|
||||
// - At least 2 bytes in the buffer for the varints
|
||||
// Buffer layout:
|
||||
// [0]: indicator byte (0x00)
|
||||
// [1-3]: Message size varint (variable length)
|
||||
// - 2 bytes would only allow up to 16383, which is less than noise's UINT16_MAX (65535)
|
||||
// - 3 bytes allows up to 2097151, ensuring we support at least as much as noise
|
||||
// [2-5]: Message type varint (variable length)
|
||||
// We now attempt to parse both varints. If either is incomplete,
|
||||
// we'll continue reading more bytes.
|
||||
|
||||
// Skip indicator byte at position 0
|
||||
uint8_t varint_pos = 1;
|
||||
uint32_t consumed = 0;
|
||||
|
||||
auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[varint_pos], rx_header_buf_pos_ - varint_pos, &consumed);
|
||||
if (!msg_size_varint.has_value()) {
|
||||
// not enough data there yet
|
||||
continue;
|
||||
}
|
||||
|
||||
if (msg_size_varint->as_uint32() > std::numeric_limits<uint16_t>::max()) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad packet: message size %" PRIu32 " exceeds maximum %u", msg_size_varint->as_uint32(),
|
||||
std::numeric_limits<uint16_t>::max());
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
rx_header_parsed_len_ = msg_size_varint->as_uint16();
|
||||
|
||||
// Move to next varint position
|
||||
varint_pos += consumed;
|
||||
|
||||
auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[varint_pos], rx_header_buf_pos_ - varint_pos, &consumed);
|
||||
if (!msg_type_varint.has_value()) {
|
||||
// not enough data there yet
|
||||
continue;
|
||||
}
|
||||
if (msg_type_varint->as_uint32() > std::numeric_limits<uint16_t>::max()) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad packet: message type %" PRIu32 " exceeds maximum %u", msg_type_varint->as_uint32(),
|
||||
std::numeric_limits<uint16_t>::max());
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
rx_header_parsed_type_ = msg_type_varint->as_uint16();
|
||||
rx_header_parsed_ = true;
|
||||
}
|
||||
// header reading done
|
||||
|
||||
// reserve space for body
|
||||
if (rx_buf_.size() != rx_header_parsed_len_) {
|
||||
rx_buf_.resize(rx_header_parsed_len_);
|
||||
}
|
||||
|
||||
if (rx_buf_len_ < rx_header_parsed_len_) {
|
||||
// more data to read
|
||||
uint16_t to_read = rx_header_parsed_len_ - rx_buf_len_;
|
||||
ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read);
|
||||
APIError err = handle_socket_read_result_(received);
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
rx_buf_len_ += static_cast<uint16_t>(received);
|
||||
if (static_cast<uint16_t>(received) != to_read) {
|
||||
// not all read
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
}
|
||||
|
||||
// uncomment for even more debugging
|
||||
#ifdef HELPER_LOG_PACKETS
|
||||
ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(rx_buf_).c_str());
|
||||
#endif
|
||||
*frame = std::move(rx_buf_);
|
||||
// consume msg
|
||||
rx_buf_ = {};
|
||||
rx_buf_len_ = 0;
|
||||
rx_header_buf_pos_ = 0;
|
||||
rx_header_parsed_ = false;
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
APIError aerr;
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK) {
|
||||
if (aerr == APIError::BAD_INDICATOR) {
|
||||
// Make sure to tell the remote that we don't
|
||||
// understand the indicator byte so it knows
|
||||
// we do not support it.
|
||||
struct iovec iov[1];
|
||||
// The \x00 first byte is the marker for plaintext.
|
||||
//
|
||||
// The remote will know how to handle the indicator byte,
|
||||
// but it likely won't understand the rest of the message.
|
||||
//
|
||||
// We must send at least 3 bytes to be read, so we add
|
||||
// a message after the indicator byte to ensures its long
|
||||
// enough and can aid in debugging.
|
||||
const char msg[] = "\x00"
|
||||
"Bad indicator byte";
|
||||
iov[0].iov_base = (void *) msg;
|
||||
iov[0].iov_len = 19;
|
||||
this->write_raw_(iov, 1, 19);
|
||||
}
|
||||
return aerr;
|
||||
}
|
||||
|
||||
buffer->container = std::move(frame);
|
||||
buffer->data_offset = 0;
|
||||
buffer->data_len = rx_header_parsed_len_;
|
||||
buffer->type = rx_header_parsed_type_;
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) {
|
||||
PacketInfo packet{type, 0, static_cast<uint16_t>(buffer.get_buffer()->size() - frame_header_padding_)};
|
||||
return write_protobuf_packets(buffer, std::span<const PacketInfo>(&packet, 1));
|
||||
}
|
||||
|
||||
APIError APIPlaintextFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, std::span<const PacketInfo> packets) {
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
|
||||
if (packets.empty()) {
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
|
||||
uint8_t *buffer_data = raw_buffer->data(); // Cache buffer pointer
|
||||
|
||||
this->reusable_iovs_.clear();
|
||||
this->reusable_iovs_.reserve(packets.size());
|
||||
uint16_t total_write_len = 0;
|
||||
|
||||
for (const auto &packet : packets) {
|
||||
// Calculate varint sizes for header layout
|
||||
uint8_t size_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(packet.payload_size));
|
||||
uint8_t type_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(packet.message_type));
|
||||
uint8_t total_header_len = 1 + size_varint_len + type_varint_len;
|
||||
|
||||
// Calculate where to start writing the header
|
||||
// The header starts at the latest possible position to minimize unused padding
|
||||
//
|
||||
// Example 1 (small values): total_header_len = 3, header_offset = 6 - 3 = 3
|
||||
// [0-2] - Unused padding
|
||||
// [3] - 0x00 indicator byte
|
||||
// [4] - Payload size varint (1 byte, for sizes 0-127)
|
||||
// [5] - Message type varint (1 byte, for types 0-127)
|
||||
// [6...] - Actual payload data
|
||||
//
|
||||
// Example 2 (medium values): total_header_len = 4, header_offset = 6 - 4 = 2
|
||||
// [0-1] - Unused padding
|
||||
// [2] - 0x00 indicator byte
|
||||
// [3-4] - Payload size varint (2 bytes, for sizes 128-16383)
|
||||
// [5] - Message type varint (1 byte, for types 0-127)
|
||||
// [6...] - Actual payload data
|
||||
//
|
||||
// Example 3 (large values): total_header_len = 6, header_offset = 6 - 6 = 0
|
||||
// [0] - 0x00 indicator byte
|
||||
// [1-3] - Payload size varint (3 bytes, for sizes 16384-2097151)
|
||||
// [4-5] - Message type varint (2 bytes, for types 128-32767)
|
||||
// [6...] - Actual payload data
|
||||
//
|
||||
// The message starts at offset + frame_header_padding_
|
||||
// So we write the header starting at offset + frame_header_padding_ - total_header_len
|
||||
uint8_t *buf_start = buffer_data + packet.offset;
|
||||
uint32_t header_offset = frame_header_padding_ - total_header_len;
|
||||
|
||||
// Write the plaintext header
|
||||
buf_start[header_offset] = 0x00; // indicator
|
||||
|
||||
// Encode varints directly into buffer
|
||||
ProtoVarInt(packet.payload_size).encode_to_buffer_unchecked(buf_start + header_offset + 1, size_varint_len);
|
||||
ProtoVarInt(packet.message_type)
|
||||
.encode_to_buffer_unchecked(buf_start + header_offset + 1 + size_varint_len, type_varint_len);
|
||||
|
||||
// Add iovec for this packet (header + payload)
|
||||
size_t packet_len = static_cast<size_t>(total_header_len + packet.payload_size);
|
||||
this->reusable_iovs_.push_back({buf_start + header_offset, packet_len});
|
||||
total_write_len += packet_len;
|
||||
}
|
||||
|
||||
// Send all packets in one writev call
|
||||
return write_raw_(this->reusable_iovs_.data(), this->reusable_iovs_.size(), total_write_len);
|
||||
}
|
||||
|
||||
#endif // USE_API_PLAINTEXT
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
|
@ -8,17 +8,16 @@
|
||||
|
||||
#include "esphome/core/defines.h"
|
||||
#ifdef USE_API
|
||||
#ifdef USE_API_NOISE
|
||||
#include "noise/protocol.h"
|
||||
#endif
|
||||
|
||||
#include "api_noise_context.h"
|
||||
#include "esphome/components/socket/socket.h"
|
||||
#include "esphome/core/application.h"
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
// uncomment to log raw packets
|
||||
//#define HELPER_LOG_PACKETS
|
||||
|
||||
// Forward declaration
|
||||
struct ClientInfo;
|
||||
|
||||
@ -43,7 +42,6 @@ struct PacketInfo {
|
||||
enum class APIError : uint16_t {
|
||||
OK = 0,
|
||||
WOULD_BLOCK = 1001,
|
||||
BAD_HANDSHAKE_PACKET_LEN = 1002,
|
||||
BAD_INDICATOR = 1003,
|
||||
BAD_DATA_PACKET = 1004,
|
||||
TCP_NODELAY_FAILED = 1005,
|
||||
@ -54,16 +52,19 @@ enum class APIError : uint16_t {
|
||||
BAD_ARG = 1010,
|
||||
SOCKET_READ_FAILED = 1011,
|
||||
SOCKET_WRITE_FAILED = 1012,
|
||||
OUT_OF_MEMORY = 1018,
|
||||
CONNECTION_CLOSED = 1022,
|
||||
#ifdef USE_API_NOISE
|
||||
BAD_HANDSHAKE_PACKET_LEN = 1002,
|
||||
HANDSHAKESTATE_READ_FAILED = 1013,
|
||||
HANDSHAKESTATE_WRITE_FAILED = 1014,
|
||||
HANDSHAKESTATE_BAD_STATE = 1015,
|
||||
CIPHERSTATE_DECRYPT_FAILED = 1016,
|
||||
CIPHERSTATE_ENCRYPT_FAILED = 1017,
|
||||
OUT_OF_MEMORY = 1018,
|
||||
HANDSHAKESTATE_SETUP_FAILED = 1019,
|
||||
HANDSHAKESTATE_SPLIT_FAILED = 1020,
|
||||
BAD_HANDSHAKE_ERROR_BYTE = 1021,
|
||||
CONNECTION_CLOSED = 1022,
|
||||
#endif
|
||||
};
|
||||
|
||||
const char *api_error_to_str(APIError err);
|
||||
@ -183,109 +184,7 @@ class APIFrameHelper {
|
||||
APIError handle_socket_read_result_(ssize_t received);
|
||||
};
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
class APINoiseFrameHelper : public APIFrameHelper {
|
||||
public:
|
||||
APINoiseFrameHelper(std::unique_ptr<socket::Socket> socket, std::shared_ptr<APINoiseContext> ctx,
|
||||
const ClientInfo *client_info)
|
||||
: APIFrameHelper(std::move(socket), client_info), ctx_(std::move(ctx)) {
|
||||
// Noise header structure:
|
||||
// Pos 0: indicator (0x01)
|
||||
// Pos 1-2: encrypted payload size (16-bit big-endian)
|
||||
// Pos 3-6: encrypted type (16-bit) + data_len (16-bit)
|
||||
// Pos 7+: actual payload data
|
||||
frame_header_padding_ = 7;
|
||||
}
|
||||
~APINoiseFrameHelper() override;
|
||||
APIError init() override;
|
||||
APIError loop() override;
|
||||
APIError read_packet(ReadPacketBuffer *buffer) override;
|
||||
APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) override;
|
||||
APIError write_protobuf_packets(ProtoWriteBuffer buffer, std::span<const PacketInfo> packets) override;
|
||||
// Get the frame header padding required by this protocol
|
||||
uint8_t frame_header_padding() override { return frame_header_padding_; }
|
||||
// Get the frame footer size required by this protocol
|
||||
uint8_t frame_footer_size() override { return frame_footer_size_; }
|
||||
|
||||
protected:
|
||||
APIError state_action_();
|
||||
APIError try_read_frame_(std::vector<uint8_t> *frame);
|
||||
APIError write_frame_(const uint8_t *data, uint16_t len);
|
||||
APIError init_handshake_();
|
||||
APIError check_handshake_finished_();
|
||||
void send_explicit_handshake_reject_(const std::string &reason);
|
||||
APIError handle_handshake_frame_error_(APIError aerr);
|
||||
APIError handle_noise_error_(int err, const char *func_name, APIError api_err);
|
||||
|
||||
// Pointers first (4 bytes each)
|
||||
NoiseHandshakeState *handshake_{nullptr};
|
||||
NoiseCipherState *send_cipher_{nullptr};
|
||||
NoiseCipherState *recv_cipher_{nullptr};
|
||||
|
||||
// Shared pointer (8 bytes on 32-bit = 4 bytes control block pointer + 4 bytes object pointer)
|
||||
std::shared_ptr<APINoiseContext> ctx_;
|
||||
|
||||
// Vector (12 bytes on 32-bit)
|
||||
std::vector<uint8_t> prologue_;
|
||||
|
||||
// NoiseProtocolId (size depends on implementation)
|
||||
NoiseProtocolId nid_;
|
||||
|
||||
// Group small types together
|
||||
// Fixed-size header buffer for noise protocol:
|
||||
// 1 byte for indicator + 2 bytes for message size (16-bit value, not varint)
|
||||
// Note: Maximum message size is UINT16_MAX (65535), with a limit of 128 bytes during handshake phase
|
||||
uint8_t rx_header_buf_[3];
|
||||
uint8_t rx_header_buf_len_ = 0;
|
||||
// 4 bytes total, no padding
|
||||
};
|
||||
#endif // USE_API_NOISE
|
||||
|
||||
#ifdef USE_API_PLAINTEXT
|
||||
class APIPlaintextFrameHelper : public APIFrameHelper {
|
||||
public:
|
||||
APIPlaintextFrameHelper(std::unique_ptr<socket::Socket> socket, const ClientInfo *client_info)
|
||||
: APIFrameHelper(std::move(socket), client_info) {
|
||||
// Plaintext header structure (worst case):
|
||||
// Pos 0: indicator (0x00)
|
||||
// Pos 1-3: payload size varint (up to 3 bytes)
|
||||
// Pos 4-5: message type varint (up to 2 bytes)
|
||||
// Pos 6+: actual payload data
|
||||
frame_header_padding_ = 6;
|
||||
}
|
||||
~APIPlaintextFrameHelper() override = default;
|
||||
APIError init() override;
|
||||
APIError loop() override;
|
||||
APIError read_packet(ReadPacketBuffer *buffer) override;
|
||||
APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) override;
|
||||
APIError write_protobuf_packets(ProtoWriteBuffer buffer, std::span<const PacketInfo> packets) override;
|
||||
uint8_t frame_header_padding() override { return frame_header_padding_; }
|
||||
// Get the frame footer size required by this protocol
|
||||
uint8_t frame_footer_size() override { return frame_footer_size_; }
|
||||
|
||||
protected:
|
||||
APIError try_read_frame_(std::vector<uint8_t> *frame);
|
||||
|
||||
// Group 2-byte aligned types
|
||||
uint16_t rx_header_parsed_type_ = 0;
|
||||
uint16_t rx_header_parsed_len_ = 0;
|
||||
|
||||
// Group 1-byte types together
|
||||
// Fixed-size header buffer for plaintext protocol:
|
||||
// We now store the indicator byte + the two varints.
|
||||
// To match noise protocol's maximum message size (UINT16_MAX = 65535), we need:
|
||||
// 1 byte for indicator + 3 bytes for message size varint (supports up to 2097151) + 2 bytes for message type varint
|
||||
//
|
||||
// While varints could theoretically be up to 10 bytes each for 64-bit values,
|
||||
// attempting to process messages with headers that large would likely crash the
|
||||
// ESP32 due to memory constraints.
|
||||
uint8_t rx_header_buf_[6]; // 1 byte indicator + 5 bytes for varints (3 for size + 2 for type)
|
||||
uint8_t rx_header_buf_pos_ = 0;
|
||||
bool rx_header_parsed_ = false;
|
||||
// 8 bytes total, no padding needed
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
#endif
|
||||
|
||||
#endif // USE_API
|
||||
|
577
esphome/components/api/api_frame_helper_noise.cpp
Normal file
577
esphome/components/api/api_frame_helper_noise.cpp
Normal file
@ -0,0 +1,577 @@
|
||||
#include "api_frame_helper_noise.h"
|
||||
#ifdef USE_API
|
||||
#ifdef USE_API_NOISE
|
||||
#include "api_connection.h" // For ClientInfo struct
|
||||
#include "esphome/core/application.h"
|
||||
#include "esphome/core/hal.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
#include "esphome/core/log.h"
|
||||
#include "proto.h"
|
||||
#include <cstring>
|
||||
#include <cinttypes>
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
static const char *const TAG = "api.noise";
|
||||
static const char *const PROLOGUE_INIT = "NoiseAPIInit";
|
||||
|
||||
#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->client_info_->get_combined_info().c_str(), ##__VA_ARGS__)
|
||||
|
||||
#ifdef HELPER_LOG_PACKETS
|
||||
#define LOG_PACKET_RECEIVED(buffer) ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(buffer).c_str())
|
||||
#define LOG_PACKET_SENDING(data, len) ESP_LOGVV(TAG, "Sending raw: %s", format_hex_pretty(data, len).c_str())
|
||||
#else
|
||||
#define LOG_PACKET_RECEIVED(buffer) ((void) 0)
|
||||
#define LOG_PACKET_SENDING(data, len) ((void) 0)
|
||||
#endif
|
||||
|
||||
/// Convert a noise error code to a readable error
|
||||
std::string noise_err_to_str(int err) {
|
||||
if (err == NOISE_ERROR_NO_MEMORY)
|
||||
return "NO_MEMORY";
|
||||
if (err == NOISE_ERROR_UNKNOWN_ID)
|
||||
return "UNKNOWN_ID";
|
||||
if (err == NOISE_ERROR_UNKNOWN_NAME)
|
||||
return "UNKNOWN_NAME";
|
||||
if (err == NOISE_ERROR_MAC_FAILURE)
|
||||
return "MAC_FAILURE";
|
||||
if (err == NOISE_ERROR_NOT_APPLICABLE)
|
||||
return "NOT_APPLICABLE";
|
||||
if (err == NOISE_ERROR_SYSTEM)
|
||||
return "SYSTEM";
|
||||
if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED)
|
||||
return "REMOTE_KEY_REQUIRED";
|
||||
if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED)
|
||||
return "LOCAL_KEY_REQUIRED";
|
||||
if (err == NOISE_ERROR_PSK_REQUIRED)
|
||||
return "PSK_REQUIRED";
|
||||
if (err == NOISE_ERROR_INVALID_LENGTH)
|
||||
return "INVALID_LENGTH";
|
||||
if (err == NOISE_ERROR_INVALID_PARAM)
|
||||
return "INVALID_PARAM";
|
||||
if (err == NOISE_ERROR_INVALID_STATE)
|
||||
return "INVALID_STATE";
|
||||
if (err == NOISE_ERROR_INVALID_NONCE)
|
||||
return "INVALID_NONCE";
|
||||
if (err == NOISE_ERROR_INVALID_PRIVATE_KEY)
|
||||
return "INVALID_PRIVATE_KEY";
|
||||
if (err == NOISE_ERROR_INVALID_PUBLIC_KEY)
|
||||
return "INVALID_PUBLIC_KEY";
|
||||
if (err == NOISE_ERROR_INVALID_FORMAT)
|
||||
return "INVALID_FORMAT";
|
||||
if (err == NOISE_ERROR_INVALID_SIGNATURE)
|
||||
return "INVALID_SIGNATURE";
|
||||
return to_string(err);
|
||||
}
|
||||
|
||||
/// Initialize the frame helper, returns OK if successful.
|
||||
APIError APINoiseFrameHelper::init() {
|
||||
APIError err = init_common_();
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
// init prologue
|
||||
prologue_.insert(prologue_.end(), PROLOGUE_INIT, PROLOGUE_INIT + strlen(PROLOGUE_INIT));
|
||||
|
||||
state_ = State::CLIENT_HELLO;
|
||||
return APIError::OK;
|
||||
}
|
||||
// Helper for handling handshake frame errors
|
||||
APIError APINoiseFrameHelper::handle_handshake_frame_error_(APIError aerr) {
|
||||
if (aerr == APIError::BAD_INDICATOR) {
|
||||
send_explicit_handshake_reject_("Bad indicator byte");
|
||||
} else if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) {
|
||||
send_explicit_handshake_reject_("Bad handshake packet len");
|
||||
}
|
||||
return aerr;
|
||||
}
|
||||
|
||||
// Helper for handling noise library errors
|
||||
APIError APINoiseFrameHelper::handle_noise_error_(int err, const char *func_name, APIError api_err) {
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("%s failed: %s", func_name, noise_err_to_str(err).c_str());
|
||||
return api_err;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
/// Run through handshake messages (if in that phase)
|
||||
APIError APINoiseFrameHelper::loop() {
|
||||
// During handshake phase, process as many actions as possible until we can't progress
|
||||
// socket_->ready() stays true until next main loop, but state_action() will return
|
||||
// WOULD_BLOCK when no more data is available to read
|
||||
while (state_ != State::DATA && this->socket_->ready()) {
|
||||
APIError err = state_action_();
|
||||
if (err == APIError::WOULD_BLOCK) {
|
||||
break;
|
||||
}
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
// Use base class implementation for buffer sending
|
||||
return APIFrameHelper::loop();
|
||||
}
|
||||
|
||||
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
|
||||
*
|
||||
* @param frame: The struct to hold the frame information in.
|
||||
* msg_start: points to the start of the payload - this pointer is only valid until the next
|
||||
* try_receive_raw_ call
|
||||
*
|
||||
* @return 0 if a full packet is in rx_buf_
|
||||
* @return -1 if error, check errno.
|
||||
*
|
||||
* errno EWOULDBLOCK: Packet could not be read without blocking. Try again later.
|
||||
* errno ENOMEM: Not enough memory for reading packet.
|
||||
* errno API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
|
||||
* errno API_ERROR_HANDSHAKE_PACKET_LEN: Packet too big for this phase.
|
||||
*/
|
||||
APIError APINoiseFrameHelper::try_read_frame_(std::vector<uint8_t> *frame) {
|
||||
if (frame == nullptr) {
|
||||
HELPER_LOG("Bad argument for try_read_frame_");
|
||||
return APIError::BAD_ARG;
|
||||
}
|
||||
|
||||
// read header
|
||||
if (rx_header_buf_len_ < 3) {
|
||||
// no header information yet
|
||||
uint8_t to_read = 3 - rx_header_buf_len_;
|
||||
ssize_t received = this->socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read);
|
||||
APIError err = handle_socket_read_result_(received);
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
rx_header_buf_len_ += static_cast<uint8_t>(received);
|
||||
if (static_cast<uint8_t>(received) != to_read) {
|
||||
// not a full read
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
if (rx_header_buf_[0] != 0x01) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]);
|
||||
return APIError::BAD_INDICATOR;
|
||||
}
|
||||
// header reading done
|
||||
}
|
||||
|
||||
// read body
|
||||
uint16_t msg_size = (((uint16_t) rx_header_buf_[1]) << 8) | rx_header_buf_[2];
|
||||
|
||||
if (state_ != State::DATA && msg_size > 128) {
|
||||
// for handshake message only permit up to 128 bytes
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad packet len for handshake: %d", msg_size);
|
||||
return APIError::BAD_HANDSHAKE_PACKET_LEN;
|
||||
}
|
||||
|
||||
// reserve space for body
|
||||
if (rx_buf_.size() != msg_size) {
|
||||
rx_buf_.resize(msg_size);
|
||||
}
|
||||
|
||||
if (rx_buf_len_ < msg_size) {
|
||||
// more data to read
|
||||
uint16_t to_read = msg_size - rx_buf_len_;
|
||||
ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read);
|
||||
APIError err = handle_socket_read_result_(received);
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
rx_buf_len_ += static_cast<uint16_t>(received);
|
||||
if (static_cast<uint16_t>(received) != to_read) {
|
||||
// not all read
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_PACKET_RECEIVED(rx_buf_);
|
||||
*frame = std::move(rx_buf_);
|
||||
// consume msg
|
||||
rx_buf_ = {};
|
||||
rx_buf_len_ = 0;
|
||||
rx_header_buf_len_ = 0;
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
/** To be called from read/write methods.
|
||||
*
|
||||
* This method runs through the internal handshake methods, if in that state.
|
||||
*
|
||||
* If the handshake is still active when this method returns and a read/write can't take place at
|
||||
* the moment, returns WOULD_BLOCK.
|
||||
* If an error occurred, returns that error. Only returns OK if the transport is ready for data
|
||||
* traffic.
|
||||
*/
|
||||
APIError APINoiseFrameHelper::state_action_() {
|
||||
int err;
|
||||
APIError aerr;
|
||||
if (state_ == State::INITIALIZE) {
|
||||
HELPER_LOG("Bad state for method: %d", (int) state_);
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
if (state_ == State::CLIENT_HELLO) {
|
||||
// waiting for client hello
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK) {
|
||||
return handle_handshake_frame_error_(aerr);
|
||||
}
|
||||
// ignore contents, may be used in future for flags
|
||||
// Reserve space for: existing prologue + 2 size bytes + frame data
|
||||
prologue_.reserve(prologue_.size() + 2 + frame.size());
|
||||
prologue_.push_back((uint8_t) (frame.size() >> 8));
|
||||
prologue_.push_back((uint8_t) frame.size());
|
||||
prologue_.insert(prologue_.end(), frame.begin(), frame.end());
|
||||
|
||||
state_ = State::SERVER_HELLO;
|
||||
}
|
||||
if (state_ == State::SERVER_HELLO) {
|
||||
// send server hello
|
||||
const std::string &name = App.get_name();
|
||||
const std::string &mac = get_mac_address();
|
||||
|
||||
std::vector<uint8_t> msg;
|
||||
// Reserve space for: 1 byte proto + name + null + mac + null
|
||||
msg.reserve(1 + name.size() + 1 + mac.size() + 1);
|
||||
|
||||
// chosen proto
|
||||
msg.push_back(0x01);
|
||||
|
||||
// node name, terminated by null byte
|
||||
const uint8_t *name_ptr = reinterpret_cast<const uint8_t *>(name.c_str());
|
||||
msg.insert(msg.end(), name_ptr, name_ptr + name.size() + 1);
|
||||
// node mac, terminated by null byte
|
||||
const uint8_t *mac_ptr = reinterpret_cast<const uint8_t *>(mac.c_str());
|
||||
msg.insert(msg.end(), mac_ptr, mac_ptr + mac.size() + 1);
|
||||
|
||||
aerr = write_frame_(msg.data(), msg.size());
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
// start handshake
|
||||
aerr = init_handshake_();
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
state_ = State::HANDSHAKE;
|
||||
}
|
||||
if (state_ == State::HANDSHAKE) {
|
||||
int action = noise_handshakestate_get_action(handshake_);
|
||||
if (action == NOISE_ACTION_READ_MESSAGE) {
|
||||
// waiting for handshake msg
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK) {
|
||||
return handle_handshake_frame_error_(aerr);
|
||||
}
|
||||
|
||||
if (frame.empty()) {
|
||||
send_explicit_handshake_reject_("Empty handshake message");
|
||||
return APIError::BAD_HANDSHAKE_ERROR_BYTE;
|
||||
} else if (frame[0] != 0x00) {
|
||||
HELPER_LOG("Bad handshake error byte: %u", frame[0]);
|
||||
send_explicit_handshake_reject_("Bad handshake error byte");
|
||||
return APIError::BAD_HANDSHAKE_ERROR_BYTE;
|
||||
}
|
||||
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_input(mbuf, frame.data() + 1, frame.size() - 1);
|
||||
err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr);
|
||||
if (err != 0) {
|
||||
// Special handling for MAC failure
|
||||
send_explicit_handshake_reject_(err == NOISE_ERROR_MAC_FAILURE ? "Handshake MAC failure" : "Handshake error");
|
||||
return handle_noise_error_(err, "noise_handshakestate_read_message", APIError::HANDSHAKESTATE_READ_FAILED);
|
||||
}
|
||||
|
||||
aerr = check_handshake_finished_();
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
} else if (action == NOISE_ACTION_WRITE_MESSAGE) {
|
||||
uint8_t buffer[65];
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1);
|
||||
|
||||
err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr);
|
||||
APIError aerr_write =
|
||||
handle_noise_error_(err, "noise_handshakestate_write_message", APIError::HANDSHAKESTATE_WRITE_FAILED);
|
||||
if (aerr_write != APIError::OK)
|
||||
return aerr_write;
|
||||
buffer[0] = 0x00; // success
|
||||
|
||||
aerr = write_frame_(buffer, mbuf.size + 1);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
aerr = check_handshake_finished_();
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
} else {
|
||||
// bad state for action
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad action for handshake: %d", action);
|
||||
return APIError::HANDSHAKESTATE_BAD_STATE;
|
||||
}
|
||||
}
|
||||
if (state_ == State::CLOSED || state_ == State::FAILED) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) {
|
||||
std::vector<uint8_t> data;
|
||||
data.resize(reason.length() + 1);
|
||||
data[0] = 0x01; // failure
|
||||
|
||||
// Copy error message in bulk
|
||||
if (!reason.empty()) {
|
||||
std::memcpy(data.data() + 1, reason.c_str(), reason.length());
|
||||
}
|
||||
|
||||
// temporarily remove failed state
|
||||
auto orig_state = state_;
|
||||
state_ = State::EXPLICIT_REJECT;
|
||||
write_frame_(data.data(), data.size());
|
||||
state_ = orig_state;
|
||||
}
|
||||
APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
aerr = state_action_();
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_inout(mbuf, frame.data(), frame.size(), frame.size());
|
||||
err = noise_cipherstate_decrypt(recv_cipher_, &mbuf);
|
||||
APIError decrypt_err = handle_noise_error_(err, "noise_cipherstate_decrypt", APIError::CIPHERSTATE_DECRYPT_FAILED);
|
||||
if (decrypt_err != APIError::OK)
|
||||
return decrypt_err;
|
||||
|
||||
uint16_t msg_size = mbuf.size;
|
||||
uint8_t *msg_data = frame.data();
|
||||
if (msg_size < 4) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad data packet: size %d too short", msg_size);
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
|
||||
uint16_t type = (((uint16_t) msg_data[0]) << 8) | msg_data[1];
|
||||
uint16_t data_len = (((uint16_t) msg_data[2]) << 8) | msg_data[3];
|
||||
if (data_len > msg_size - 4) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad data packet: data_len %u greater than msg_size %u", data_len, msg_size);
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
|
||||
buffer->container = std::move(frame);
|
||||
buffer->data_offset = 4;
|
||||
buffer->data_len = data_len;
|
||||
buffer->type = type;
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APINoiseFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) {
|
||||
// Resize to include MAC space (required for Noise encryption)
|
||||
buffer.get_buffer()->resize(buffer.get_buffer()->size() + frame_footer_size_);
|
||||
PacketInfo packet{type, 0,
|
||||
static_cast<uint16_t>(buffer.get_buffer()->size() - frame_header_padding_ - frame_footer_size_)};
|
||||
return write_protobuf_packets(buffer, std::span<const PacketInfo>(&packet, 1));
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, std::span<const PacketInfo> packets) {
|
||||
APIError aerr = state_action_();
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
if (packets.empty()) {
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
|
||||
uint8_t *buffer_data = raw_buffer->data(); // Cache buffer pointer
|
||||
|
||||
this->reusable_iovs_.clear();
|
||||
this->reusable_iovs_.reserve(packets.size());
|
||||
uint16_t total_write_len = 0;
|
||||
|
||||
// We need to encrypt each packet in place
|
||||
for (const auto &packet : packets) {
|
||||
// The buffer already has padding at offset
|
||||
uint8_t *buf_start = buffer_data + packet.offset;
|
||||
|
||||
// Write noise header
|
||||
buf_start[0] = 0x01; // indicator
|
||||
// buf_start[1], buf_start[2] to be set after encryption
|
||||
|
||||
// Write message header (to be encrypted)
|
||||
const uint8_t msg_offset = 3;
|
||||
buf_start[msg_offset] = static_cast<uint8_t>(packet.message_type >> 8); // type high byte
|
||||
buf_start[msg_offset + 1] = static_cast<uint8_t>(packet.message_type); // type low byte
|
||||
buf_start[msg_offset + 2] = static_cast<uint8_t>(packet.payload_size >> 8); // data_len high byte
|
||||
buf_start[msg_offset + 3] = static_cast<uint8_t>(packet.payload_size); // data_len low byte
|
||||
// payload data is already in the buffer starting at offset + 7
|
||||
|
||||
// Make sure we have space for MAC
|
||||
// The buffer should already have been sized appropriately
|
||||
|
||||
// Encrypt the message in place
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_inout(mbuf, buf_start + msg_offset, 4 + packet.payload_size,
|
||||
4 + packet.payload_size + frame_footer_size_);
|
||||
|
||||
int err = noise_cipherstate_encrypt(send_cipher_, &mbuf);
|
||||
APIError aerr = handle_noise_error_(err, "noise_cipherstate_encrypt", APIError::CIPHERSTATE_ENCRYPT_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
// Fill in the encrypted size
|
||||
buf_start[1] = static_cast<uint8_t>(mbuf.size >> 8);
|
||||
buf_start[2] = static_cast<uint8_t>(mbuf.size);
|
||||
|
||||
// Add iovec for this encrypted packet
|
||||
size_t packet_len = static_cast<size_t>(3 + mbuf.size); // indicator + size + encrypted data
|
||||
this->reusable_iovs_.push_back({buf_start, packet_len});
|
||||
total_write_len += packet_len;
|
||||
}
|
||||
|
||||
// Send all encrypted packets in one writev call
|
||||
return this->write_raw_(this->reusable_iovs_.data(), this->reusable_iovs_.size(), total_write_len);
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, uint16_t len) {
|
||||
uint8_t header[3];
|
||||
header[0] = 0x01; // indicator
|
||||
header[1] = (uint8_t) (len >> 8);
|
||||
header[2] = (uint8_t) len;
|
||||
|
||||
struct iovec iov[2];
|
||||
iov[0].iov_base = header;
|
||||
iov[0].iov_len = 3;
|
||||
if (len == 0) {
|
||||
return this->write_raw_(iov, 1, 3); // Just header
|
||||
}
|
||||
iov[1].iov_base = const_cast<uint8_t *>(data);
|
||||
iov[1].iov_len = len;
|
||||
|
||||
return this->write_raw_(iov, 2, 3 + len); // Header + data
|
||||
}
|
||||
|
||||
/** Initiate the data structures for the handshake.
|
||||
*
|
||||
* @return 0 on success, -1 on error (check errno)
|
||||
*/
|
||||
APIError APINoiseFrameHelper::init_handshake_() {
|
||||
int err;
|
||||
memset(&nid_, 0, sizeof(nid_));
|
||||
// const char *proto = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
|
||||
// err = noise_protocol_name_to_id(&nid_, proto, strlen(proto));
|
||||
nid_.pattern_id = NOISE_PATTERN_NN;
|
||||
nid_.cipher_id = NOISE_CIPHER_CHACHAPOLY;
|
||||
nid_.dh_id = NOISE_DH_CURVE25519;
|
||||
nid_.prefix_id = NOISE_PREFIX_STANDARD;
|
||||
nid_.hybrid_id = NOISE_DH_NONE;
|
||||
nid_.hash_id = NOISE_HASH_SHA256;
|
||||
nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0;
|
||||
|
||||
err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER);
|
||||
APIError aerr = handle_noise_error_(err, "noise_handshakestate_new_by_id", APIError::HANDSHAKESTATE_SETUP_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
const auto &psk = ctx_->get_psk();
|
||||
err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size());
|
||||
aerr = handle_noise_error_(err, "noise_handshakestate_set_pre_shared_key", APIError::HANDSHAKESTATE_SETUP_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size());
|
||||
aerr = handle_noise_error_(err, "noise_handshakestate_set_prologue", APIError::HANDSHAKESTATE_SETUP_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
// set_prologue copies it into handshakestate, so we can get rid of it now
|
||||
prologue_ = {};
|
||||
|
||||
err = noise_handshakestate_start(handshake_);
|
||||
aerr = handle_noise_error_(err, "noise_handshakestate_start", APIError::HANDSHAKESTATE_SETUP_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::check_handshake_finished_() {
|
||||
assert(state_ == State::HANDSHAKE);
|
||||
|
||||
int action = noise_handshakestate_get_action(handshake_);
|
||||
if (action == NOISE_ACTION_READ_MESSAGE || action == NOISE_ACTION_WRITE_MESSAGE)
|
||||
return APIError::OK;
|
||||
if (action != NOISE_ACTION_SPLIT) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad action for handshake: %d", action);
|
||||
return APIError::HANDSHAKESTATE_BAD_STATE;
|
||||
}
|
||||
int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_);
|
||||
APIError aerr = handle_noise_error_(err, "noise_handshakestate_split", APIError::HANDSHAKESTATE_SPLIT_FAILED);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
frame_footer_size_ = noise_cipherstate_get_mac_length(send_cipher_);
|
||||
|
||||
HELPER_LOG("Handshake complete!");
|
||||
noise_handshakestate_free(handshake_);
|
||||
handshake_ = nullptr;
|
||||
state_ = State::DATA;
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
APINoiseFrameHelper::~APINoiseFrameHelper() {
|
||||
if (handshake_ != nullptr) {
|
||||
noise_handshakestate_free(handshake_);
|
||||
handshake_ = nullptr;
|
||||
}
|
||||
if (send_cipher_ != nullptr) {
|
||||
noise_cipherstate_free(send_cipher_);
|
||||
send_cipher_ = nullptr;
|
||||
}
|
||||
if (recv_cipher_ != nullptr) {
|
||||
noise_cipherstate_free(recv_cipher_);
|
||||
recv_cipher_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
// declare how noise generates random bytes (here with a good HWRNG based on the RF system)
|
||||
void noise_rand_bytes(void *output, size_t len) {
|
||||
if (!esphome::random_bytes(reinterpret_cast<uint8_t *>(output), len)) {
|
||||
ESP_LOGE(TAG, "Acquiring random bytes failed; rebooting");
|
||||
arch_restart();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
#endif // USE_API_NOISE
|
||||
#endif // USE_API
|
70
esphome/components/api/api_frame_helper_noise.h
Normal file
70
esphome/components/api/api_frame_helper_noise.h
Normal file
@ -0,0 +1,70 @@
|
||||
#pragma once
|
||||
#include "api_frame_helper.h"
|
||||
#ifdef USE_API
|
||||
#ifdef USE_API_NOISE
|
||||
#include "noise/protocol.h"
|
||||
#include "api_noise_context.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
class APINoiseFrameHelper : public APIFrameHelper {
|
||||
public:
|
||||
APINoiseFrameHelper(std::unique_ptr<socket::Socket> socket, std::shared_ptr<APINoiseContext> ctx,
|
||||
const ClientInfo *client_info)
|
||||
: APIFrameHelper(std::move(socket), client_info), ctx_(std::move(ctx)) {
|
||||
// Noise header structure:
|
||||
// Pos 0: indicator (0x01)
|
||||
// Pos 1-2: encrypted payload size (16-bit big-endian)
|
||||
// Pos 3-6: encrypted type (16-bit) + data_len (16-bit)
|
||||
// Pos 7+: actual payload data
|
||||
frame_header_padding_ = 7;
|
||||
}
|
||||
~APINoiseFrameHelper() override;
|
||||
APIError init() override;
|
||||
APIError loop() override;
|
||||
APIError read_packet(ReadPacketBuffer *buffer) override;
|
||||
APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) override;
|
||||
APIError write_protobuf_packets(ProtoWriteBuffer buffer, std::span<const PacketInfo> packets) override;
|
||||
// Get the frame header padding required by this protocol
|
||||
uint8_t frame_header_padding() override { return frame_header_padding_; }
|
||||
// Get the frame footer size required by this protocol
|
||||
uint8_t frame_footer_size() override { return frame_footer_size_; }
|
||||
|
||||
protected:
|
||||
APIError state_action_();
|
||||
APIError try_read_frame_(std::vector<uint8_t> *frame);
|
||||
APIError write_frame_(const uint8_t *data, uint16_t len);
|
||||
APIError init_handshake_();
|
||||
APIError check_handshake_finished_();
|
||||
void send_explicit_handshake_reject_(const std::string &reason);
|
||||
APIError handle_handshake_frame_error_(APIError aerr);
|
||||
APIError handle_noise_error_(int err, const char *func_name, APIError api_err);
|
||||
|
||||
// Pointers first (4 bytes each)
|
||||
NoiseHandshakeState *handshake_{nullptr};
|
||||
NoiseCipherState *send_cipher_{nullptr};
|
||||
NoiseCipherState *recv_cipher_{nullptr};
|
||||
|
||||
// Shared pointer (8 bytes on 32-bit = 4 bytes control block pointer + 4 bytes object pointer)
|
||||
std::shared_ptr<APINoiseContext> ctx_;
|
||||
|
||||
// Vector (12 bytes on 32-bit)
|
||||
std::vector<uint8_t> prologue_;
|
||||
|
||||
// NoiseProtocolId (size depends on implementation)
|
||||
NoiseProtocolId nid_;
|
||||
|
||||
// Group small types together
|
||||
// Fixed-size header buffer for noise protocol:
|
||||
// 1 byte for indicator + 2 bytes for message size (16-bit value, not varint)
|
||||
// Note: Maximum message size is UINT16_MAX (65535), with a limit of 128 bytes during handshake phase
|
||||
uint8_t rx_header_buf_[3];
|
||||
uint8_t rx_header_buf_len_ = 0;
|
||||
// 4 bytes total, no padding
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
#endif // USE_API_NOISE
|
||||
#endif // USE_API
|
292
esphome/components/api/api_frame_helper_plaintext.cpp
Normal file
292
esphome/components/api/api_frame_helper_plaintext.cpp
Normal file
@ -0,0 +1,292 @@
|
||||
#include "api_frame_helper_plaintext.h"
|
||||
#ifdef USE_API
|
||||
#ifdef USE_API_PLAINTEXT
|
||||
#include "api_connection.h" // For ClientInfo struct
|
||||
#include "esphome/core/application.h"
|
||||
#include "esphome/core/hal.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
#include "esphome/core/log.h"
|
||||
#include "proto.h"
|
||||
#include <cstring>
|
||||
#include <cinttypes>
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
static const char *const TAG = "api.plaintext";
|
||||
|
||||
#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->client_info_->get_combined_info().c_str(), ##__VA_ARGS__)
|
||||
|
||||
#ifdef HELPER_LOG_PACKETS
|
||||
#define LOG_PACKET_RECEIVED(buffer) ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(buffer).c_str())
|
||||
#define LOG_PACKET_SENDING(data, len) ESP_LOGVV(TAG, "Sending raw: %s", format_hex_pretty(data, len).c_str())
|
||||
#else
|
||||
#define LOG_PACKET_RECEIVED(buffer) ((void) 0)
|
||||
#define LOG_PACKET_SENDING(data, len) ((void) 0)
|
||||
#endif
|
||||
|
||||
/// Initialize the frame helper, returns OK if successful.
|
||||
APIError APIPlaintextFrameHelper::init() {
|
||||
APIError err = init_common_();
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
state_ = State::DATA;
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::loop() {
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
// Use base class implementation for buffer sending
|
||||
return APIFrameHelper::loop();
|
||||
}
|
||||
|
||||
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
|
||||
*
|
||||
* @param frame: The struct to hold the frame information in.
|
||||
* msg: store the parsed frame in that struct
|
||||
*
|
||||
* @return See APIError
|
||||
*
|
||||
* error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
|
||||
*/
|
||||
APIError APIPlaintextFrameHelper::try_read_frame_(std::vector<uint8_t> *frame) {
|
||||
if (frame == nullptr) {
|
||||
HELPER_LOG("Bad argument for try_read_frame_");
|
||||
return APIError::BAD_ARG;
|
||||
}
|
||||
|
||||
// read header
|
||||
while (!rx_header_parsed_) {
|
||||
// Now that we know when the socket is ready, we can read up to 3 bytes
|
||||
// into the rx_header_buf_ before we have to switch back to reading
|
||||
// one byte at a time to ensure we don't read past the message and
|
||||
// into the next one.
|
||||
|
||||
// Read directly into rx_header_buf_ at the current position
|
||||
// Try to get to at least 3 bytes total (indicator + 2 varint bytes), then read one byte at a time
|
||||
ssize_t received =
|
||||
this->socket_->read(&rx_header_buf_[rx_header_buf_pos_], rx_header_buf_pos_ < 3 ? 3 - rx_header_buf_pos_ : 1);
|
||||
APIError err = handle_socket_read_result_(received);
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
// If this was the first read, validate the indicator byte
|
||||
if (rx_header_buf_pos_ == 0 && received > 0) {
|
||||
if (rx_header_buf_[0] != 0x00) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]);
|
||||
return APIError::BAD_INDICATOR;
|
||||
}
|
||||
}
|
||||
|
||||
rx_header_buf_pos_ += received;
|
||||
|
||||
// Check for buffer overflow
|
||||
if (rx_header_buf_pos_ >= sizeof(rx_header_buf_)) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Header buffer overflow");
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
|
||||
// Need at least 3 bytes total (indicator + 2 varint bytes) before trying to parse
|
||||
if (rx_header_buf_pos_ < 3) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// At this point, we have at least 3 bytes total:
|
||||
// - Validated indicator byte (0x00) stored at position 0
|
||||
// - At least 2 bytes in the buffer for the varints
|
||||
// Buffer layout:
|
||||
// [0]: indicator byte (0x00)
|
||||
// [1-3]: Message size varint (variable length)
|
||||
// - 2 bytes would only allow up to 16383, which is less than noise's UINT16_MAX (65535)
|
||||
// - 3 bytes allows up to 2097151, ensuring we support at least as much as noise
|
||||
// [2-5]: Message type varint (variable length)
|
||||
// We now attempt to parse both varints. If either is incomplete,
|
||||
// we'll continue reading more bytes.
|
||||
|
||||
// Skip indicator byte at position 0
|
||||
uint8_t varint_pos = 1;
|
||||
uint32_t consumed = 0;
|
||||
|
||||
auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[varint_pos], rx_header_buf_pos_ - varint_pos, &consumed);
|
||||
if (!msg_size_varint.has_value()) {
|
||||
// not enough data there yet
|
||||
continue;
|
||||
}
|
||||
|
||||
if (msg_size_varint->as_uint32() > std::numeric_limits<uint16_t>::max()) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad packet: message size %" PRIu32 " exceeds maximum %u", msg_size_varint->as_uint32(),
|
||||
std::numeric_limits<uint16_t>::max());
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
rx_header_parsed_len_ = msg_size_varint->as_uint16();
|
||||
|
||||
// Move to next varint position
|
||||
varint_pos += consumed;
|
||||
|
||||
auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[varint_pos], rx_header_buf_pos_ - varint_pos, &consumed);
|
||||
if (!msg_type_varint.has_value()) {
|
||||
// not enough data there yet
|
||||
continue;
|
||||
}
|
||||
if (msg_type_varint->as_uint32() > std::numeric_limits<uint16_t>::max()) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad packet: message type %" PRIu32 " exceeds maximum %u", msg_type_varint->as_uint32(),
|
||||
std::numeric_limits<uint16_t>::max());
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
rx_header_parsed_type_ = msg_type_varint->as_uint16();
|
||||
rx_header_parsed_ = true;
|
||||
}
|
||||
// header reading done
|
||||
|
||||
// reserve space for body
|
||||
if (rx_buf_.size() != rx_header_parsed_len_) {
|
||||
rx_buf_.resize(rx_header_parsed_len_);
|
||||
}
|
||||
|
||||
if (rx_buf_len_ < rx_header_parsed_len_) {
|
||||
// more data to read
|
||||
uint16_t to_read = rx_header_parsed_len_ - rx_buf_len_;
|
||||
ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read);
|
||||
APIError err = handle_socket_read_result_(received);
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
rx_buf_len_ += static_cast<uint16_t>(received);
|
||||
if (static_cast<uint16_t>(received) != to_read) {
|
||||
// not all read
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_PACKET_RECEIVED(rx_buf_);
|
||||
*frame = std::move(rx_buf_);
|
||||
// consume msg
|
||||
rx_buf_ = {};
|
||||
rx_buf_len_ = 0;
|
||||
rx_header_buf_pos_ = 0;
|
||||
rx_header_parsed_ = false;
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
APIError aerr;
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK) {
|
||||
if (aerr == APIError::BAD_INDICATOR) {
|
||||
// Make sure to tell the remote that we don't
|
||||
// understand the indicator byte so it knows
|
||||
// we do not support it.
|
||||
struct iovec iov[1];
|
||||
// The \x00 first byte is the marker for plaintext.
|
||||
//
|
||||
// The remote will know how to handle the indicator byte,
|
||||
// but it likely won't understand the rest of the message.
|
||||
//
|
||||
// We must send at least 3 bytes to be read, so we add
|
||||
// a message after the indicator byte to ensures its long
|
||||
// enough and can aid in debugging.
|
||||
const char msg[] = "\x00"
|
||||
"Bad indicator byte";
|
||||
iov[0].iov_base = (void *) msg;
|
||||
iov[0].iov_len = 19;
|
||||
this->write_raw_(iov, 1, 19);
|
||||
}
|
||||
return aerr;
|
||||
}
|
||||
|
||||
buffer->container = std::move(frame);
|
||||
buffer->data_offset = 0;
|
||||
buffer->data_len = rx_header_parsed_len_;
|
||||
buffer->type = rx_header_parsed_type_;
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) {
|
||||
PacketInfo packet{type, 0, static_cast<uint16_t>(buffer.get_buffer()->size() - frame_header_padding_)};
|
||||
return write_protobuf_packets(buffer, std::span<const PacketInfo>(&packet, 1));
|
||||
}
|
||||
|
||||
APIError APIPlaintextFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, std::span<const PacketInfo> packets) {
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
|
||||
if (packets.empty()) {
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
|
||||
uint8_t *buffer_data = raw_buffer->data(); // Cache buffer pointer
|
||||
|
||||
this->reusable_iovs_.clear();
|
||||
this->reusable_iovs_.reserve(packets.size());
|
||||
uint16_t total_write_len = 0;
|
||||
|
||||
for (const auto &packet : packets) {
|
||||
// Calculate varint sizes for header layout
|
||||
uint8_t size_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(packet.payload_size));
|
||||
uint8_t type_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(packet.message_type));
|
||||
uint8_t total_header_len = 1 + size_varint_len + type_varint_len;
|
||||
|
||||
// Calculate where to start writing the header
|
||||
// The header starts at the latest possible position to minimize unused padding
|
||||
//
|
||||
// Example 1 (small values): total_header_len = 3, header_offset = 6 - 3 = 3
|
||||
// [0-2] - Unused padding
|
||||
// [3] - 0x00 indicator byte
|
||||
// [4] - Payload size varint (1 byte, for sizes 0-127)
|
||||
// [5] - Message type varint (1 byte, for types 0-127)
|
||||
// [6...] - Actual payload data
|
||||
//
|
||||
// Example 2 (medium values): total_header_len = 4, header_offset = 6 - 4 = 2
|
||||
// [0-1] - Unused padding
|
||||
// [2] - 0x00 indicator byte
|
||||
// [3-4] - Payload size varint (2 bytes, for sizes 128-16383)
|
||||
// [5] - Message type varint (1 byte, for types 0-127)
|
||||
// [6...] - Actual payload data
|
||||
//
|
||||
// Example 3 (large values): total_header_len = 6, header_offset = 6 - 6 = 0
|
||||
// [0] - 0x00 indicator byte
|
||||
// [1-3] - Payload size varint (3 bytes, for sizes 16384-2097151)
|
||||
// [4-5] - Message type varint (2 bytes, for types 128-32767)
|
||||
// [6...] - Actual payload data
|
||||
//
|
||||
// The message starts at offset + frame_header_padding_
|
||||
// So we write the header starting at offset + frame_header_padding_ - total_header_len
|
||||
uint8_t *buf_start = buffer_data + packet.offset;
|
||||
uint32_t header_offset = frame_header_padding_ - total_header_len;
|
||||
|
||||
// Write the plaintext header
|
||||
buf_start[header_offset] = 0x00; // indicator
|
||||
|
||||
// Encode varints directly into buffer
|
||||
ProtoVarInt(packet.payload_size).encode_to_buffer_unchecked(buf_start + header_offset + 1, size_varint_len);
|
||||
ProtoVarInt(packet.message_type)
|
||||
.encode_to_buffer_unchecked(buf_start + header_offset + 1 + size_varint_len, type_varint_len);
|
||||
|
||||
// Add iovec for this packet (header + payload)
|
||||
size_t packet_len = static_cast<size_t>(total_header_len + packet.payload_size);
|
||||
this->reusable_iovs_.push_back({buf_start + header_offset, packet_len});
|
||||
total_write_len += packet_len;
|
||||
}
|
||||
|
||||
// Send all packets in one writev call
|
||||
return write_raw_(this->reusable_iovs_.data(), this->reusable_iovs_.size(), total_write_len);
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
#endif // USE_API_PLAINTEXT
|
||||
#endif // USE_API
|
55
esphome/components/api/api_frame_helper_plaintext.h
Normal file
55
esphome/components/api/api_frame_helper_plaintext.h
Normal file
@ -0,0 +1,55 @@
|
||||
#pragma once
|
||||
#include "api_frame_helper.h"
|
||||
#ifdef USE_API
|
||||
#ifdef USE_API_PLAINTEXT
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
class APIPlaintextFrameHelper : public APIFrameHelper {
|
||||
public:
|
||||
APIPlaintextFrameHelper(std::unique_ptr<socket::Socket> socket, const ClientInfo *client_info)
|
||||
: APIFrameHelper(std::move(socket), client_info) {
|
||||
// Plaintext header structure (worst case):
|
||||
// Pos 0: indicator (0x00)
|
||||
// Pos 1-3: payload size varint (up to 3 bytes)
|
||||
// Pos 4-5: message type varint (up to 2 bytes)
|
||||
// Pos 6+: actual payload data
|
||||
frame_header_padding_ = 6;
|
||||
}
|
||||
~APIPlaintextFrameHelper() override = default;
|
||||
APIError init() override;
|
||||
APIError loop() override;
|
||||
APIError read_packet(ReadPacketBuffer *buffer) override;
|
||||
APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) override;
|
||||
APIError write_protobuf_packets(ProtoWriteBuffer buffer, std::span<const PacketInfo> packets) override;
|
||||
uint8_t frame_header_padding() override { return frame_header_padding_; }
|
||||
// Get the frame footer size required by this protocol
|
||||
uint8_t frame_footer_size() override { return frame_footer_size_; }
|
||||
|
||||
protected:
|
||||
APIError try_read_frame_(std::vector<uint8_t> *frame);
|
||||
|
||||
// Group 2-byte aligned types
|
||||
uint16_t rx_header_parsed_type_ = 0;
|
||||
uint16_t rx_header_parsed_len_ = 0;
|
||||
|
||||
// Group 1-byte types together
|
||||
// Fixed-size header buffer for plaintext protocol:
|
||||
// We now store the indicator byte + the two varints.
|
||||
// To match noise protocol's maximum message size (UINT16_MAX = 65535), we need:
|
||||
// 1 byte for indicator + 3 bytes for message size varint (supports up to 2097151) + 2 bytes for message type varint
|
||||
//
|
||||
// While varints could theoretically be up to 10 bytes each for 64-bit values,
|
||||
// attempting to process messages with headers that large would likely crash the
|
||||
// ESP32 due to memory constraints.
|
||||
uint8_t rx_header_buf_[6]; // 1 byte indicator + 5 bytes for varints (3 for size + 2 for type)
|
||||
uint8_t rx_header_buf_pos_ = 0;
|
||||
bool rx_header_parsed_ = false;
|
||||
// 8 bytes total, no padding needed
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
#endif // USE_API_PLAINTEXT
|
||||
#endif // USE_API
|
Loading…
x
Reference in New Issue
Block a user