diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index b95b9361624..15f2e6d8322 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -30,11 +30,16 @@ DATA_AGENT = "conversation_agent" DATA_CONFIG = "conversation_config" SERVICE_PROCESS = "process" +SERVICE_RELOAD = "reload" SERVICE_PROCESS_SCHEMA = vol.Schema( {vol.Required(ATTR_TEXT): cv.string, vol.Optional(ATTR_LANGUAGE): cv.string} ) + +SERVICE_RELOAD_SCHEMA = vol.Schema({vol.Optional(ATTR_LANGUAGE): cv.string}) + + CONFIG_SCHEMA = vol.Schema( { DOMAIN: vol.Schema( @@ -62,7 +67,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Register the process service.""" hass.data[DATA_CONFIG] = config - async def handle_service(service: core.ServiceCall) -> None: + async def handle_process(service: core.ServiceCall) -> None: """Parse text into commands.""" text = service.data[ATTR_TEXT] _LOGGER.debug("Processing: <%s>", text) @@ -74,11 +79,20 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: except intent.IntentHandleError as err: _LOGGER.error("Error processing %s: %s", text, err) + async def handle_reload(service: core.ServiceCall) -> None: + """Reload intents.""" + agent = await _get_agent(hass) + await agent.async_reload(language=service.data.get(ATTR_LANGUAGE)) + hass.services.async_register( - DOMAIN, SERVICE_PROCESS, handle_service, schema=SERVICE_PROCESS_SCHEMA + DOMAIN, SERVICE_PROCESS, handle_process, schema=SERVICE_PROCESS_SCHEMA + ) + hass.services.async_register( + DOMAIN, SERVICE_RELOAD, handle_reload, schema=SERVICE_RELOAD_SCHEMA ) hass.http.register_view(ConversationProcessView()) websocket_api.async_register_command(hass, websocket_process) + websocket_api.async_register_command(hass, websocket_prepare) websocket_api.async_register_command(hass, websocket_get_agent_info) websocket_api.async_register_command(hass, websocket_set_onboarding) @@ -110,6 +124,24 @@ async def websocket_process( connection.send_result(msg["id"], result.as_dict()) +@websocket_api.websocket_command( + { + "type": "conversation/prepare", + vol.Optional("language"): str, + } +) +@websocket_api.async_response +async def websocket_prepare( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Reload intents.""" + agent = await _get_agent(hass) + await agent.async_prepare(msg.get("language")) + connection.send_result(msg["id"]) + + @websocket_api.websocket_command({"type": "conversation/agent/info"}) @websocket_api.async_response async def websocket_get_agent_info( diff --git a/homeassistant/components/conversation/agent.py b/homeassistant/components/conversation/agent.py index 0bd3f018589..889412996aa 100644 --- a/homeassistant/components/conversation/agent.py +++ b/homeassistant/components/conversation/agent.py @@ -49,3 +49,9 @@ class AbstractConversationAgent(ABC): language: str | None = None, ) -> ConversationResult | None: """Process a sentence.""" + + async def async_reload(self, language: str | None = None): + """Clear cached intents for a language.""" + + async def async_prepare(self, language: str | None = None): + """Load intents for a language.""" diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 94af03a9e90..a87d6606db9 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -1,6 +1,8 @@ """Standard conversation implementation for Home Assistant.""" from __future__ import annotations +import asyncio +from collections import defaultdict from dataclasses import dataclass import logging from pathlib import Path @@ -58,6 +60,7 @@ class DefaultAgent(AbstractConversationAgent): """Initialize the default agent.""" self.hass = hass self._lang_intents: dict[str, LanguageIntents] = {} + self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) async def async_initialize(self, config): """Initialize the default agent.""" @@ -88,10 +91,7 @@ class DefaultAgent(AbstractConversationAgent): lang_intents.loaded_components - self.hass.config.components ): # Load intents in executor - lang_intents = await self.hass.async_add_executor_job( - self.get_or_load_intents, - language, - ) + lang_intents = await self.async_get_or_load_intents(language) if lang_intents is None: # No intents loaded @@ -121,8 +121,35 @@ class DefaultAgent(AbstractConversationAgent): response=intent_response, conversation_id=conversation_id ) - def get_or_load_intents(self, language: str) -> LanguageIntents | None: - """Load all intents for language.""" + async def async_reload(self, language: str | None = None): + """Clear cached intents for a language.""" + if language is None: + language = self.hass.config.language + + self._lang_intents.pop(language, None) + _LOGGER.debug("Cleared intents for language: %s", language) + + async def async_prepare(self, language: str | None = None): + """Load intents for a language.""" + if language is None: + language = self.hass.config.language + + lang_intents = await self.async_get_or_load_intents(language) + + if lang_intents is None: + # No intents loaded + _LOGGER.warning("No intents were loaded for language: %s", language) + + async def async_get_or_load_intents(self, language: str) -> LanguageIntents | None: + """Load all intents of a language with lock.""" + async with self._lang_lock[language]: + return await self.hass.async_add_executor_job( + self._get_or_load_intents, + language, + ) + + def _get_or_load_intents(self, language: str) -> LanguageIntents | None: + """Load all intents for language (run inside executor).""" lang_intents = self._lang_intents.get(language) if lang_intents is None: diff --git a/tests/components/conversation/test_init.py b/tests/components/conversation/test_init.py index 22ea6208214..88ca0a078f6 100644 --- a/tests/components/conversation/test_init.py +++ b/tests/components/conversation/test_init.py @@ -329,6 +329,34 @@ async def test_ws_api(hass, hass_ws_client, payload): } +# pylint: disable=protected-access +async def test_ws_prepare(hass, hass_ws_client): + """Test the Websocket prepare conversation API.""" + assert await async_setup_component(hass, "conversation", {}) + agent = await conversation._get_agent(hass) + assert isinstance(agent, conversation.DefaultAgent) + + # No intents should be loaded yet + assert not agent._lang_intents.get(hass.config.language) + + client = await hass_ws_client(hass) + + await client.send_json( + { + "id": 5, + "type": "conversation/prepare", + } + ) + + msg = await client.receive_json() + + assert msg["success"] + assert msg["id"] == 5 + + # Intents should now be load + assert agent._lang_intents.get(hass.config.language) + + async def test_custom_sentences(hass, hass_client, hass_admin_user): """Test custom sentences with a custom intent.""" assert await async_setup_component(hass, "homeassistant", {}) @@ -367,3 +395,39 @@ async def test_custom_sentences(hass, hass_client, hass_admin_user): }, "conversation_id": None, } + + +# pylint: disable=protected-access +async def test_prepare_reload(hass): + """Test calling the reload service.""" + language = hass.config.language + assert await async_setup_component(hass, "conversation", {}) + + # Load intents + agent = await conversation._get_agent(hass) + assert isinstance(agent, conversation.DefaultAgent) + await agent.async_prepare(language) + + # Confirm intents are loaded + assert agent._lang_intents.get(language) + + # Clear cache + await hass.services.async_call("conversation", "reload", {}) + await hass.async_block_till_done() + + # Confirm intent cache is cleared + assert not agent._lang_intents.get(language) + + +# pylint: disable=protected-access +async def test_prepare_fail(hass): + """Test calling prepare with a non-existent language.""" + assert await async_setup_component(hass, "conversation", {}) + + # Load intents + agent = await conversation._get_agent(hass) + assert isinstance(agent, conversation.DefaultAgent) + await agent.async_prepare("not-a-language") + + # Confirm no intents were loaded + assert not agent._lang_intents.get("not-a-language")