Add conversation reload service (#86175)

* Add preload and reload service calls to conversation

* Add conversation preload/reload to websocket API

* Merge prepare into reload

* reload service and prepare websocket API

* Add preload and reload service calls to conversation

* Add conversation preload/reload to websocket API

* Merge prepare into reload

* reload service and prepare websocket API

* Add language lock for loading intents

* Add more tests for code coverage
This commit is contained in:
Michael Hansen 2023-01-18 19:36:51 -06:00 committed by GitHub
parent ca885f3fab
commit 2f98485ae7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 137 additions and 8 deletions

View File

@ -30,11 +30,16 @@ DATA_AGENT = "conversation_agent"
DATA_CONFIG = "conversation_config"
SERVICE_PROCESS = "process"
SERVICE_RELOAD = "reload"
SERVICE_PROCESS_SCHEMA = vol.Schema(
{vol.Required(ATTR_TEXT): cv.string, vol.Optional(ATTR_LANGUAGE): cv.string}
)
SERVICE_RELOAD_SCHEMA = vol.Schema({vol.Optional(ATTR_LANGUAGE): cv.string})
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.Schema(
@ -62,7 +67,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service."""
hass.data[DATA_CONFIG] = config
async def handle_service(service: core.ServiceCall) -> None:
async def handle_process(service: core.ServiceCall) -> None:
"""Parse text into commands."""
text = service.data[ATTR_TEXT]
_LOGGER.debug("Processing: <%s>", text)
@ -74,11 +79,20 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
except intent.IntentHandleError as err:
_LOGGER.error("Error processing %s: %s", text, err)
async def handle_reload(service: core.ServiceCall) -> None:
"""Reload intents."""
agent = await _get_agent(hass)
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
hass.services.async_register(
DOMAIN, SERVICE_PROCESS, handle_service, schema=SERVICE_PROCESS_SCHEMA
DOMAIN, SERVICE_PROCESS, handle_process, schema=SERVICE_PROCESS_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_RELOAD, handle_reload, schema=SERVICE_RELOAD_SCHEMA
)
hass.http.register_view(ConversationProcessView())
websocket_api.async_register_command(hass, websocket_process)
websocket_api.async_register_command(hass, websocket_prepare)
websocket_api.async_register_command(hass, websocket_get_agent_info)
websocket_api.async_register_command(hass, websocket_set_onboarding)
@ -110,6 +124,24 @@ async def websocket_process(
connection.send_result(msg["id"], result.as_dict())
@websocket_api.websocket_command(
{
"type": "conversation/prepare",
vol.Optional("language"): str,
}
)
@websocket_api.async_response
async def websocket_prepare(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Reload intents."""
agent = await _get_agent(hass)
await agent.async_prepare(msg.get("language"))
connection.send_result(msg["id"])
@websocket_api.websocket_command({"type": "conversation/agent/info"})
@websocket_api.async_response
async def websocket_get_agent_info(

View File

@ -49,3 +49,9 @@ class AbstractConversationAgent(ABC):
language: str | None = None,
) -> ConversationResult | None:
"""Process a sentence."""
async def async_reload(self, language: str | None = None):
"""Clear cached intents for a language."""
async def async_prepare(self, language: str | None = None):
"""Load intents for a language."""

View File

@ -1,6 +1,8 @@
"""Standard conversation implementation for Home Assistant."""
from __future__ import annotations
import asyncio
from collections import defaultdict
from dataclasses import dataclass
import logging
from pathlib import Path
@ -58,6 +60,7 @@ class DefaultAgent(AbstractConversationAgent):
"""Initialize the default agent."""
self.hass = hass
self._lang_intents: dict[str, LanguageIntents] = {}
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
async def async_initialize(self, config):
"""Initialize the default agent."""
@ -88,10 +91,7 @@ class DefaultAgent(AbstractConversationAgent):
lang_intents.loaded_components - self.hass.config.components
):
# Load intents in executor
lang_intents = await self.hass.async_add_executor_job(
self.get_or_load_intents,
language,
)
lang_intents = await self.async_get_or_load_intents(language)
if lang_intents is None:
# No intents loaded
@ -121,8 +121,35 @@ class DefaultAgent(AbstractConversationAgent):
response=intent_response, conversation_id=conversation_id
)
def get_or_load_intents(self, language: str) -> LanguageIntents | None:
"""Load all intents for language."""
async def async_reload(self, language: str | None = None):
"""Clear cached intents for a language."""
if language is None:
language = self.hass.config.language
self._lang_intents.pop(language, None)
_LOGGER.debug("Cleared intents for language: %s", language)
async def async_prepare(self, language: str | None = None):
"""Load intents for a language."""
if language is None:
language = self.hass.config.language
lang_intents = await self.async_get_or_load_intents(language)
if lang_intents is None:
# No intents loaded
_LOGGER.warning("No intents were loaded for language: %s", language)
async def async_get_or_load_intents(self, language: str) -> LanguageIntents | None:
"""Load all intents of a language with lock."""
async with self._lang_lock[language]:
return await self.hass.async_add_executor_job(
self._get_or_load_intents,
language,
)
def _get_or_load_intents(self, language: str) -> LanguageIntents | None:
"""Load all intents for language (run inside executor)."""
lang_intents = self._lang_intents.get(language)
if lang_intents is None:

View File

@ -329,6 +329,34 @@ async def test_ws_api(hass, hass_ws_client, payload):
}
# pylint: disable=protected-access
async def test_ws_prepare(hass, hass_ws_client):
"""Test the Websocket prepare conversation API."""
assert await async_setup_component(hass, "conversation", {})
agent = await conversation._get_agent(hass)
assert isinstance(agent, conversation.DefaultAgent)
# No intents should be loaded yet
assert not agent._lang_intents.get(hass.config.language)
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "conversation/prepare",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["id"] == 5
# Intents should now be load
assert agent._lang_intents.get(hass.config.language)
async def test_custom_sentences(hass, hass_client, hass_admin_user):
"""Test custom sentences with a custom intent."""
assert await async_setup_component(hass, "homeassistant", {})
@ -367,3 +395,39 @@ async def test_custom_sentences(hass, hass_client, hass_admin_user):
},
"conversation_id": None,
}
# pylint: disable=protected-access
async def test_prepare_reload(hass):
"""Test calling the reload service."""
language = hass.config.language
assert await async_setup_component(hass, "conversation", {})
# Load intents
agent = await conversation._get_agent(hass)
assert isinstance(agent, conversation.DefaultAgent)
await agent.async_prepare(language)
# Confirm intents are loaded
assert agent._lang_intents.get(language)
# Clear cache
await hass.services.async_call("conversation", "reload", {})
await hass.async_block_till_done()
# Confirm intent cache is cleared
assert not agent._lang_intents.get(language)
# pylint: disable=protected-access
async def test_prepare_fail(hass):
"""Test calling prepare with a non-existent language."""
assert await async_setup_component(hass, "conversation", {})
# Load intents
agent = await conversation._get_agent(hass)
assert isinstance(agent, conversation.DefaultAgent)
await agent.async_prepare("not-a-language")
# Confirm no intents were loaded
assert not agent._lang_intents.get("not-a-language")