diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index a9356ab8b7e..e2e00a2652a 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -1,6 +1,7 @@ """Support for functionality to have conversations with Home Assistant.""" from __future__ import annotations +import asyncio import logging import re from typing import Any @@ -12,7 +13,7 @@ from homeassistant.components import http, websocket_api from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -from homeassistant.helpers import config_validation as cv, intent +from homeassistant.helpers import config_validation as cv, intent, singleton from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass @@ -23,20 +24,31 @@ _LOGGER = logging.getLogger(__name__) ATTR_TEXT = "text" ATTR_LANGUAGE = "language" +ATTR_AGENT_ID = "agent_id" DOMAIN = "conversation" REGEX_TYPE = type(re.compile("")) -DATA_AGENT = "conversation_agent" DATA_CONFIG = "conversation_config" SERVICE_PROCESS = "process" SERVICE_RELOAD = "reload" + +def agent_id_validator(value: Any) -> str: + """Validate agent ID.""" + hass = core.async_get_hass() + manager = _get_agent_manager(hass) + if not manager.async_is_valid_agent_id(cv.string(value)): + raise vol.Invalid("invalid agent ID") + return value + + SERVICE_PROCESS_SCHEMA = vol.Schema( { vol.Required(ATTR_TEXT): cv.string, vol.Optional(ATTR_LANGUAGE): cv.string, + vol.Optional(ATTR_AGENT_ID): agent_id_validator, } ) @@ -44,6 +56,7 @@ SERVICE_PROCESS_SCHEMA = vol.Schema( SERVICE_RELOAD_SCHEMA = vol.Schema( { vol.Optional(ATTR_LANGUAGE): cv.string, + vol.Optional(ATTR_AGENT_ID): agent_id_validator, } ) @@ -61,6 +74,13 @@ CONFIG_SCHEMA = vol.Schema( ) +@singleton.singleton("conversation_agent") +@core.callback +def _get_agent_manager(hass: HomeAssistant) -> AgentManager: + """Get the active agent.""" + return AgentManager(hass) + + @core.callback @bind_hass def async_set_agent( @@ -69,7 +89,7 @@ def async_set_agent( agent: AbstractConversationAgent, ): """Set the agent to handle the conversations.""" - hass.data[DATA_AGENT] = agent + _get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent) @core.callback @@ -79,11 +99,13 @@ def async_unset_agent( config_entry: ConfigEntry, ): """Set the agent to handle the conversations.""" - hass.data[DATA_AGENT] = None + _get_agent_manager(hass).async_unset_agent(config_entry.entry_id) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Register the process service.""" + agent_manager = _get_agent_manager(hass) + if config_intents := config.get(DOMAIN, {}).get("intents"): hass.data[DATA_CONFIG] = config_intents @@ -91,22 +113,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Parse text into commands.""" text = service.data[ATTR_TEXT] _LOGGER.debug("Processing: <%s>", text) - agent = await _get_agent(hass) try: - await agent.async_process( - ConversationInput( - text=text, - context=service.context, - conversation_id=None, - language=service.data.get(ATTR_LANGUAGE, hass.config.language), - ) + await async_converse( + hass=hass, + text=text, + conversation_id=None, + context=service.context, + language=service.data.get(ATTR_LANGUAGE), + agent_id=service.data.get(ATTR_AGENT_ID), ) except intent.IntentHandleError as err: _LOGGER.error("Error processing %s: %s", text, err) async def handle_reload(service: core.ServiceCall) -> None: """Reload intents.""" - agent = await _get_agent(hass) + agent = await agent_manager.async_get_agent() await agent.async_reload(language=service.data.get(ATTR_LANGUAGE)) hass.services.async_register( @@ -119,6 +140,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: websocket_api.async_register_command(hass, websocket_process) websocket_api.async_register_command(hass, websocket_prepare) websocket_api.async_register_command(hass, websocket_get_agent_info) + websocket_api.async_register_command(hass, websocket_list_agents) return True @@ -129,6 +151,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: vol.Required("text"): str, vol.Optional("conversation_id"): vol.Any(str, None), vol.Optional("language"): str, + vol.Optional("agent_id"): agent_id_validator, } ) @websocket_api.async_response @@ -139,11 +162,12 @@ async def websocket_process( ) -> None: """Process text.""" result = await async_converse( - hass, - msg["text"], - msg.get("conversation_id"), - connection.context(msg), - msg.get("language"), + hass=hass, + text=msg["text"], + conversation_id=msg.get("conversation_id"), + context=connection.context(msg), + language=msg.get("language"), + agent_id=msg.get("agent_id"), ) connection.send_result(msg["id"], result.as_dict()) @@ -152,6 +176,7 @@ async def websocket_process( { "type": "conversation/prepare", vol.Optional("language"): str, + vol.Optional("agent_id"): agent_id_validator, } ) @websocket_api.async_response @@ -161,7 +186,8 @@ async def websocket_prepare( msg: dict[str, Any], ) -> None: """Reload intents.""" - agent = await _get_agent(hass) + manager = _get_agent_manager(hass) + agent = await manager.async_get_agent(msg.get("agent_id")) await agent.async_prepare(msg.get("language")) connection.send_result(msg["id"]) @@ -169,6 +195,7 @@ async def websocket_prepare( @websocket_api.websocket_command( { vol.Required("type"): "conversation/agent/info", + vol.Optional("agent_id"): agent_id_validator, } ) @websocket_api.async_response @@ -178,7 +205,7 @@ async def websocket_get_agent_info( msg: dict[str, Any], ) -> None: """Info about the agent in use.""" - agent = await _get_agent(hass) + agent = await _get_agent_manager(hass).async_get_agent(msg.get("agent_id")) connection.send_result( msg["id"], @@ -188,6 +215,29 @@ async def websocket_get_agent_info( ) +@websocket_api.websocket_command( + { + vol.Required("type"): "conversation/agent/list", + } +) +@core.callback +def websocket_list_agents( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """List available agents.""" + manager = _get_agent_manager(hass) + + connection.send_result( + msg["id"], + { + "default_agent": manager.default_agent, + "agents": manager.async_get_agent_info(), + }, + ) + + class ConversationProcessView(http.HomeAssistantView): """View to process text.""" @@ -200,43 +250,41 @@ class ConversationProcessView(http.HomeAssistantView): vol.Required("text"): str, vol.Optional("conversation_id"): str, vol.Optional("language"): str, + vol.Optional("agent_id"): agent_id_validator, } ) ) async def post(self, request, data): """Send a request for processing.""" hass = request.app["hass"] + result = await async_converse( hass, text=data["text"], conversation_id=data.get("conversation_id"), context=self.context(request), language=data.get("language"), + agent_id=data.get("agent_id"), ) return self.json(result.as_dict()) -async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent: - """Get the active conversation agent.""" - if (agent := hass.data.get(DATA_AGENT)) is None: - agent = hass.data[DATA_AGENT] = DefaultAgent(hass) - await agent.async_initialize(hass.data.get(DATA_CONFIG)) - return agent - - async def async_converse( hass: core.HomeAssistant, text: str, conversation_id: str | None, context: core.Context, language: str | None = None, + agent_id: str | None = None, ) -> ConversationResult: """Process text and get intent.""" - agent = await _get_agent(hass) + agent = await _get_agent_manager(hass).async_get_agent(agent_id) + if language is None: language = hass.config.language + _LOGGER.debug("Processing in %s: %s", language, text) result = await agent.async_process( ConversationInput( text=text, @@ -246,3 +294,88 @@ async def async_converse( ) ) return result + + +class AgentManager: + """Class to manage conversation agents.""" + + HOME_ASSISTANT_AGENT = "homeassistant" + + default_agent: str = HOME_ASSISTANT_AGENT + _builtin_agent: AbstractConversationAgent | None = None + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the conversation agents.""" + self.hass = hass + self._agents: dict[str, AbstractConversationAgent] = {} + self._default_agent_init_lock = asyncio.Lock() + + async def async_get_agent( + self, agent_id: str | None = None + ) -> AbstractConversationAgent: + """Get the agent.""" + if agent_id is None: + agent_id = self.default_agent + + if agent_id == AgentManager.HOME_ASSISTANT_AGENT: + if self._builtin_agent is not None: + return self._builtin_agent + + async with self._default_agent_init_lock: + if self._builtin_agent is not None: + return self._builtin_agent + + self._builtin_agent = DefaultAgent(self.hass) + await self._builtin_agent.async_initialize( + self.hass.data.get(DATA_CONFIG) + ) + + return self._builtin_agent + + return self._agents[agent_id] + + @core.callback + def async_get_agent_info(self) -> list[dict[str, Any]]: + """List all agents.""" + agents = [ + { + "id": AgentManager.HOME_ASSISTANT_AGENT, + "name": "Home Assistant", + } + ] + for agent_id, agent in self._agents.items(): + config_entry = self.hass.config_entries.async_get_entry(agent_id) + + # This is a bug, agent should have been unset when config entry was unloaded + if config_entry is None: + _LOGGER.warning( + "Agent was still loaded while config entry is gone: %s", agent + ) + continue + + agents.append( + { + "id": agent_id, + "name": config_entry.title, + } + ) + return agents + + @core.callback + def async_is_valid_agent_id(self, agent_id: str) -> bool: + """Check if the agent id is valid.""" + return agent_id in self._agents or agent_id == AgentManager.HOME_ASSISTANT_AGENT + + @core.callback + def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None: + """Set the agent.""" + self._agents[agent_id] = agent + if self.default_agent == AgentManager.HOME_ASSISTANT_AGENT: + self.default_agent = agent_id + + @core.callback + def async_unset_agent(self, agent_id: str) -> None: + """Unset the agent.""" + if self.default_agent == agent_id: + self.default_agent = AgentManager.HOME_ASSISTANT_AGENT + self._agents.pop(agent_id, None) diff --git a/homeassistant/components/conversation/services.yaml b/homeassistant/components/conversation/services.yaml index edba9ffb0b9..6b031ff7142 100644 --- a/homeassistant/components/conversation/services.yaml +++ b/homeassistant/components/conversation/services.yaml @@ -9,3 +9,15 @@ process: example: Turn all lights on selector: text: + language: + name: Language + description: Language of text. Defaults to server language + example: NL + selector: + text: + agent_id: + name: Agent + description: Assist engine to process your request + example: homeassistant + selector: + text: diff --git a/tests/components/conversation/__init__.py b/tests/components/conversation/__init__.py index 8c5371f8cbe..fac9ae95e61 100644 --- a/tests/components/conversation/__init__.py +++ b/tests/components/conversation/__init__.py @@ -8,8 +8,9 @@ from homeassistant.helpers import intent class MockAgent(conversation.AbstractConversationAgent): """Test Agent.""" - def __init__(self) -> None: + def __init__(self, agent_id: str) -> None: """Initialize the agent.""" + self.agent_id = agent_id self.calls = [] self.response = "Test response" diff --git a/tests/components/conversation/conftest.py b/tests/components/conversation/conftest.py index 35f9937e5a0..46f57dbcab9 100644 --- a/tests/components/conversation/conftest.py +++ b/tests/components/conversation/conftest.py @@ -6,10 +6,14 @@ from homeassistant.components import conversation from . import MockAgent +from tests.common import MockConfigEntry + @pytest.fixture def mock_agent(hass): """Mock agent.""" - agent = MockAgent() - conversation.async_set_agent(hass, None, agent) + entry = MockConfigEntry(entry_id="mock-entry") + entry.add_to_hass(hass) + agent = MockAgent(entry.entry_id) + conversation.async_set_agent(hass, entry, agent) return agent diff --git a/tests/components/conversation/test_init.py b/tests/components/conversation/test_init.py index 54fed8a6139..f5928ddfa35 100644 --- a/tests/components/conversation/test_init.py +++ b/tests/components/conversation/test_init.py @@ -3,6 +3,7 @@ from http import HTTPStatus from unittest.mock import patch import pytest +import voluptuous as vol from homeassistant.components import conversation from homeassistant.components.cover import SERVICE_OPEN_COVER @@ -18,6 +19,8 @@ from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry, async_mock_service +AGENT_ID_OPTIONS = [None, conversation.AgentManager.HOME_ASSISTANT_AGENT] + class OrderBeerIntentHandler(intent.IntentHandler): """Handle OrderBeer intent.""" @@ -40,8 +43,9 @@ async def init_components(hass): assert await async_setup_component(hass, "intent", {}) +@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS) async def test_http_processing_intent( - hass, init_components, hass_client, hass_admin_user + hass, init_components, hass_client, hass_admin_user, agent_id ): """Test processing intent via HTTP API.""" # Add an alias @@ -50,9 +54,52 @@ async def test_http_processing_intent( entities.async_update_entity("light.kitchen", aliases={"my cool light"}) hass.states.async_set("light.kitchen", "off") + client = await hass_client() + data = {"text": "turn on my cool light"} + if agent_id: + data["agent_id"] = agent_id + resp = await client.post("/api/conversation/process", json=data) + + assert resp.status == HTTPStatus.OK + data = await resp.json() + + assert data == { + "response": { + "response_type": "action_done", + "card": {}, + "speech": { + "plain": { + "extra_data": None, + "speech": "Turned on my cool light", + } + }, + "language": hass.config.language, + "data": { + "targets": [], + "success": [ + {"id": "light.kitchen", "name": "kitchen", "type": "entity"} + ], + "failed": [], + }, + }, + "conversation_id": None, + } + + +async def test_http_processing_intent_target_ha_agent( + hass, init_components, hass_client, hass_admin_user, mock_agent +): + """Test processing intent can be processed via HTTP API with picking agent.""" + # Add an alias + entities = entity_registry.async_get(hass) + entities.async_get_or_create("light", "demo", "1234", suggested_object_id="kitchen") + entities.async_update_entity("light.kitchen", aliases={"my cool light"}) + hass.states.async_set("light.kitchen", "off") + client = await hass_client() resp = await client.post( - "/api/conversation/process", json={"text": "turn on my cool light"} + "/api/conversation/process", + json={"text": "turn on my cool light", "agent_id": "homeassistant"}, ) assert resp.status == HTTPStatus.OK @@ -218,15 +265,17 @@ async def test_http_processing_intent_entity_added( } +@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS) @pytest.mark.parametrize("sentence", ("turn on kitchen", "turn kitchen on")) -async def test_turn_on_intent(hass, init_components, sentence): +async def test_turn_on_intent(hass, init_components, sentence, agent_id): """Test calling the turn on intent.""" hass.states.async_set("light.kitchen", "off") calls = async_mock_service(hass, HASS_DOMAIN, "turn_on") - await hass.services.async_call( - "conversation", "process", {conversation.ATTR_TEXT: sentence} - ) + data = {conversation.ATTR_TEXT: sentence} + if agent_id is not None: + data[conversation.ATTR_AGENT_ID] = agent_id + await hass.services.async_call("conversation", "process", data) await hass.async_block_till_done() assert len(calls) == 1 @@ -254,46 +303,6 @@ async def test_turn_off_intent(hass, init_components, sentence): assert call.data == {"entity_id": "light.kitchen"} -async def test_http_api(hass, init_components, hass_client): - """Test the HTTP conversation API.""" - client = await hass_client() - hass.states.async_set("light.kitchen", "off") - calls = async_mock_service(hass, HASS_DOMAIN, "turn_on") - - resp = await client.post( - "/api/conversation/process", json={"text": "Turn the kitchen on"} - ) - assert resp.status == HTTPStatus.OK - data = await resp.json() - - assert data == { - "response": { - "card": {}, - "speech": {"plain": {"extra_data": None, "speech": "Turned on kitchen"}}, - "language": hass.config.language, - "response_type": "action_done", - "data": { - "targets": [], - "success": [ - { - "type": "entity", - "name": "kitchen", - "id": "light.kitchen", - }, - ], - "failed": [], - }, - }, - "conversation_id": None, - } - - assert len(calls) == 1 - call = calls[0] - assert call.domain == HASS_DOMAIN - assert call.service == "turn_on" - assert call.data == {"entity_id": "light.kitchen"} - - async def test_http_api_no_match(hass, init_components, hass_client): """Test the HTTP conversation API with an intent match failure.""" client = await hass_client() @@ -406,20 +415,22 @@ async def test_http_api_wrong_data(hass, init_components, hass_client): assert resp.status == HTTPStatus.BAD_REQUEST -async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent): +@pytest.mark.parametrize("agent_id", (None, "mock-entry")) +async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent, agent_id): """Test a custom conversation agent.""" assert await async_setup_component(hass, "conversation", {}) client = await hass_client() - resp = await client.post( - "/api/conversation/process", - json={ - "text": "Test Text", - "conversation_id": "test-conv-id", - "language": "test-language", - }, - ) + data = { + "text": "Test Text", + "conversation_id": "test-conv-id", + "language": "test-language", + } + if agent_id is not None: + data["agent_id"] = agent_id + + resp = await client.post("/api/conversation/process", json=data) assert resp.status == HTTPStatus.OK assert await resp.json() == { "response": { @@ -443,6 +454,10 @@ async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent): assert mock_agent.calls[0].conversation_id == "test-conv-id" assert mock_agent.calls[0].language == "test-language" + conversation.async_unset_agent( + hass, hass.config_entries.async_get_entry(mock_agent.agent_id) + ) + @pytest.mark.parametrize( "payload", @@ -467,6 +482,10 @@ async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent): "conversation_id": "test-conv-id", "language": "test-language", }, + { + "text": "Test Text", + "agent_id": "homeassistant", + }, ], ) async def test_ws_api(hass, hass_ws_client, payload): @@ -496,10 +515,11 @@ async def test_ws_api(hass, hass_ws_client, payload): } -async def test_ws_prepare(hass, hass_ws_client): +@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS) +async def test_ws_prepare(hass, hass_ws_client, agent_id): """Test the Websocket prepare conversation API.""" assert await async_setup_component(hass, "conversation", {}) - agent = await conversation._get_agent(hass) + agent = await conversation._get_agent_manager(hass).async_get_agent() assert isinstance(agent, conversation.DefaultAgent) # No intents should be loaded yet @@ -507,12 +527,13 @@ async def test_ws_prepare(hass, hass_ws_client): client = await hass_ws_client(hass) - await client.send_json( - { - "id": 5, - "type": "conversation/prepare", - } - ) + msg = { + "id": 5, + "type": "conversation/prepare", + } + if agent_id is not None: + msg["agent_id"] = agent_id + await client.send_json(msg) msg = await client.receive_json() @@ -618,7 +639,7 @@ async def test_prepare_reload(hass): assert await async_setup_component(hass, "conversation", {}) # Load intents - agent = await conversation._get_agent(hass) + agent = await conversation._get_agent_manager(hass).async_get_agent() assert isinstance(agent, conversation.DefaultAgent) await agent.async_prepare(language) @@ -638,7 +659,7 @@ async def test_prepare_fail(hass): assert await async_setup_component(hass, "conversation", {}) # Load intents - agent = await conversation._get_agent(hass) + agent = await conversation._get_agent_manager(hass).async_get_agent() assert isinstance(agent, conversation.DefaultAgent) await agent.async_prepare("not-a-language") @@ -676,7 +697,7 @@ async def test_reload_on_new_component(hass): assert await async_setup_component(hass, "conversation", {}) # Load intents - agent = await conversation._get_agent(hass) + agent = await conversation._get_agent_manager(hass).async_get_agent() assert isinstance(agent, conversation.DefaultAgent) await agent.async_prepare() @@ -700,7 +721,7 @@ async def test_non_default_response(hass, init_components): hass.states.async_set("cover.front_door", "closed") async_mock_service(hass, "cover", SERVICE_OPEN_COVER) - agent = await conversation._get_agent(hass) + agent = await conversation._get_agent_manager(hass).async_get_agent() assert isinstance(agent, conversation.DefaultAgent) result = await agent.async_process( @@ -826,3 +847,30 @@ async def test_light_area_same_name(hass, init_components): assert call.domain == HASS_DOMAIN assert call.service == "turn_on" assert call.data == {"entity_id": kitchen_light.entity_id} + + +async def test_agent_id_validator_invalid_agent(hass): + """Test validating agent id.""" + with pytest.raises(vol.Invalid): + conversation.agent_id_validator("invalid_agent") + + conversation.agent_id_validator(conversation.AgentManager.HOME_ASSISTANT_AGENT) + + +async def test_get_agent_list(hass, init_components, mock_agent, hass_ws_client): + """Test getting agent info.""" + client = await hass_ws_client(hass) + + await client.send_json({"id": 5, "type": "conversation/agent/list"}) + + msg = await client.receive_json() + assert msg["id"] == 5 + assert msg["type"] == "result" + assert msg["success"] + assert msg["result"] == { + "agents": [ + {"id": "homeassistant", "name": "Home Assistant"}, + {"id": "mock-entry", "name": "Mock Title"}, + ], + "default_agent": "mock-entry", + }