From 01861cd240ce626858072644c9945ce9dfe9070f Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 9 Jul 2024 16:58:52 +0200 Subject: [PATCH] Ensure we prepare conversation pipeline when speech-to-text starts (#114665) * Ensure we prepare conversation pipeline when speech-to-text starts * Add lock around recognize * Update homeassistant/components/conversation/default_agent.py * Add lock around load intents --- .../components/assist_pipeline/pipeline.py | 11 ++++-- .../components/conversation/__init__.py | 12 +++++++ .../components/conversation/default_agent.py | 36 ++++++++++++------- tests/components/conversation/test_init.py | 15 ++++++++ 4 files changed, 60 insertions(+), 14 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index ce6f3e8d024..d8fd15900b8 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -865,6 +865,15 @@ class PipelineRun: stream: AsyncIterable[ProcessedAudioChunk], ) -> str: """Run speech-to-text portion of pipeline. Returns the spoken text.""" + # Create a background task to prepare the conversation agent + if self.end_stage >= PipelineStage.INTENT: + self.hass.async_create_background_task( + conversation.async_prepare_agent( + self.hass, self.intent_agent, self.language + ), + f"prepare conversation agent {self.intent_agent}", + ) + if isinstance(self.stt_provider, stt.Provider): engine = self.stt_provider.name else: @@ -967,8 +976,6 @@ class PipelineRun: """Prepare recognizing an intent.""" agent_info = conversation.async_get_agent_info( self.hass, - # If no conversation engine is set, use the Home Assistant agent - # (the conversation integration default is currently the last one set) self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT, ) diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index 40b0cc54e99..a7b163d69bd 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -191,6 +191,18 @@ def async_get_agent_info( return None +async def async_prepare_agent( + hass: HomeAssistant, agent_id: str | None, language: str +) -> None: + """Prepare given agent.""" + agent = async_get_agent(hass, agent_id) + + if agent is None: + raise ValueError("Invalid agent specified") + + await agent.async_prepare(language) + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Register the process service.""" entity_component: EntityComponent[ConversationEntity] = EntityComponent( diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 85343bb28ff..47a1755f8c1 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -165,6 +165,7 @@ class DefaultAgent(ConversationEntity): self._trigger_sentences: list[TriggerData] = [] self._trigger_intents: Intents | None = None self._unsub_clear_slot_list: list[Callable[[], None]] | None = None + self._load_intents_lock = asyncio.Lock() @property def supported_languages(self) -> list[str]: @@ -636,22 +637,33 @@ class DefaultAgent(ConversationEntity): else cast(LanguageIntents, lang_intents) ) - start = time.monotonic() + async with self._load_intents_lock: + # In case it was loaded now + if lang_intents := self._lang_intents.get(language): + return ( + None + if lang_intents is ERROR_SENTINEL + else cast(LanguageIntents, lang_intents) + ) - result = await self.hass.async_add_executor_job(self._load_intents, language) + start = time.monotonic() - if result is None: - self._lang_intents[language] = ERROR_SENTINEL - else: - self._lang_intents[language] = result + result = await self.hass.async_add_executor_job( + self._load_intents, language + ) - _LOGGER.debug( - "Full intents load completed for language=%s in %.2f seconds", - language, - time.monotonic() - start, - ) + if result is None: + self._lang_intents[language] = ERROR_SENTINEL + else: + self._lang_intents[language] = result - return result + _LOGGER.debug( + "Full intents load completed for language=%s in %.2f seconds", + language, + time.monotonic() - start, + ) + + return result def _load_intents(self, language: str) -> LanguageIntents | None: """Load all intents for language (run inside executor).""" diff --git a/tests/components/conversation/test_init.py b/tests/components/conversation/test_init.py index 14826401605..34a8fce636d 100644 --- a/tests/components/conversation/test_init.py +++ b/tests/components/conversation/test_init.py @@ -213,3 +213,18 @@ async def test_get_agent_info( agent_info = conversation.async_get_agent_info(hass) assert agent_info == snapshot + + +@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS) +async def test_prepare_agent( + hass: HomeAssistant, + init_components, + agent_id: str, +) -> None: + """Test prepare agent.""" + with patch( + "homeassistant.components.conversation.default_agent.DefaultAgent.async_prepare" + ) as mock_prepare: + await conversation.async_prepare_agent(hass, agent_id, "en") + + assert len(mock_prepare.mock_calls) == 1