[api] Allow noise encryption key to be set at runtime (#7296)

Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
This commit is contained in:
Keith Burzinski 2025-04-16 20:15:25 -05:00 committed by GitHub
parent ca4838a5f4
commit 2fd5f9ac58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 373 additions and 78 deletions

View File

@ -82,6 +82,19 @@ ACTIONS_SCHEMA = automation.validate_automation(
), ),
) )
ENCRYPTION_SCHEMA = cv.Schema(
{
cv.Optional(CONF_KEY): validate_encryption_key,
}
)
def _encryption_schema(config):
if config is None:
config = {}
return ENCRYPTION_SCHEMA(config)
CONFIG_SCHEMA = cv.All( CONFIG_SCHEMA = cv.All(
cv.Schema( cv.Schema(
{ {
@ -95,11 +108,7 @@ CONFIG_SCHEMA = cv.All(
CONF_SERVICES, group_of_exclusion=CONF_ACTIONS CONF_SERVICES, group_of_exclusion=CONF_ACTIONS
): ACTIONS_SCHEMA, ): ACTIONS_SCHEMA,
cv.Exclusive(CONF_ACTIONS, group_of_exclusion=CONF_ACTIONS): ACTIONS_SCHEMA, cv.Exclusive(CONF_ACTIONS, group_of_exclusion=CONF_ACTIONS): ACTIONS_SCHEMA,
cv.Optional(CONF_ENCRYPTION): cv.Schema( cv.Optional(CONF_ENCRYPTION): _encryption_schema,
{
cv.Required(CONF_KEY): validate_encryption_key,
}
),
cv.Optional(CONF_ON_CLIENT_CONNECTED): automation.validate_automation( cv.Optional(CONF_ON_CLIENT_CONNECTED): automation.validate_automation(
single=True single=True
), ),
@ -151,9 +160,17 @@ async def to_code(config):
config[CONF_ON_CLIENT_DISCONNECTED], config[CONF_ON_CLIENT_DISCONNECTED],
) )
if encryption_config := config.get(CONF_ENCRYPTION): if (encryption_config := config.get(CONF_ENCRYPTION, None)) is not None:
decoded = base64.b64decode(encryption_config[CONF_KEY]) if key := encryption_config.get(CONF_KEY):
cg.add(var.set_noise_psk(list(decoded))) decoded = base64.b64decode(key)
cg.add(var.set_noise_psk(list(decoded)))
else:
# No key provided, but encryption desired
# This will allow a plaintext client to provide a noise key,
# send it to the device, and then switch to noise.
# The key will be saved in flash and used for future connections
# and plaintext disabled. Only a factory reset can remove it.
cg.add_define("USE_API_PLAINTEXT")
cg.add_define("USE_API_NOISE") cg.add_define("USE_API_NOISE")
cg.add_library("esphome/noise-c", "0.1.6") cg.add_library("esphome/noise-c", "0.1.6")
else: else:

View File

@ -31,6 +31,7 @@ service APIConnection {
option (needs_authentication) = false; option (needs_authentication) = false;
} }
rpc execute_service (ExecuteServiceRequest) returns (void) {} rpc execute_service (ExecuteServiceRequest) returns (void) {}
rpc noise_encryption_set_key (NoiseEncryptionSetKeyRequest) returns (NoiseEncryptionSetKeyResponse) {}
rpc cover_command (CoverCommandRequest) returns (void) {} rpc cover_command (CoverCommandRequest) returns (void) {}
rpc fan_command (FanCommandRequest) returns (void) {} rpc fan_command (FanCommandRequest) returns (void) {}
@ -230,6 +231,9 @@ message DeviceInfoResponse {
// The Bluetooth mac address of the device. For example "AC:BC:32:89:0E:AA" // The Bluetooth mac address of the device. For example "AC:BC:32:89:0E:AA"
string bluetooth_mac_address = 18; string bluetooth_mac_address = 18;
// Supports receiving and saving api encryption key
bool api_encryption_supported = 19;
} }
message ListEntitiesRequest { message ListEntitiesRequest {
@ -654,6 +658,23 @@ message SubscribeLogsResponse {
bool send_failed = 4; bool send_failed = 4;
} }
// ==================== NOISE ENCRYPTION ====================
message NoiseEncryptionSetKeyRequest {
option (id) = 124;
option (source) = SOURCE_CLIENT;
option (ifdef) = "USE_API_NOISE";
bytes key = 1;
}
message NoiseEncryptionSetKeyResponse {
option (id) = 125;
option (source) = SOURCE_SERVER;
option (ifdef) = "USE_API_NOISE";
bool success = 1;
}
// ==================== HOMEASSISTANT.SERVICE ==================== // ==================== HOMEASSISTANT.SERVICE ====================
message SubscribeHomeassistantServicesRequest { message SubscribeHomeassistantServicesRequest {
option (id) = 34; option (id) = 34;

View File

@ -62,7 +62,14 @@ APIConnection::APIConnection(std::unique_ptr<socket::Socket> sock, APIServer *pa
: parent_(parent), deferred_message_queue_(this), initial_state_iterator_(this), list_entities_iterator_(this) { : parent_(parent), deferred_message_queue_(this), initial_state_iterator_(this), list_entities_iterator_(this) {
this->proto_write_buffer_.reserve(64); this->proto_write_buffer_.reserve(64);
#if defined(USE_API_PLAINTEXT) #if defined(USE_API_PLAINTEXT) && defined(USE_API_NOISE)
auto noise_ctx = parent->get_noise_ctx();
if (noise_ctx->has_psk()) {
this->helper_ = std::unique_ptr<APIFrameHelper>{new APINoiseFrameHelper(std::move(sock), noise_ctx)};
} else {
this->helper_ = std::unique_ptr<APIFrameHelper>{new APIPlaintextFrameHelper(std::move(sock))};
}
#elif defined(USE_API_PLAINTEXT)
this->helper_ = std::unique_ptr<APIFrameHelper>{new APIPlaintextFrameHelper(std::move(sock))}; this->helper_ = std::unique_ptr<APIFrameHelper>{new APIPlaintextFrameHelper(std::move(sock))};
#elif defined(USE_API_NOISE) #elif defined(USE_API_NOISE)
this->helper_ = std::unique_ptr<APIFrameHelper>{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())}; this->helper_ = std::unique_ptr<APIFrameHelper>{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())};
@ -1848,6 +1855,9 @@ DeviceInfoResponse APIConnection::device_info(const DeviceInfoRequest &msg) {
#ifdef USE_VOICE_ASSISTANT #ifdef USE_VOICE_ASSISTANT
resp.legacy_voice_assistant_version = voice_assistant::global_voice_assistant->get_legacy_version(); resp.legacy_voice_assistant_version = voice_assistant::global_voice_assistant->get_legacy_version();
resp.voice_assistant_feature_flags = voice_assistant::global_voice_assistant->get_feature_flags(); resp.voice_assistant_feature_flags = voice_assistant::global_voice_assistant->get_feature_flags();
#endif
#ifdef USE_API_NOISE
resp.api_encryption_supported = true;
#endif #endif
return resp; return resp;
} }
@ -1869,6 +1879,26 @@ void APIConnection::execute_service(const ExecuteServiceRequest &msg) {
ESP_LOGV(TAG, "Could not find matching service!"); ESP_LOGV(TAG, "Could not find matching service!");
} }
} }
#ifdef USE_API_NOISE
NoiseEncryptionSetKeyResponse APIConnection::noise_encryption_set_key(const NoiseEncryptionSetKeyRequest &msg) {
psk_t psk{};
NoiseEncryptionSetKeyResponse resp;
if (base64_decode(msg.key, psk.data(), msg.key.size()) != psk.size()) {
ESP_LOGW(TAG, "Invalid encryption key length");
resp.success = false;
return resp;
}
if (!this->parent_->save_noise_psk(psk, true)) {
ESP_LOGW(TAG, "Failed to save encryption key");
resp.success = false;
return resp;
}
resp.success = true;
return resp;
}
#endif
void APIConnection::subscribe_home_assistant_states(const SubscribeHomeAssistantStatesRequest &msg) { void APIConnection::subscribe_home_assistant_states(const SubscribeHomeAssistantStatesRequest &msg) {
state_subs_at_ = 0; state_subs_at_ = 0;
} }

View File

@ -300,6 +300,9 @@ class APIConnection : public APIServerConnection {
return {}; return {};
} }
void execute_service(const ExecuteServiceRequest &msg) override; void execute_service(const ExecuteServiceRequest &msg) override;
#ifdef USE_API_NOISE
NoiseEncryptionSetKeyResponse noise_encryption_set_key(const NoiseEncryptionSetKeyRequest &msg) override;
#endif
bool is_authenticated() override { return this->connection_state_ == ConnectionState::AUTHENTICATED; } bool is_authenticated() override { return this->connection_state_ == ConnectionState::AUTHENTICATED; }
bool is_connection_setup() override { bool is_connection_setup() override {

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <cstdint>
#include <array> #include <array>
#include <cstdint>
#include "esphome/core/defines.h" #include "esphome/core/defines.h"
namespace esphome { namespace esphome {
@ -11,11 +11,20 @@ using psk_t = std::array<uint8_t, 32>;
class APINoiseContext { class APINoiseContext {
public: public:
void set_psk(psk_t psk) { psk_ = psk; } void set_psk(psk_t psk) {
const psk_t &get_psk() const { return psk_; } this->psk_ = psk;
bool has_psk = false;
for (auto i : psk) {
has_psk |= i;
}
this->has_psk_ = has_psk;
}
const psk_t &get_psk() const { return this->psk_; }
bool has_psk() const { return this->has_psk_; }
protected: protected:
psk_t psk_; psk_t psk_{};
bool has_psk_{false};
}; };
#endif // USE_API_NOISE #endif // USE_API_NOISE

View File

@ -792,6 +792,10 @@ bool DeviceInfoResponse::decode_varint(uint32_t field_id, ProtoVarInt value) {
this->voice_assistant_feature_flags = value.as_uint32(); this->voice_assistant_feature_flags = value.as_uint32();
return true; return true;
} }
case 19: {
this->api_encryption_supported = value.as_bool();
return true;
}
default: default:
return false; return false;
} }
@ -865,6 +869,7 @@ void DeviceInfoResponse::encode(ProtoWriteBuffer buffer) const {
buffer.encode_uint32(17, this->voice_assistant_feature_flags); buffer.encode_uint32(17, this->voice_assistant_feature_flags);
buffer.encode_string(16, this->suggested_area); buffer.encode_string(16, this->suggested_area);
buffer.encode_string(18, this->bluetooth_mac_address); buffer.encode_string(18, this->bluetooth_mac_address);
buffer.encode_bool(19, this->api_encryption_supported);
} }
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
void DeviceInfoResponse::dump_to(std::string &out) const { void DeviceInfoResponse::dump_to(std::string &out) const {
@ -946,6 +951,10 @@ void DeviceInfoResponse::dump_to(std::string &out) const {
out.append(" bluetooth_mac_address: "); out.append(" bluetooth_mac_address: ");
out.append("'").append(this->bluetooth_mac_address).append("'"); out.append("'").append(this->bluetooth_mac_address).append("'");
out.append("\n"); out.append("\n");
out.append(" api_encryption_supported: ");
out.append(YESNO(this->api_encryption_supported));
out.append("\n");
out.append("}"); out.append("}");
} }
#endif #endif
@ -3009,6 +3018,48 @@ void SubscribeLogsResponse::dump_to(std::string &out) const {
out.append("}"); out.append("}");
} }
#endif #endif
bool NoiseEncryptionSetKeyRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) {
switch (field_id) {
case 1: {
this->key = value.as_string();
return true;
}
default:
return false;
}
}
void NoiseEncryptionSetKeyRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(1, this->key); }
#ifdef HAS_PROTO_MESSAGE_DUMP
void NoiseEncryptionSetKeyRequest::dump_to(std::string &out) const {
__attribute__((unused)) char buffer[64];
out.append("NoiseEncryptionSetKeyRequest {\n");
out.append(" key: ");
out.append("'").append(this->key).append("'");
out.append("\n");
out.append("}");
}
#endif
bool NoiseEncryptionSetKeyResponse::decode_varint(uint32_t field_id, ProtoVarInt value) {
switch (field_id) {
case 1: {
this->success = value.as_bool();
return true;
}
default:
return false;
}
}
void NoiseEncryptionSetKeyResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->success); }
#ifdef HAS_PROTO_MESSAGE_DUMP
void NoiseEncryptionSetKeyResponse::dump_to(std::string &out) const {
__attribute__((unused)) char buffer[64];
out.append("NoiseEncryptionSetKeyResponse {\n");
out.append(" success: ");
out.append(YESNO(this->success));
out.append("\n");
out.append("}");
}
#endif
void SubscribeHomeassistantServicesRequest::encode(ProtoWriteBuffer buffer) const {} void SubscribeHomeassistantServicesRequest::encode(ProtoWriteBuffer buffer) const {}
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
void SubscribeHomeassistantServicesRequest::dump_to(std::string &out) const { void SubscribeHomeassistantServicesRequest::dump_to(std::string &out) const {

View File

@ -355,6 +355,7 @@ class DeviceInfoResponse : public ProtoMessage {
uint32_t voice_assistant_feature_flags{0}; uint32_t voice_assistant_feature_flags{0};
std::string suggested_area{}; std::string suggested_area{};
std::string bluetooth_mac_address{}; std::string bluetooth_mac_address{};
bool api_encryption_supported{false};
void encode(ProtoWriteBuffer buffer) const override; void encode(ProtoWriteBuffer buffer) const override;
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
void dump_to(std::string &out) const override; void dump_to(std::string &out) const override;
@ -791,6 +792,28 @@ class SubscribeLogsResponse : public ProtoMessage {
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
bool decode_varint(uint32_t field_id, ProtoVarInt value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
}; };
class NoiseEncryptionSetKeyRequest : public ProtoMessage {
public:
std::string key{};
void encode(ProtoWriteBuffer buffer) const override;
#ifdef HAS_PROTO_MESSAGE_DUMP
void dump_to(std::string &out) const override;
#endif
protected:
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
};
class NoiseEncryptionSetKeyResponse : public ProtoMessage {
public:
bool success{false};
void encode(ProtoWriteBuffer buffer) const override;
#ifdef HAS_PROTO_MESSAGE_DUMP
void dump_to(std::string &out) const override;
#endif
protected:
bool decode_varint(uint32_t field_id, ProtoVarInt value) override;
};
class SubscribeHomeassistantServicesRequest : public ProtoMessage { class SubscribeHomeassistantServicesRequest : public ProtoMessage {
public: public:
void encode(ProtoWriteBuffer buffer) const override; void encode(ProtoWriteBuffer buffer) const override;

View File

@ -179,6 +179,16 @@ bool APIServerConnectionBase::send_text_sensor_state_response(const TextSensorSt
bool APIServerConnectionBase::send_subscribe_logs_response(const SubscribeLogsResponse &msg) { bool APIServerConnectionBase::send_subscribe_logs_response(const SubscribeLogsResponse &msg) {
return this->send_message_<SubscribeLogsResponse>(msg, 29); return this->send_message_<SubscribeLogsResponse>(msg, 29);
} }
#ifdef USE_API_NOISE
#endif
#ifdef USE_API_NOISE
bool APIServerConnectionBase::send_noise_encryption_set_key_response(const NoiseEncryptionSetKeyResponse &msg) {
#ifdef HAS_PROTO_MESSAGE_DUMP
ESP_LOGVV(TAG, "send_noise_encryption_set_key_response: %s", msg.dump().c_str());
#endif
return this->send_message_<NoiseEncryptionSetKeyResponse>(msg, 125);
}
#endif
bool APIServerConnectionBase::send_homeassistant_service_response(const HomeassistantServiceResponse &msg) { bool APIServerConnectionBase::send_homeassistant_service_response(const HomeassistantServiceResponse &msg) {
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
ESP_LOGVV(TAG, "send_homeassistant_service_response: %s", msg.dump().c_str()); ESP_LOGVV(TAG, "send_homeassistant_service_response: %s", msg.dump().c_str());
@ -1191,6 +1201,17 @@ bool APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type,
ESP_LOGVV(TAG, "on_voice_assistant_set_configuration: %s", msg.dump().c_str()); ESP_LOGVV(TAG, "on_voice_assistant_set_configuration: %s", msg.dump().c_str());
#endif #endif
this->on_voice_assistant_set_configuration(msg); this->on_voice_assistant_set_configuration(msg);
#endif
break;
}
case 124: {
#ifdef USE_API_NOISE
NoiseEncryptionSetKeyRequest msg;
msg.decode(msg_data, msg_size);
#ifdef HAS_PROTO_MESSAGE_DUMP
ESP_LOGVV(TAG, "on_noise_encryption_set_key_request: %s", msg.dump().c_str());
#endif
this->on_noise_encryption_set_key_request(msg);
#endif #endif
break; break;
} }
@ -1311,6 +1332,22 @@ void APIServerConnection::on_execute_service_request(const ExecuteServiceRequest
} }
this->execute_service(msg); this->execute_service(msg);
} }
#ifdef USE_API_NOISE
void APIServerConnection::on_noise_encryption_set_key_request(const NoiseEncryptionSetKeyRequest &msg) {
if (!this->is_connection_setup()) {
this->on_no_setup_connection();
return;
}
if (!this->is_authenticated()) {
this->on_unauthenticated_access();
return;
}
NoiseEncryptionSetKeyResponse ret = this->noise_encryption_set_key(msg);
if (!this->send_noise_encryption_set_key_response(ret)) {
this->on_fatal_error();
}
}
#endif
#ifdef USE_COVER #ifdef USE_COVER
void APIServerConnection::on_cover_command_request(const CoverCommandRequest &msg) { void APIServerConnection::on_cover_command_request(const CoverCommandRequest &msg) {
if (!this->is_connection_setup()) { if (!this->is_connection_setup()) {

View File

@ -83,6 +83,12 @@ class APIServerConnectionBase : public ProtoService {
#endif #endif
virtual void on_subscribe_logs_request(const SubscribeLogsRequest &value){}; virtual void on_subscribe_logs_request(const SubscribeLogsRequest &value){};
bool send_subscribe_logs_response(const SubscribeLogsResponse &msg); bool send_subscribe_logs_response(const SubscribeLogsResponse &msg);
#ifdef USE_API_NOISE
virtual void on_noise_encryption_set_key_request(const NoiseEncryptionSetKeyRequest &value){};
#endif
#ifdef USE_API_NOISE
bool send_noise_encryption_set_key_response(const NoiseEncryptionSetKeyResponse &msg);
#endif
virtual void on_subscribe_homeassistant_services_request(const SubscribeHomeassistantServicesRequest &value){}; virtual void on_subscribe_homeassistant_services_request(const SubscribeHomeassistantServicesRequest &value){};
bool send_homeassistant_service_response(const HomeassistantServiceResponse &msg); bool send_homeassistant_service_response(const HomeassistantServiceResponse &msg);
virtual void on_subscribe_home_assistant_states_request(const SubscribeHomeAssistantStatesRequest &value){}; virtual void on_subscribe_home_assistant_states_request(const SubscribeHomeAssistantStatesRequest &value){};
@ -349,6 +355,9 @@ class APIServerConnection : public APIServerConnectionBase {
virtual void subscribe_home_assistant_states(const SubscribeHomeAssistantStatesRequest &msg) = 0; virtual void subscribe_home_assistant_states(const SubscribeHomeAssistantStatesRequest &msg) = 0;
virtual GetTimeResponse get_time(const GetTimeRequest &msg) = 0; virtual GetTimeResponse get_time(const GetTimeRequest &msg) = 0;
virtual void execute_service(const ExecuteServiceRequest &msg) = 0; virtual void execute_service(const ExecuteServiceRequest &msg) = 0;
#ifdef USE_API_NOISE
virtual NoiseEncryptionSetKeyResponse noise_encryption_set_key(const NoiseEncryptionSetKeyRequest &msg) = 0;
#endif
#ifdef USE_COVER #ifdef USE_COVER
virtual void cover_command(const CoverCommandRequest &msg) = 0; virtual void cover_command(const CoverCommandRequest &msg) = 0;
#endif #endif
@ -457,6 +466,9 @@ class APIServerConnection : public APIServerConnectionBase {
void on_subscribe_home_assistant_states_request(const SubscribeHomeAssistantStatesRequest &msg) override; void on_subscribe_home_assistant_states_request(const SubscribeHomeAssistantStatesRequest &msg) override;
void on_get_time_request(const GetTimeRequest &msg) override; void on_get_time_request(const GetTimeRequest &msg) override;
void on_execute_service_request(const ExecuteServiceRequest &msg) override; void on_execute_service_request(const ExecuteServiceRequest &msg) override;
#ifdef USE_API_NOISE
void on_noise_encryption_set_key_request(const NoiseEncryptionSetKeyRequest &msg) override;
#endif
#ifdef USE_COVER #ifdef USE_COVER
void on_cover_command_request(const CoverCommandRequest &msg) override; void on_cover_command_request(const CoverCommandRequest &msg) override;
#endif #endif

View File

@ -22,22 +22,40 @@ namespace api {
static const char *const TAG = "api"; static const char *const TAG = "api";
// APIServer // APIServer
APIServer *global_api_server = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
APIServer::APIServer() { global_api_server = this; }
void APIServer::setup() { void APIServer::setup() {
ESP_LOGCONFIG(TAG, "Setting up Home Assistant API server..."); ESP_LOGCONFIG(TAG, "Setting up Home Assistant API server...");
this->setup_controller(); this->setup_controller();
socket_ = socket::socket_ip(SOCK_STREAM, 0);
if (socket_ == nullptr) { #ifdef USE_API_NOISE
ESP_LOGW(TAG, "Could not create socket."); uint32_t hash = 88491486UL;
this->noise_pref_ = global_preferences->make_preference<SavedNoisePsk>(hash, true);
SavedNoisePsk noise_pref_saved{};
if (this->noise_pref_.load(&noise_pref_saved)) {
ESP_LOGD(TAG, "Loaded saved Noise PSK");
this->set_noise_psk(noise_pref_saved.psk);
}
#endif
this->socket_ = socket::socket_ip(SOCK_STREAM, 0);
if (this->socket_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket");
this->mark_failed(); this->mark_failed();
return; return;
} }
int enable = 1; int enable = 1;
int err = socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); int err = this->socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (err != 0) { if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err); ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
// we can still continue // we can still continue
} }
err = socket_->setblocking(false); err = this->socket_->setblocking(false);
if (err != 0) { if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err); ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
this->mark_failed(); this->mark_failed();
@ -53,14 +71,14 @@ void APIServer::setup() {
return; return;
} }
err = socket_->bind((struct sockaddr *) &server, sl); err = this->socket_->bind((struct sockaddr *) &server, sl);
if (err != 0) { if (err != 0) {
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno); ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
this->mark_failed(); this->mark_failed();
return; return;
} }
err = socket_->listen(4); err = this->socket_->listen(4);
if (err != 0) { if (err != 0) {
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno); ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
this->mark_failed(); this->mark_failed();
@ -92,18 +110,19 @@ void APIServer::setup() {
} }
#endif #endif
} }
void APIServer::loop() { void APIServer::loop() {
// Accept new clients // Accept new clients
while (true) { while (true) {
struct sockaddr_storage source_addr; struct sockaddr_storage source_addr;
socklen_t addr_len = sizeof(source_addr); socklen_t addr_len = sizeof(source_addr);
auto sock = socket_->accept((struct sockaddr *) &source_addr, &addr_len); auto sock = this->socket_->accept((struct sockaddr *) &source_addr, &addr_len);
if (!sock) if (!sock)
break; break;
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str()); ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());
auto *conn = new APIConnection(std::move(sock), this); auto *conn = new APIConnection(std::move(sock), this);
clients_.emplace_back(conn); this->clients_.emplace_back(conn);
conn->start(); conn->start();
} }
@ -136,16 +155,22 @@ void APIServer::loop() {
} }
} }
} }
void APIServer::dump_config() { void APIServer::dump_config() {
ESP_LOGCONFIG(TAG, "API Server:"); ESP_LOGCONFIG(TAG, "API Server:");
ESP_LOGCONFIG(TAG, " Address: %s:%u", network::get_use_address().c_str(), this->port_); ESP_LOGCONFIG(TAG, " Address: %s:%u", network::get_use_address().c_str(), this->port_);
#ifdef USE_API_NOISE #ifdef USE_API_NOISE
ESP_LOGCONFIG(TAG, " Using noise encryption: YES"); ESP_LOGCONFIG(TAG, " Using noise encryption: %s", YESNO(this->noise_ctx_->has_psk()));
if (!this->noise_ctx_->has_psk()) {
ESP_LOGCONFIG(TAG, " Supports noise encryption: YES");
}
#else #else
ESP_LOGCONFIG(TAG, " Using noise encryption: NO"); ESP_LOGCONFIG(TAG, " Using noise encryption: NO");
#endif #endif
} }
bool APIServer::uses_password() const { return !this->password_.empty(); } bool APIServer::uses_password() const { return !this->password_.empty(); }
bool APIServer::check_password(const std::string &password) const { bool APIServer::check_password(const std::string &password) const {
// depend only on input password length // depend only on input password length
const char *a = this->password_.c_str(); const char *a = this->password_.c_str();
@ -174,7 +199,9 @@ bool APIServer::check_password(const std::string &password) const {
return result == 0; return result == 0;
} }
void APIServer::handle_disconnect(APIConnection *conn) {} void APIServer::handle_disconnect(APIConnection *conn) {}
#ifdef USE_BINARY_SENSOR #ifdef USE_BINARY_SENSOR
void APIServer::on_binary_sensor_update(binary_sensor::BinarySensor *obj, bool state) { void APIServer::on_binary_sensor_update(binary_sensor::BinarySensor *obj, bool state) {
if (obj->is_internal()) if (obj->is_internal())
@ -342,57 +369,6 @@ void APIServer::on_update(update::UpdateEntity *obj) {
} }
#endif #endif
float APIServer::get_setup_priority() const { return setup_priority::AFTER_WIFI; }
void APIServer::set_port(uint16_t port) { this->port_ = port; }
APIServer *global_api_server = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
void APIServer::set_password(const std::string &password) { this->password_ = password; }
void APIServer::send_homeassistant_service_call(const HomeassistantServiceResponse &call) {
for (auto &client : this->clients_) {
client->send_homeassistant_service_call(call);
}
}
APIServer::APIServer() { global_api_server = this; }
void APIServer::subscribe_home_assistant_state(std::string entity_id, optional<std::string> attribute,
std::function<void(std::string)> f) {
this->state_subs_.push_back(HomeAssistantStateSubscription{
.entity_id = std::move(entity_id),
.attribute = std::move(attribute),
.callback = std::move(f),
.once = false,
});
}
void APIServer::get_home_assistant_state(std::string entity_id, optional<std::string> attribute,
std::function<void(std::string)> f) {
this->state_subs_.push_back(HomeAssistantStateSubscription{
.entity_id = std::move(entity_id),
.attribute = std::move(attribute),
.callback = std::move(f),
.once = true,
});
};
const std::vector<APIServer::HomeAssistantStateSubscription> &APIServer::get_state_subs() const {
return this->state_subs_;
}
uint16_t APIServer::get_port() const { return this->port_; }
void APIServer::set_reboot_timeout(uint32_t reboot_timeout) { this->reboot_timeout_ = reboot_timeout; }
#ifdef USE_HOMEASSISTANT_TIME
void APIServer::request_time() {
for (auto &client : this->clients_) {
if (!client->remove_ && client->is_authenticated())
client->send_time_request();
}
}
#endif
bool APIServer::is_connected() const { return !this->clients_.empty(); }
void APIServer::on_shutdown() {
for (auto &c : this->clients_) {
c->send_disconnect_request(DisconnectRequest());
}
delay(10);
}
#ifdef USE_ALARM_CONTROL_PANEL #ifdef USE_ALARM_CONTROL_PANEL
void APIServer::on_alarm_control_panel_update(alarm_control_panel::AlarmControlPanel *obj) { void APIServer::on_alarm_control_panel_update(alarm_control_panel::AlarmControlPanel *obj) {
if (obj->is_internal()) if (obj->is_internal())
@ -402,6 +378,96 @@ void APIServer::on_alarm_control_panel_update(alarm_control_panel::AlarmControlP
} }
#endif #endif
float APIServer::get_setup_priority() const { return setup_priority::AFTER_WIFI; }
void APIServer::set_port(uint16_t port) { this->port_ = port; }
void APIServer::set_password(const std::string &password) { this->password_ = password; }
void APIServer::send_homeassistant_service_call(const HomeassistantServiceResponse &call) {
for (auto &client : this->clients_) {
client->send_homeassistant_service_call(call);
}
}
void APIServer::subscribe_home_assistant_state(std::string entity_id, optional<std::string> attribute,
std::function<void(std::string)> f) {
this->state_subs_.push_back(HomeAssistantStateSubscription{
.entity_id = std::move(entity_id),
.attribute = std::move(attribute),
.callback = std::move(f),
.once = false,
});
}
void APIServer::get_home_assistant_state(std::string entity_id, optional<std::string> attribute,
std::function<void(std::string)> f) {
this->state_subs_.push_back(HomeAssistantStateSubscription{
.entity_id = std::move(entity_id),
.attribute = std::move(attribute),
.callback = std::move(f),
.once = true,
});
};
const std::vector<APIServer::HomeAssistantStateSubscription> &APIServer::get_state_subs() const {
return this->state_subs_;
}
uint16_t APIServer::get_port() const { return this->port_; }
void APIServer::set_reboot_timeout(uint32_t reboot_timeout) { this->reboot_timeout_ = reboot_timeout; }
#ifdef USE_API_NOISE
bool APIServer::save_noise_psk(psk_t psk, bool make_active) {
auto &old_psk = this->noise_ctx_->get_psk();
if (std::equal(old_psk.begin(), old_psk.end(), psk.begin())) {
ESP_LOGW(TAG, "New PSK matches old");
return true;
}
SavedNoisePsk new_saved_psk{psk};
if (!this->noise_pref_.save(&new_saved_psk)) {
ESP_LOGW(TAG, "Failed to save Noise PSK");
return false;
}
// ensure it's written immediately
if (!global_preferences->sync()) {
ESP_LOGW(TAG, "Failed to sync preferences");
return false;
}
ESP_LOGD(TAG, "Noise PSK saved");
if (make_active) {
this->set_timeout(100, [this, psk]() {
ESP_LOGW(TAG, "Disconnecting all clients to reset connections");
this->set_noise_psk(psk);
for (auto &c : this->clients_) {
c->send_disconnect_request(DisconnectRequest());
}
});
}
return true;
}
#endif
#ifdef USE_HOMEASSISTANT_TIME
void APIServer::request_time() {
for (auto &client : this->clients_) {
if (!client->remove_ && client->is_authenticated())
client->send_time_request();
}
}
#endif
bool APIServer::is_connected() const { return !this->clients_.empty(); }
void APIServer::on_shutdown() {
for (auto &c : this->clients_) {
c->send_disconnect_request(DisconnectRequest());
}
delay(10);
}
} // namespace api } // namespace api
} // namespace esphome } // namespace esphome
#endif #endif

View File

@ -19,6 +19,12 @@
namespace esphome { namespace esphome {
namespace api { namespace api {
#ifdef USE_API_NOISE
struct SavedNoisePsk {
psk_t psk;
} PACKED; // NOLINT
#endif
class APIServer : public Component, public Controller { class APIServer : public Component, public Controller {
public: public:
APIServer(); APIServer();
@ -35,6 +41,7 @@ class APIServer : public Component, public Controller {
void set_reboot_timeout(uint32_t reboot_timeout); void set_reboot_timeout(uint32_t reboot_timeout);
#ifdef USE_API_NOISE #ifdef USE_API_NOISE
bool save_noise_psk(psk_t psk, bool make_active = true);
void set_noise_psk(psk_t psk) { noise_ctx_->set_psk(psk); } void set_noise_psk(psk_t psk) { noise_ctx_->set_psk(psk); }
std::shared_ptr<APINoiseContext> get_noise_ctx() { return noise_ctx_; } std::shared_ptr<APINoiseContext> get_noise_ctx() { return noise_ctx_; }
#endif // USE_API_NOISE #endif // USE_API_NOISE
@ -142,6 +149,7 @@ class APIServer : public Component, public Controller {
#ifdef USE_API_NOISE #ifdef USE_API_NOISE
std::shared_ptr<APINoiseContext> noise_ctx_ = std::make_shared<APINoiseContext>(); std::shared_ptr<APINoiseContext> noise_ctx_ = std::make_shared<APINoiseContext>();
ESPPreferenceObject noise_pref_;
#endif // USE_API_NOISE #endif // USE_API_NOISE
}; };

View File

@ -1,9 +1,9 @@
#include "esphome/core/defines.h" #include "esphome/core/defines.h"
#ifdef USE_MDNS #ifdef USE_MDNS
#include "mdns_component.h"
#include "esphome/core/version.h"
#include "esphome/core/application.h" #include "esphome/core/application.h"
#include "esphome/core/log.h" #include "esphome/core/log.h"
#include "esphome/core/version.h"
#include "mdns_component.h"
#ifdef USE_API #ifdef USE_API
#include "esphome/components/api/api_server.h" #include "esphome/components/api/api_server.h"
@ -62,7 +62,11 @@ void MDNSComponent::compile_records_() {
#endif #endif
#ifdef USE_API_NOISE #ifdef USE_API_NOISE
service.txt_records.push_back({"api_encryption", "Noise_NNpsk0_25519_ChaChaPoly_SHA256"}); if (api::global_api_server->get_noise_ctx()->has_psk()) {
service.txt_records.push_back({"api_encryption", "Noise_NNpsk0_25519_ChaChaPoly_SHA256"});
} else {
service.txt_records.push_back({"api_encryption_supported", "Noise_NNpsk0_25519_ChaChaPoly_SHA256"});
}
#endif #endif
#ifdef ESPHOME_PROJECT_NAME #ifdef ESPHOME_PROJECT_NAME

View File

@ -138,7 +138,11 @@ void MQTTClientComponent::send_device_info_() {
#endif #endif
#ifdef USE_API_NOISE #ifdef USE_API_NOISE
root["api_encryption"] = "Noise_NNpsk0_25519_ChaChaPoly_SHA256"; if (api::global_api_server->get_noise_ctx()->has_psk()) {
root["api_encryption"] = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
} else {
root["api_encryption_supported"] = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
}
#endif #endif
}, },
2, this->discovery_info_.retain); 2, this->discovery_info_.retain);

View File

@ -0,0 +1,10 @@
packages:
common: !include common.yaml
wifi:
ssid: MySSID
password: password1
api:
encryption:
key: !remove