Add support for multiple conversation agents (#87337)

* Add support for multiple conversation agents

* Lock initializing default agent

* Allow unsetting agent when never set
This commit is contained in:
Paulus Schoutsen 2023-02-03 23:35:29 -05:00 committed by GitHub
parent 3f992ed31d
commit fc38b4327f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 297 additions and 99 deletions

View File

@ -1,6 +1,7 @@
"""Support for functionality to have conversations with Home Assistant.""" """Support for functionality to have conversations with Home Assistant."""
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import re import re
from typing import Any from typing import Any
@ -12,7 +13,7 @@ from homeassistant.components import http, websocket_api
from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, intent from homeassistant.helpers import config_validation as cv, intent, singleton
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
@ -23,20 +24,31 @@ _LOGGER = logging.getLogger(__name__)
ATTR_TEXT = "text" ATTR_TEXT = "text"
ATTR_LANGUAGE = "language" ATTR_LANGUAGE = "language"
ATTR_AGENT_ID = "agent_id"
DOMAIN = "conversation" DOMAIN = "conversation"
REGEX_TYPE = type(re.compile("")) REGEX_TYPE = type(re.compile(""))
DATA_AGENT = "conversation_agent"
DATA_CONFIG = "conversation_config" DATA_CONFIG = "conversation_config"
SERVICE_PROCESS = "process" SERVICE_PROCESS = "process"
SERVICE_RELOAD = "reload" SERVICE_RELOAD = "reload"
def agent_id_validator(value: Any) -> str:
"""Validate agent ID."""
hass = core.async_get_hass()
manager = _get_agent_manager(hass)
if not manager.async_is_valid_agent_id(cv.string(value)):
raise vol.Invalid("invalid agent ID")
return value
SERVICE_PROCESS_SCHEMA = vol.Schema( SERVICE_PROCESS_SCHEMA = vol.Schema(
{ {
vol.Required(ATTR_TEXT): cv.string, vol.Required(ATTR_TEXT): cv.string,
vol.Optional(ATTR_LANGUAGE): cv.string, vol.Optional(ATTR_LANGUAGE): cv.string,
vol.Optional(ATTR_AGENT_ID): agent_id_validator,
} }
) )
@ -44,6 +56,7 @@ SERVICE_PROCESS_SCHEMA = vol.Schema(
SERVICE_RELOAD_SCHEMA = vol.Schema( SERVICE_RELOAD_SCHEMA = vol.Schema(
{ {
vol.Optional(ATTR_LANGUAGE): cv.string, vol.Optional(ATTR_LANGUAGE): cv.string,
vol.Optional(ATTR_AGENT_ID): agent_id_validator,
} }
) )
@ -61,6 +74,13 @@ CONFIG_SCHEMA = vol.Schema(
) )
@singleton.singleton("conversation_agent")
@core.callback
def _get_agent_manager(hass: HomeAssistant) -> AgentManager:
"""Get the active agent."""
return AgentManager(hass)
@core.callback @core.callback
@bind_hass @bind_hass
def async_set_agent( def async_set_agent(
@ -69,7 +89,7 @@ def async_set_agent(
agent: AbstractConversationAgent, agent: AbstractConversationAgent,
): ):
"""Set the agent to handle the conversations.""" """Set the agent to handle the conversations."""
hass.data[DATA_AGENT] = agent _get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent)
@core.callback @core.callback
@ -79,11 +99,13 @@ def async_unset_agent(
config_entry: ConfigEntry, config_entry: ConfigEntry,
): ):
"""Set the agent to handle the conversations.""" """Set the agent to handle the conversations."""
hass.data[DATA_AGENT] = None _get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service.""" """Register the process service."""
agent_manager = _get_agent_manager(hass)
if config_intents := config.get(DOMAIN, {}).get("intents"): if config_intents := config.get(DOMAIN, {}).get("intents"):
hass.data[DATA_CONFIG] = config_intents hass.data[DATA_CONFIG] = config_intents
@ -91,22 +113,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Parse text into commands.""" """Parse text into commands."""
text = service.data[ATTR_TEXT] text = service.data[ATTR_TEXT]
_LOGGER.debug("Processing: <%s>", text) _LOGGER.debug("Processing: <%s>", text)
agent = await _get_agent(hass)
try: try:
await agent.async_process( await async_converse(
ConversationInput( hass=hass,
text=text, text=text,
context=service.context, conversation_id=None,
conversation_id=None, context=service.context,
language=service.data.get(ATTR_LANGUAGE, hass.config.language), language=service.data.get(ATTR_LANGUAGE),
) agent_id=service.data.get(ATTR_AGENT_ID),
) )
except intent.IntentHandleError as err: except intent.IntentHandleError as err:
_LOGGER.error("Error processing %s: %s", text, err) _LOGGER.error("Error processing %s: %s", text, err)
async def handle_reload(service: core.ServiceCall) -> None: async def handle_reload(service: core.ServiceCall) -> None:
"""Reload intents.""" """Reload intents."""
agent = await _get_agent(hass) agent = await agent_manager.async_get_agent()
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE)) await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
hass.services.async_register( hass.services.async_register(
@ -119,6 +140,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
websocket_api.async_register_command(hass, websocket_process) websocket_api.async_register_command(hass, websocket_process)
websocket_api.async_register_command(hass, websocket_prepare) websocket_api.async_register_command(hass, websocket_prepare)
websocket_api.async_register_command(hass, websocket_get_agent_info) websocket_api.async_register_command(hass, websocket_get_agent_info)
websocket_api.async_register_command(hass, websocket_list_agents)
return True return True
@ -129,6 +151,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
vol.Required("text"): str, vol.Required("text"): str,
vol.Optional("conversation_id"): vol.Any(str, None), vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("language"): str, vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
} }
) )
@websocket_api.async_response @websocket_api.async_response
@ -139,11 +162,12 @@ async def websocket_process(
) -> None: ) -> None:
"""Process text.""" """Process text."""
result = await async_converse( result = await async_converse(
hass, hass=hass,
msg["text"], text=msg["text"],
msg.get("conversation_id"), conversation_id=msg.get("conversation_id"),
connection.context(msg), context=connection.context(msg),
msg.get("language"), language=msg.get("language"),
agent_id=msg.get("agent_id"),
) )
connection.send_result(msg["id"], result.as_dict()) connection.send_result(msg["id"], result.as_dict())
@ -152,6 +176,7 @@ async def websocket_process(
{ {
"type": "conversation/prepare", "type": "conversation/prepare",
vol.Optional("language"): str, vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
} }
) )
@websocket_api.async_response @websocket_api.async_response
@ -161,7 +186,8 @@ async def websocket_prepare(
msg: dict[str, Any], msg: dict[str, Any],
) -> None: ) -> None:
"""Reload intents.""" """Reload intents."""
agent = await _get_agent(hass) manager = _get_agent_manager(hass)
agent = await manager.async_get_agent(msg.get("agent_id"))
await agent.async_prepare(msg.get("language")) await agent.async_prepare(msg.get("language"))
connection.send_result(msg["id"]) connection.send_result(msg["id"])
@ -169,6 +195,7 @@ async def websocket_prepare(
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "conversation/agent/info", vol.Required("type"): "conversation/agent/info",
vol.Optional("agent_id"): agent_id_validator,
} }
) )
@websocket_api.async_response @websocket_api.async_response
@ -178,7 +205,7 @@ async def websocket_get_agent_info(
msg: dict[str, Any], msg: dict[str, Any],
) -> None: ) -> None:
"""Info about the agent in use.""" """Info about the agent in use."""
agent = await _get_agent(hass) agent = await _get_agent_manager(hass).async_get_agent(msg.get("agent_id"))
connection.send_result( connection.send_result(
msg["id"], msg["id"],
@ -188,6 +215,29 @@ async def websocket_get_agent_info(
) )
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/agent/list",
}
)
@core.callback
def websocket_list_agents(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""List available agents."""
manager = _get_agent_manager(hass)
connection.send_result(
msg["id"],
{
"default_agent": manager.default_agent,
"agents": manager.async_get_agent_info(),
},
)
class ConversationProcessView(http.HomeAssistantView): class ConversationProcessView(http.HomeAssistantView):
"""View to process text.""" """View to process text."""
@ -200,43 +250,41 @@ class ConversationProcessView(http.HomeAssistantView):
vol.Required("text"): str, vol.Required("text"): str,
vol.Optional("conversation_id"): str, vol.Optional("conversation_id"): str,
vol.Optional("language"): str, vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
} }
) )
) )
async def post(self, request, data): async def post(self, request, data):
"""Send a request for processing.""" """Send a request for processing."""
hass = request.app["hass"] hass = request.app["hass"]
result = await async_converse( result = await async_converse(
hass, hass,
text=data["text"], text=data["text"],
conversation_id=data.get("conversation_id"), conversation_id=data.get("conversation_id"),
context=self.context(request), context=self.context(request),
language=data.get("language"), language=data.get("language"),
agent_id=data.get("agent_id"),
) )
return self.json(result.as_dict()) return self.json(result.as_dict())
async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
"""Get the active conversation agent."""
if (agent := hass.data.get(DATA_AGENT)) is None:
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
await agent.async_initialize(hass.data.get(DATA_CONFIG))
return agent
async def async_converse( async def async_converse(
hass: core.HomeAssistant, hass: core.HomeAssistant,
text: str, text: str,
conversation_id: str | None, conversation_id: str | None,
context: core.Context, context: core.Context,
language: str | None = None, language: str | None = None,
agent_id: str | None = None,
) -> ConversationResult: ) -> ConversationResult:
"""Process text and get intent.""" """Process text and get intent."""
agent = await _get_agent(hass) agent = await _get_agent_manager(hass).async_get_agent(agent_id)
if language is None: if language is None:
language = hass.config.language language = hass.config.language
_LOGGER.debug("Processing in %s: %s", language, text)
result = await agent.async_process( result = await agent.async_process(
ConversationInput( ConversationInput(
text=text, text=text,
@ -246,3 +294,88 @@ async def async_converse(
) )
) )
return result return result
class AgentManager:
"""Class to manage conversation agents."""
HOME_ASSISTANT_AGENT = "homeassistant"
default_agent: str = HOME_ASSISTANT_AGENT
_builtin_agent: AbstractConversationAgent | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the conversation agents."""
self.hass = hass
self._agents: dict[str, AbstractConversationAgent] = {}
self._default_agent_init_lock = asyncio.Lock()
async def async_get_agent(
self, agent_id: str | None = None
) -> AbstractConversationAgent:
"""Get the agent."""
if agent_id is None:
agent_id = self.default_agent
if agent_id == AgentManager.HOME_ASSISTANT_AGENT:
if self._builtin_agent is not None:
return self._builtin_agent
async with self._default_agent_init_lock:
if self._builtin_agent is not None:
return self._builtin_agent
self._builtin_agent = DefaultAgent(self.hass)
await self._builtin_agent.async_initialize(
self.hass.data.get(DATA_CONFIG)
)
return self._builtin_agent
return self._agents[agent_id]
@core.callback
def async_get_agent_info(self) -> list[dict[str, Any]]:
"""List all agents."""
agents = [
{
"id": AgentManager.HOME_ASSISTANT_AGENT,
"name": "Home Assistant",
}
]
for agent_id, agent in self._agents.items():
config_entry = self.hass.config_entries.async_get_entry(agent_id)
# This is a bug, agent should have been unset when config entry was unloaded
if config_entry is None:
_LOGGER.warning(
"Agent was still loaded while config entry is gone: %s", agent
)
continue
agents.append(
{
"id": agent_id,
"name": config_entry.title,
}
)
return agents
@core.callback
def async_is_valid_agent_id(self, agent_id: str) -> bool:
"""Check if the agent id is valid."""
return agent_id in self._agents or agent_id == AgentManager.HOME_ASSISTANT_AGENT
@core.callback
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
"""Set the agent."""
self._agents[agent_id] = agent
if self.default_agent == AgentManager.HOME_ASSISTANT_AGENT:
self.default_agent = agent_id
@core.callback
def async_unset_agent(self, agent_id: str) -> None:
"""Unset the agent."""
if self.default_agent == agent_id:
self.default_agent = AgentManager.HOME_ASSISTANT_AGENT
self._agents.pop(agent_id, None)

View File

@ -9,3 +9,15 @@ process:
example: Turn all lights on example: Turn all lights on
selector: selector:
text: text:
language:
name: Language
description: Language of text. Defaults to server language
example: NL
selector:
text:
agent_id:
name: Agent
description: Assist engine to process your request
example: homeassistant
selector:
text:

View File

@ -8,8 +8,9 @@ from homeassistant.helpers import intent
class MockAgent(conversation.AbstractConversationAgent): class MockAgent(conversation.AbstractConversationAgent):
"""Test Agent.""" """Test Agent."""
def __init__(self) -> None: def __init__(self, agent_id: str) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.agent_id = agent_id
self.calls = [] self.calls = []
self.response = "Test response" self.response = "Test response"

View File

@ -6,10 +6,14 @@ from homeassistant.components import conversation
from . import MockAgent from . import MockAgent
from tests.common import MockConfigEntry
@pytest.fixture @pytest.fixture
def mock_agent(hass): def mock_agent(hass):
"""Mock agent.""" """Mock agent."""
agent = MockAgent() entry = MockConfigEntry(entry_id="mock-entry")
conversation.async_set_agent(hass, None, agent) entry.add_to_hass(hass)
agent = MockAgent(entry.entry_id)
conversation.async_set_agent(hass, entry, agent)
return agent return agent

View File

@ -3,6 +3,7 @@ from http import HTTPStatus
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.cover import SERVICE_OPEN_COVER from homeassistant.components.cover import SERVICE_OPEN_COVER
@ -18,6 +19,8 @@ from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry, async_mock_service from tests.common import MockConfigEntry, async_mock_service
AGENT_ID_OPTIONS = [None, conversation.AgentManager.HOME_ASSISTANT_AGENT]
class OrderBeerIntentHandler(intent.IntentHandler): class OrderBeerIntentHandler(intent.IntentHandler):
"""Handle OrderBeer intent.""" """Handle OrderBeer intent."""
@ -40,8 +43,9 @@ async def init_components(hass):
assert await async_setup_component(hass, "intent", {}) assert await async_setup_component(hass, "intent", {})
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
async def test_http_processing_intent( async def test_http_processing_intent(
hass, init_components, hass_client, hass_admin_user hass, init_components, hass_client, hass_admin_user, agent_id
): ):
"""Test processing intent via HTTP API.""" """Test processing intent via HTTP API."""
# Add an alias # Add an alias
@ -50,9 +54,52 @@ async def test_http_processing_intent(
entities.async_update_entity("light.kitchen", aliases={"my cool light"}) entities.async_update_entity("light.kitchen", aliases={"my cool light"})
hass.states.async_set("light.kitchen", "off") hass.states.async_set("light.kitchen", "off")
client = await hass_client()
data = {"text": "turn on my cool light"}
if agent_id:
data["agent_id"] = agent_id
resp = await client.post("/api/conversation/process", json=data)
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert data == {
"response": {
"response_type": "action_done",
"card": {},
"speech": {
"plain": {
"extra_data": None,
"speech": "Turned on my cool light",
}
},
"language": hass.config.language,
"data": {
"targets": [],
"success": [
{"id": "light.kitchen", "name": "kitchen", "type": "entity"}
],
"failed": [],
},
},
"conversation_id": None,
}
async def test_http_processing_intent_target_ha_agent(
hass, init_components, hass_client, hass_admin_user, mock_agent
):
"""Test processing intent can be processed via HTTP API with picking agent."""
# Add an alias
entities = entity_registry.async_get(hass)
entities.async_get_or_create("light", "demo", "1234", suggested_object_id="kitchen")
entities.async_update_entity("light.kitchen", aliases={"my cool light"})
hass.states.async_set("light.kitchen", "off")
client = await hass_client() client = await hass_client()
resp = await client.post( resp = await client.post(
"/api/conversation/process", json={"text": "turn on my cool light"} "/api/conversation/process",
json={"text": "turn on my cool light", "agent_id": "homeassistant"},
) )
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
@ -218,15 +265,17 @@ async def test_http_processing_intent_entity_added(
} }
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
@pytest.mark.parametrize("sentence", ("turn on kitchen", "turn kitchen on")) @pytest.mark.parametrize("sentence", ("turn on kitchen", "turn kitchen on"))
async def test_turn_on_intent(hass, init_components, sentence): async def test_turn_on_intent(hass, init_components, sentence, agent_id):
"""Test calling the turn on intent.""" """Test calling the turn on intent."""
hass.states.async_set("light.kitchen", "off") hass.states.async_set("light.kitchen", "off")
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on") calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
await hass.services.async_call( data = {conversation.ATTR_TEXT: sentence}
"conversation", "process", {conversation.ATTR_TEXT: sentence} if agent_id is not None:
) data[conversation.ATTR_AGENT_ID] = agent_id
await hass.services.async_call("conversation", "process", data)
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
@ -254,46 +303,6 @@ async def test_turn_off_intent(hass, init_components, sentence):
assert call.data == {"entity_id": "light.kitchen"} assert call.data == {"entity_id": "light.kitchen"}
async def test_http_api(hass, init_components, hass_client):
"""Test the HTTP conversation API."""
client = await hass_client()
hass.states.async_set("light.kitchen", "off")
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
resp = await client.post(
"/api/conversation/process", json={"text": "Turn the kitchen on"}
)
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert data == {
"response": {
"card": {},
"speech": {"plain": {"extra_data": None, "speech": "Turned on kitchen"}},
"language": hass.config.language,
"response_type": "action_done",
"data": {
"targets": [],
"success": [
{
"type": "entity",
"name": "kitchen",
"id": "light.kitchen",
},
],
"failed": [],
},
},
"conversation_id": None,
}
assert len(calls) == 1
call = calls[0]
assert call.domain == HASS_DOMAIN
assert call.service == "turn_on"
assert call.data == {"entity_id": "light.kitchen"}
async def test_http_api_no_match(hass, init_components, hass_client): async def test_http_api_no_match(hass, init_components, hass_client):
"""Test the HTTP conversation API with an intent match failure.""" """Test the HTTP conversation API with an intent match failure."""
client = await hass_client() client = await hass_client()
@ -406,20 +415,22 @@ async def test_http_api_wrong_data(hass, init_components, hass_client):
assert resp.status == HTTPStatus.BAD_REQUEST assert resp.status == HTTPStatus.BAD_REQUEST
async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent): @pytest.mark.parametrize("agent_id", (None, "mock-entry"))
async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent, agent_id):
"""Test a custom conversation agent.""" """Test a custom conversation agent."""
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
client = await hass_client() client = await hass_client()
resp = await client.post( data = {
"/api/conversation/process", "text": "Test Text",
json={ "conversation_id": "test-conv-id",
"text": "Test Text", "language": "test-language",
"conversation_id": "test-conv-id", }
"language": "test-language", if agent_id is not None:
}, data["agent_id"] = agent_id
)
resp = await client.post("/api/conversation/process", json=data)
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
assert await resp.json() == { assert await resp.json() == {
"response": { "response": {
@ -443,6 +454,10 @@ async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent):
assert mock_agent.calls[0].conversation_id == "test-conv-id" assert mock_agent.calls[0].conversation_id == "test-conv-id"
assert mock_agent.calls[0].language == "test-language" assert mock_agent.calls[0].language == "test-language"
conversation.async_unset_agent(
hass, hass.config_entries.async_get_entry(mock_agent.agent_id)
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"payload", "payload",
@ -467,6 +482,10 @@ async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent):
"conversation_id": "test-conv-id", "conversation_id": "test-conv-id",
"language": "test-language", "language": "test-language",
}, },
{
"text": "Test Text",
"agent_id": "homeassistant",
},
], ],
) )
async def test_ws_api(hass, hass_ws_client, payload): async def test_ws_api(hass, hass_ws_client, payload):
@ -496,10 +515,11 @@ async def test_ws_api(hass, hass_ws_client, payload):
} }
async def test_ws_prepare(hass, hass_ws_client): @pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
async def test_ws_prepare(hass, hass_ws_client, agent_id):
"""Test the Websocket prepare conversation API.""" """Test the Websocket prepare conversation API."""
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
agent = await conversation._get_agent(hass) agent = await conversation._get_agent_manager(hass).async_get_agent()
assert isinstance(agent, conversation.DefaultAgent) assert isinstance(agent, conversation.DefaultAgent)
# No intents should be loaded yet # No intents should be loaded yet
@ -507,12 +527,13 @@ async def test_ws_prepare(hass, hass_ws_client):
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json( msg = {
{ "id": 5,
"id": 5, "type": "conversation/prepare",
"type": "conversation/prepare", }
} if agent_id is not None:
) msg["agent_id"] = agent_id
await client.send_json(msg)
msg = await client.receive_json() msg = await client.receive_json()
@ -618,7 +639,7 @@ async def test_prepare_reload(hass):
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
# Load intents # Load intents
agent = await conversation._get_agent(hass) agent = await conversation._get_agent_manager(hass).async_get_agent()
assert isinstance(agent, conversation.DefaultAgent) assert isinstance(agent, conversation.DefaultAgent)
await agent.async_prepare(language) await agent.async_prepare(language)
@ -638,7 +659,7 @@ async def test_prepare_fail(hass):
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
# Load intents # Load intents
agent = await conversation._get_agent(hass) agent = await conversation._get_agent_manager(hass).async_get_agent()
assert isinstance(agent, conversation.DefaultAgent) assert isinstance(agent, conversation.DefaultAgent)
await agent.async_prepare("not-a-language") await agent.async_prepare("not-a-language")
@ -676,7 +697,7 @@ async def test_reload_on_new_component(hass):
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
# Load intents # Load intents
agent = await conversation._get_agent(hass) agent = await conversation._get_agent_manager(hass).async_get_agent()
assert isinstance(agent, conversation.DefaultAgent) assert isinstance(agent, conversation.DefaultAgent)
await agent.async_prepare() await agent.async_prepare()
@ -700,7 +721,7 @@ async def test_non_default_response(hass, init_components):
hass.states.async_set("cover.front_door", "closed") hass.states.async_set("cover.front_door", "closed")
async_mock_service(hass, "cover", SERVICE_OPEN_COVER) async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
agent = await conversation._get_agent(hass) agent = await conversation._get_agent_manager(hass).async_get_agent()
assert isinstance(agent, conversation.DefaultAgent) assert isinstance(agent, conversation.DefaultAgent)
result = await agent.async_process( result = await agent.async_process(
@ -826,3 +847,30 @@ async def test_light_area_same_name(hass, init_components):
assert call.domain == HASS_DOMAIN assert call.domain == HASS_DOMAIN
assert call.service == "turn_on" assert call.service == "turn_on"
assert call.data == {"entity_id": kitchen_light.entity_id} assert call.data == {"entity_id": kitchen_light.entity_id}
async def test_agent_id_validator_invalid_agent(hass):
"""Test validating agent id."""
with pytest.raises(vol.Invalid):
conversation.agent_id_validator("invalid_agent")
conversation.agent_id_validator(conversation.AgentManager.HOME_ASSISTANT_AGENT)
async def test_get_agent_list(hass, init_components, mock_agent, hass_ws_client):
"""Test getting agent info."""
client = await hass_ws_client(hass)
await client.send_json({"id": 5, "type": "conversation/agent/list"})
msg = await client.receive_json()
assert msg["id"] == 5
assert msg["type"] == "result"
assert msg["success"]
assert msg["result"] == {
"agents": [
{"id": "homeassistant", "name": "Home Assistant"},
{"id": "mock-entry", "name": "Mock Title"},
],
"default_agent": "mock-entry",
}