diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index 13c5b345b6..01f4552842 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -1176,66 +1176,53 @@ void APIConnection::bluetooth_scanner_set_mode(const BluetoothScannerSetModeRequ #endif #ifdef USE_VOICE_ASSISTANT +bool APIConnection::check_voice_assistant_api_connection_() const { + return voice_assistant::global_voice_assistant != nullptr && + voice_assistant::global_voice_assistant->get_api_connection() == this; +} + void APIConnection::subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) { if (voice_assistant::global_voice_assistant != nullptr) { voice_assistant::global_voice_assistant->client_subscription(this, msg.subscribe); } } void APIConnection::on_voice_assistant_response(const VoiceAssistantResponse &msg) { - if (voice_assistant::global_voice_assistant != nullptr) { - if (voice_assistant::global_voice_assistant->get_api_connection() != this) { - return; - } + if (!this->check_voice_assistant_api_connection_()) { + return; + } - if (msg.error) { - voice_assistant::global_voice_assistant->failed_to_start(); - return; - } - if (msg.port == 0) { - // Use API Audio - voice_assistant::global_voice_assistant->start_streaming(); - } else { - struct sockaddr_storage storage; - socklen_t len = sizeof(storage); - this->helper_->getpeername((struct sockaddr *) &storage, &len); - voice_assistant::global_voice_assistant->start_streaming(&storage, msg.port); - } + if (msg.error) { + voice_assistant::global_voice_assistant->failed_to_start(); + return; + } + if (msg.port == 0) { + // Use API Audio + voice_assistant::global_voice_assistant->start_streaming(); + } else { + struct sockaddr_storage storage; + socklen_t len = sizeof(storage); + this->helper_->getpeername((struct sockaddr *) &storage, &len); + voice_assistant::global_voice_assistant->start_streaming(&storage, msg.port); } }; void APIConnection::on_voice_assistant_event_response(const VoiceAssistantEventResponse &msg) { - if (voice_assistant::global_voice_assistant != nullptr) { - if (voice_assistant::global_voice_assistant->get_api_connection() != this) { - return; - } - + if (this->check_voice_assistant_api_connection_()) { voice_assistant::global_voice_assistant->on_event(msg); } } void APIConnection::on_voice_assistant_audio(const VoiceAssistantAudio &msg) { - if (voice_assistant::global_voice_assistant != nullptr) { - if (voice_assistant::global_voice_assistant->get_api_connection() != this) { - return; - } - + if (this->check_voice_assistant_api_connection_()) { voice_assistant::global_voice_assistant->on_audio(msg); } }; void APIConnection::on_voice_assistant_timer_event_response(const VoiceAssistantTimerEventResponse &msg) { - if (voice_assistant::global_voice_assistant != nullptr) { - if (voice_assistant::global_voice_assistant->get_api_connection() != this) { - return; - } - + if (this->check_voice_assistant_api_connection_()) { voice_assistant::global_voice_assistant->on_timer_event(msg); } }; void APIConnection::on_voice_assistant_announce_request(const VoiceAssistantAnnounceRequest &msg) { - if (voice_assistant::global_voice_assistant != nullptr) { - if (voice_assistant::global_voice_assistant->get_api_connection() != this) { - return; - } - + if (this->check_voice_assistant_api_connection_()) { voice_assistant::global_voice_assistant->on_announce(msg); } } @@ -1243,35 +1230,29 @@ void APIConnection::on_voice_assistant_announce_request(const VoiceAssistantAnno VoiceAssistantConfigurationResponse APIConnection::voice_assistant_get_configuration( const VoiceAssistantConfigurationRequest &msg) { VoiceAssistantConfigurationResponse resp; - if (voice_assistant::global_voice_assistant != nullptr) { - if (voice_assistant::global_voice_assistant->get_api_connection() != this) { - return resp; - } - - auto &config = voice_assistant::global_voice_assistant->get_configuration(); - for (auto &wake_word : config.available_wake_words) { - VoiceAssistantWakeWord resp_wake_word; - resp_wake_word.id = wake_word.id; - resp_wake_word.wake_word = wake_word.wake_word; - for (const auto &lang : wake_word.trained_languages) { - resp_wake_word.trained_languages.push_back(lang); - } - resp.available_wake_words.push_back(std::move(resp_wake_word)); - } - for (auto &wake_word_id : config.active_wake_words) { - resp.active_wake_words.push_back(wake_word_id); - } - resp.max_active_wake_words = config.max_active_wake_words; + if (!this->check_voice_assistant_api_connection_()) { + return resp; } + + auto &config = voice_assistant::global_voice_assistant->get_configuration(); + for (auto &wake_word : config.available_wake_words) { + VoiceAssistantWakeWord resp_wake_word; + resp_wake_word.id = wake_word.id; + resp_wake_word.wake_word = wake_word.wake_word; + for (const auto &lang : wake_word.trained_languages) { + resp_wake_word.trained_languages.push_back(lang); + } + resp.available_wake_words.push_back(std::move(resp_wake_word)); + } + for (auto &wake_word_id : config.active_wake_words) { + resp.active_wake_words.push_back(wake_word_id); + } + resp.max_active_wake_words = config.max_active_wake_words; return resp; } void APIConnection::voice_assistant_set_configuration(const VoiceAssistantSetConfiguration &msg) { - if (voice_assistant::global_voice_assistant != nullptr) { - if (voice_assistant::global_voice_assistant->get_api_connection() != this) { - return; - } - + if (this->check_voice_assistant_api_connection_()) { voice_assistant::global_voice_assistant->on_set_configuration(msg.active_wake_words); } } diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index 166dbc3656..aa323d339d 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -301,6 +301,11 @@ class APIConnection : public APIServerConnection { static uint16_t encode_message_to_buffer(ProtoMessage &msg, uint16_t message_type, APIConnection *conn, uint32_t remaining_size, bool is_single); +#ifdef USE_VOICE_ASSISTANT + // Helper to check voice assistant validity and connection ownership + inline bool check_voice_assistant_api_connection_() const; +#endif + // Helper method to process multiple entities from an iterator in a batch template void process_iterator_batch_(Iterator &iterator) { size_t initial_size = this->deferred_batch_.size();