mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Ensure we prepare conversation pipeline when speech-to-text starts (#114665)
* Ensure we prepare conversation pipeline when speech-to-text starts * Add lock around recognize * Update homeassistant/components/conversation/default_agent.py * Add lock around load intents
This commit is contained in:
parent
4498bf9ec4
commit
01861cd240
@ -865,6 +865,15 @@ class PipelineRun:
|
||||
stream: AsyncIterable[ProcessedAudioChunk],
|
||||
) -> str:
|
||||
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
|
||||
# Create a background task to prepare the conversation agent
|
||||
if self.end_stage >= PipelineStage.INTENT:
|
||||
self.hass.async_create_background_task(
|
||||
conversation.async_prepare_agent(
|
||||
self.hass, self.intent_agent, self.language
|
||||
),
|
||||
f"prepare conversation agent {self.intent_agent}",
|
||||
)
|
||||
|
||||
if isinstance(self.stt_provider, stt.Provider):
|
||||
engine = self.stt_provider.name
|
||||
else:
|
||||
@ -967,8 +976,6 @@ class PipelineRun:
|
||||
"""Prepare recognizing an intent."""
|
||||
agent_info = conversation.async_get_agent_info(
|
||||
self.hass,
|
||||
# If no conversation engine is set, use the Home Assistant agent
|
||||
# (the conversation integration default is currently the last one set)
|
||||
self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT,
|
||||
)
|
||||
|
||||
|
@ -191,6 +191,18 @@ def async_get_agent_info(
|
||||
return None
|
||||
|
||||
|
||||
async def async_prepare_agent(
|
||||
hass: HomeAssistant, agent_id: str | None, language: str
|
||||
) -> None:
|
||||
"""Prepare given agent."""
|
||||
agent = async_get_agent(hass, agent_id)
|
||||
|
||||
if agent is None:
|
||||
raise ValueError("Invalid agent specified")
|
||||
|
||||
await agent.async_prepare(language)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Register the process service."""
|
||||
entity_component: EntityComponent[ConversationEntity] = EntityComponent(
|
||||
|
@ -165,6 +165,7 @@ class DefaultAgent(ConversationEntity):
|
||||
self._trigger_sentences: list[TriggerData] = []
|
||||
self._trigger_intents: Intents | None = None
|
||||
self._unsub_clear_slot_list: list[Callable[[], None]] | None = None
|
||||
self._load_intents_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
@ -636,22 +637,33 @@ class DefaultAgent(ConversationEntity):
|
||||
else cast(LanguageIntents, lang_intents)
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
async with self._load_intents_lock:
|
||||
# In case it was loaded now
|
||||
if lang_intents := self._lang_intents.get(language):
|
||||
return (
|
||||
None
|
||||
if lang_intents is ERROR_SENTINEL
|
||||
else cast(LanguageIntents, lang_intents)
|
||||
)
|
||||
|
||||
result = await self.hass.async_add_executor_job(self._load_intents, language)
|
||||
start = time.monotonic()
|
||||
|
||||
if result is None:
|
||||
self._lang_intents[language] = ERROR_SENTINEL
|
||||
else:
|
||||
self._lang_intents[language] = result
|
||||
result = await self.hass.async_add_executor_job(
|
||||
self._load_intents, language
|
||||
)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Full intents load completed for language=%s in %.2f seconds",
|
||||
language,
|
||||
time.monotonic() - start,
|
||||
)
|
||||
if result is None:
|
||||
self._lang_intents[language] = ERROR_SENTINEL
|
||||
else:
|
||||
self._lang_intents[language] = result
|
||||
|
||||
return result
|
||||
_LOGGER.debug(
|
||||
"Full intents load completed for language=%s in %.2f seconds",
|
||||
language,
|
||||
time.monotonic() - start,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _load_intents(self, language: str) -> LanguageIntents | None:
|
||||
"""Load all intents for language (run inside executor)."""
|
||||
|
@ -213,3 +213,18 @@ async def test_get_agent_info(
|
||||
|
||||
agent_info = conversation.async_get_agent_info(hass)
|
||||
assert agent_info == snapshot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
|
||||
async def test_prepare_agent(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Test prepare agent."""
|
||||
with patch(
|
||||
"homeassistant.components.conversation.default_agent.DefaultAgent.async_prepare"
|
||||
) as mock_prepare:
|
||||
await conversation.async_prepare_agent(hass, agent_id, "en")
|
||||
|
||||
assert len(mock_prepare.mock_calls) == 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user