mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 09:47:13 +00:00
Add conversation reload service (#86175)
* Add preload and reload service calls to conversation * Add conversation preload/reload to websocket API * Merge prepare into reload * reload service and prepare websocket API * Add preload and reload service calls to conversation * Add conversation preload/reload to websocket API * Merge prepare into reload * reload service and prepare websocket API * Add language lock for loading intents * Add more tests for code coverage
This commit is contained in:
parent
ca885f3fab
commit
2f98485ae7
@ -30,11 +30,16 @@ DATA_AGENT = "conversation_agent"
|
||||
DATA_CONFIG = "conversation_config"
|
||||
|
||||
SERVICE_PROCESS = "process"
|
||||
SERVICE_RELOAD = "reload"
|
||||
|
||||
SERVICE_PROCESS_SCHEMA = vol.Schema(
|
||||
{vol.Required(ATTR_TEXT): cv.string, vol.Optional(ATTR_LANGUAGE): cv.string}
|
||||
)
|
||||
|
||||
|
||||
SERVICE_RELOAD_SCHEMA = vol.Schema({vol.Optional(ATTR_LANGUAGE): cv.string})
|
||||
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
@ -62,7 +67,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Register the process service."""
|
||||
hass.data[DATA_CONFIG] = config
|
||||
|
||||
async def handle_service(service: core.ServiceCall) -> None:
|
||||
async def handle_process(service: core.ServiceCall) -> None:
|
||||
"""Parse text into commands."""
|
||||
text = service.data[ATTR_TEXT]
|
||||
_LOGGER.debug("Processing: <%s>", text)
|
||||
@ -74,11 +79,20 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
except intent.IntentHandleError as err:
|
||||
_LOGGER.error("Error processing %s: %s", text, err)
|
||||
|
||||
async def handle_reload(service: core.ServiceCall) -> None:
|
||||
"""Reload intents."""
|
||||
agent = await _get_agent(hass)
|
||||
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_PROCESS, handle_service, schema=SERVICE_PROCESS_SCHEMA
|
||||
DOMAIN, SERVICE_PROCESS, handle_process, schema=SERVICE_PROCESS_SCHEMA
|
||||
)
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_RELOAD, handle_reload, schema=SERVICE_RELOAD_SCHEMA
|
||||
)
|
||||
hass.http.register_view(ConversationProcessView())
|
||||
websocket_api.async_register_command(hass, websocket_process)
|
||||
websocket_api.async_register_command(hass, websocket_prepare)
|
||||
websocket_api.async_register_command(hass, websocket_get_agent_info)
|
||||
websocket_api.async_register_command(hass, websocket_set_onboarding)
|
||||
|
||||
@ -110,6 +124,24 @@ async def websocket_process(
|
||||
connection.send_result(msg["id"], result.as_dict())
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
"type": "conversation/prepare",
|
||||
vol.Optional("language"): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_prepare(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Reload intents."""
|
||||
agent = await _get_agent(hass)
|
||||
await agent.async_prepare(msg.get("language"))
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
@websocket_api.websocket_command({"type": "conversation/agent/info"})
|
||||
@websocket_api.async_response
|
||||
async def websocket_get_agent_info(
|
||||
|
@ -49,3 +49,9 @@ class AbstractConversationAgent(ABC):
|
||||
language: str | None = None,
|
||||
) -> ConversationResult | None:
|
||||
"""Process a sentence."""
|
||||
|
||||
async def async_reload(self, language: str | None = None):
|
||||
"""Clear cached intents for a language."""
|
||||
|
||||
async def async_prepare(self, language: str | None = None):
|
||||
"""Load intents for a language."""
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""Standard conversation implementation for Home Assistant."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@ -58,6 +60,7 @@ class DefaultAgent(AbstractConversationAgent):
|
||||
"""Initialize the default agent."""
|
||||
self.hass = hass
|
||||
self._lang_intents: dict[str, LanguageIntents] = {}
|
||||
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
async def async_initialize(self, config):
|
||||
"""Initialize the default agent."""
|
||||
@ -88,10 +91,7 @@ class DefaultAgent(AbstractConversationAgent):
|
||||
lang_intents.loaded_components - self.hass.config.components
|
||||
):
|
||||
# Load intents in executor
|
||||
lang_intents = await self.hass.async_add_executor_job(
|
||||
self.get_or_load_intents,
|
||||
language,
|
||||
)
|
||||
lang_intents = await self.async_get_or_load_intents(language)
|
||||
|
||||
if lang_intents is None:
|
||||
# No intents loaded
|
||||
@ -121,8 +121,35 @@ class DefaultAgent(AbstractConversationAgent):
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
def get_or_load_intents(self, language: str) -> LanguageIntents | None:
|
||||
"""Load all intents for language."""
|
||||
async def async_reload(self, language: str | None = None):
|
||||
"""Clear cached intents for a language."""
|
||||
if language is None:
|
||||
language = self.hass.config.language
|
||||
|
||||
self._lang_intents.pop(language, None)
|
||||
_LOGGER.debug("Cleared intents for language: %s", language)
|
||||
|
||||
async def async_prepare(self, language: str | None = None):
|
||||
"""Load intents for a language."""
|
||||
if language is None:
|
||||
language = self.hass.config.language
|
||||
|
||||
lang_intents = await self.async_get_or_load_intents(language)
|
||||
|
||||
if lang_intents is None:
|
||||
# No intents loaded
|
||||
_LOGGER.warning("No intents were loaded for language: %s", language)
|
||||
|
||||
async def async_get_or_load_intents(self, language: str) -> LanguageIntents | None:
|
||||
"""Load all intents of a language with lock."""
|
||||
async with self._lang_lock[language]:
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._get_or_load_intents,
|
||||
language,
|
||||
)
|
||||
|
||||
def _get_or_load_intents(self, language: str) -> LanguageIntents | None:
|
||||
"""Load all intents for language (run inside executor)."""
|
||||
lang_intents = self._lang_intents.get(language)
|
||||
|
||||
if lang_intents is None:
|
||||
|
@ -329,6 +329,34 @@ async def test_ws_api(hass, hass_ws_client, payload):
|
||||
}
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
async def test_ws_prepare(hass, hass_ws_client):
|
||||
"""Test the Websocket prepare conversation API."""
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
agent = await conversation._get_agent(hass)
|
||||
assert isinstance(agent, conversation.DefaultAgent)
|
||||
|
||||
# No intents should be loaded yet
|
||||
assert not agent._lang_intents.get(hass.config.language)
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "conversation/prepare",
|
||||
}
|
||||
)
|
||||
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg["success"]
|
||||
assert msg["id"] == 5
|
||||
|
||||
# Intents should now be load
|
||||
assert agent._lang_intents.get(hass.config.language)
|
||||
|
||||
|
||||
async def test_custom_sentences(hass, hass_client, hass_admin_user):
|
||||
"""Test custom sentences with a custom intent."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
@ -367,3 +395,39 @@ async def test_custom_sentences(hass, hass_client, hass_admin_user):
|
||||
},
|
||||
"conversation_id": None,
|
||||
}
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
async def test_prepare_reload(hass):
|
||||
"""Test calling the reload service."""
|
||||
language = hass.config.language
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
|
||||
# Load intents
|
||||
agent = await conversation._get_agent(hass)
|
||||
assert isinstance(agent, conversation.DefaultAgent)
|
||||
await agent.async_prepare(language)
|
||||
|
||||
# Confirm intents are loaded
|
||||
assert agent._lang_intents.get(language)
|
||||
|
||||
# Clear cache
|
||||
await hass.services.async_call("conversation", "reload", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Confirm intent cache is cleared
|
||||
assert not agent._lang_intents.get(language)
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
async def test_prepare_fail(hass):
|
||||
"""Test calling prepare with a non-existent language."""
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
|
||||
# Load intents
|
||||
agent = await conversation._get_agent(hass)
|
||||
assert isinstance(agent, conversation.DefaultAgent)
|
||||
await agent.async_prepare("not-a-language")
|
||||
|
||||
# Confirm no intents were loaded
|
||||
assert not agent._lang_intents.get("not-a-language")
|
||||
|
Loading…
x
Reference in New Issue
Block a user