mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Add support for multiple conversation agents (#87337)
* Add support for multiple conversation agents * Lock initializing default agent * Allow unsetting agent when never set
This commit is contained in:
parent
3f992ed31d
commit
fc38b4327f
@ -1,6 +1,7 @@
|
|||||||
"""Support for functionality to have conversations with Home Assistant."""
|
"""Support for functionality to have conversations with Home Assistant."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -12,7 +13,7 @@ from homeassistant.components import http, websocket_api
|
|||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import config_validation as cv, intent
|
from homeassistant.helpers import config_validation as cv, intent, singleton
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
|
|
||||||
@ -23,20 +24,31 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
|
|
||||||
ATTR_TEXT = "text"
|
ATTR_TEXT = "text"
|
||||||
ATTR_LANGUAGE = "language"
|
ATTR_LANGUAGE = "language"
|
||||||
|
ATTR_AGENT_ID = "agent_id"
|
||||||
|
|
||||||
DOMAIN = "conversation"
|
DOMAIN = "conversation"
|
||||||
|
|
||||||
REGEX_TYPE = type(re.compile(""))
|
REGEX_TYPE = type(re.compile(""))
|
||||||
DATA_AGENT = "conversation_agent"
|
|
||||||
DATA_CONFIG = "conversation_config"
|
DATA_CONFIG = "conversation_config"
|
||||||
|
|
||||||
SERVICE_PROCESS = "process"
|
SERVICE_PROCESS = "process"
|
||||||
SERVICE_RELOAD = "reload"
|
SERVICE_RELOAD = "reload"
|
||||||
|
|
||||||
|
|
||||||
|
def agent_id_validator(value: Any) -> str:
|
||||||
|
"""Validate agent ID."""
|
||||||
|
hass = core.async_get_hass()
|
||||||
|
manager = _get_agent_manager(hass)
|
||||||
|
if not manager.async_is_valid_agent_id(cv.string(value)):
|
||||||
|
raise vol.Invalid("invalid agent ID")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
SERVICE_PROCESS_SCHEMA = vol.Schema(
|
SERVICE_PROCESS_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
vol.Required(ATTR_TEXT): cv.string,
|
vol.Required(ATTR_TEXT): cv.string,
|
||||||
vol.Optional(ATTR_LANGUAGE): cv.string,
|
vol.Optional(ATTR_LANGUAGE): cv.string,
|
||||||
|
vol.Optional(ATTR_AGENT_ID): agent_id_validator,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -44,6 +56,7 @@ SERVICE_PROCESS_SCHEMA = vol.Schema(
|
|||||||
SERVICE_RELOAD_SCHEMA = vol.Schema(
|
SERVICE_RELOAD_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
vol.Optional(ATTR_LANGUAGE): cv.string,
|
vol.Optional(ATTR_LANGUAGE): cv.string,
|
||||||
|
vol.Optional(ATTR_AGENT_ID): agent_id_validator,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -61,6 +74,13 @@ CONFIG_SCHEMA = vol.Schema(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@singleton.singleton("conversation_agent")
|
||||||
|
@core.callback
|
||||||
|
def _get_agent_manager(hass: HomeAssistant) -> AgentManager:
|
||||||
|
"""Get the active agent."""
|
||||||
|
return AgentManager(hass)
|
||||||
|
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def async_set_agent(
|
def async_set_agent(
|
||||||
@ -69,7 +89,7 @@ def async_set_agent(
|
|||||||
agent: AbstractConversationAgent,
|
agent: AbstractConversationAgent,
|
||||||
):
|
):
|
||||||
"""Set the agent to handle the conversations."""
|
"""Set the agent to handle the conversations."""
|
||||||
hass.data[DATA_AGENT] = agent
|
_get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent)
|
||||||
|
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
@ -79,11 +99,13 @@ def async_unset_agent(
|
|||||||
config_entry: ConfigEntry,
|
config_entry: ConfigEntry,
|
||||||
):
|
):
|
||||||
"""Set the agent to handle the conversations."""
|
"""Set the agent to handle the conversations."""
|
||||||
hass.data[DATA_AGENT] = None
|
_get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Register the process service."""
|
"""Register the process service."""
|
||||||
|
agent_manager = _get_agent_manager(hass)
|
||||||
|
|
||||||
if config_intents := config.get(DOMAIN, {}).get("intents"):
|
if config_intents := config.get(DOMAIN, {}).get("intents"):
|
||||||
hass.data[DATA_CONFIG] = config_intents
|
hass.data[DATA_CONFIG] = config_intents
|
||||||
|
|
||||||
@ -91,22 +113,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
"""Parse text into commands."""
|
"""Parse text into commands."""
|
||||||
text = service.data[ATTR_TEXT]
|
text = service.data[ATTR_TEXT]
|
||||||
_LOGGER.debug("Processing: <%s>", text)
|
_LOGGER.debug("Processing: <%s>", text)
|
||||||
agent = await _get_agent(hass)
|
|
||||||
try:
|
try:
|
||||||
await agent.async_process(
|
await async_converse(
|
||||||
ConversationInput(
|
hass=hass,
|
||||||
text=text,
|
text=text,
|
||||||
context=service.context,
|
conversation_id=None,
|
||||||
conversation_id=None,
|
context=service.context,
|
||||||
language=service.data.get(ATTR_LANGUAGE, hass.config.language),
|
language=service.data.get(ATTR_LANGUAGE),
|
||||||
)
|
agent_id=service.data.get(ATTR_AGENT_ID),
|
||||||
)
|
)
|
||||||
except intent.IntentHandleError as err:
|
except intent.IntentHandleError as err:
|
||||||
_LOGGER.error("Error processing %s: %s", text, err)
|
_LOGGER.error("Error processing %s: %s", text, err)
|
||||||
|
|
||||||
async def handle_reload(service: core.ServiceCall) -> None:
|
async def handle_reload(service: core.ServiceCall) -> None:
|
||||||
"""Reload intents."""
|
"""Reload intents."""
|
||||||
agent = await _get_agent(hass)
|
agent = await agent_manager.async_get_agent()
|
||||||
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
|
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
|
||||||
|
|
||||||
hass.services.async_register(
|
hass.services.async_register(
|
||||||
@ -119,6 +140,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
websocket_api.async_register_command(hass, websocket_process)
|
websocket_api.async_register_command(hass, websocket_process)
|
||||||
websocket_api.async_register_command(hass, websocket_prepare)
|
websocket_api.async_register_command(hass, websocket_prepare)
|
||||||
websocket_api.async_register_command(hass, websocket_get_agent_info)
|
websocket_api.async_register_command(hass, websocket_get_agent_info)
|
||||||
|
websocket_api.async_register_command(hass, websocket_list_agents)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -129,6 +151,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
vol.Required("text"): str,
|
vol.Required("text"): str,
|
||||||
vol.Optional("conversation_id"): vol.Any(str, None),
|
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||||
vol.Optional("language"): str,
|
vol.Optional("language"): str,
|
||||||
|
vol.Optional("agent_id"): agent_id_validator,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@websocket_api.async_response
|
@websocket_api.async_response
|
||||||
@ -139,11 +162,12 @@ async def websocket_process(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Process text."""
|
"""Process text."""
|
||||||
result = await async_converse(
|
result = await async_converse(
|
||||||
hass,
|
hass=hass,
|
||||||
msg["text"],
|
text=msg["text"],
|
||||||
msg.get("conversation_id"),
|
conversation_id=msg.get("conversation_id"),
|
||||||
connection.context(msg),
|
context=connection.context(msg),
|
||||||
msg.get("language"),
|
language=msg.get("language"),
|
||||||
|
agent_id=msg.get("agent_id"),
|
||||||
)
|
)
|
||||||
connection.send_result(msg["id"], result.as_dict())
|
connection.send_result(msg["id"], result.as_dict())
|
||||||
|
|
||||||
@ -152,6 +176,7 @@ async def websocket_process(
|
|||||||
{
|
{
|
||||||
"type": "conversation/prepare",
|
"type": "conversation/prepare",
|
||||||
vol.Optional("language"): str,
|
vol.Optional("language"): str,
|
||||||
|
vol.Optional("agent_id"): agent_id_validator,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@websocket_api.async_response
|
@websocket_api.async_response
|
||||||
@ -161,7 +186,8 @@ async def websocket_prepare(
|
|||||||
msg: dict[str, Any],
|
msg: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Reload intents."""
|
"""Reload intents."""
|
||||||
agent = await _get_agent(hass)
|
manager = _get_agent_manager(hass)
|
||||||
|
agent = await manager.async_get_agent(msg.get("agent_id"))
|
||||||
await agent.async_prepare(msg.get("language"))
|
await agent.async_prepare(msg.get("language"))
|
||||||
connection.send_result(msg["id"])
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
@ -169,6 +195,7 @@ async def websocket_prepare(
|
|||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
{
|
{
|
||||||
vol.Required("type"): "conversation/agent/info",
|
vol.Required("type"): "conversation/agent/info",
|
||||||
|
vol.Optional("agent_id"): agent_id_validator,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@websocket_api.async_response
|
@websocket_api.async_response
|
||||||
@ -178,7 +205,7 @@ async def websocket_get_agent_info(
|
|||||||
msg: dict[str, Any],
|
msg: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Info about the agent in use."""
|
"""Info about the agent in use."""
|
||||||
agent = await _get_agent(hass)
|
agent = await _get_agent_manager(hass).async_get_agent(msg.get("agent_id"))
|
||||||
|
|
||||||
connection.send_result(
|
connection.send_result(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
@ -188,6 +215,29 @@ async def websocket_get_agent_info(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "conversation/agent/list",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@core.callback
|
||||||
|
def websocket_list_agents(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""List available agents."""
|
||||||
|
manager = _get_agent_manager(hass)
|
||||||
|
|
||||||
|
connection.send_result(
|
||||||
|
msg["id"],
|
||||||
|
{
|
||||||
|
"default_agent": manager.default_agent,
|
||||||
|
"agents": manager.async_get_agent_info(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessView(http.HomeAssistantView):
|
class ConversationProcessView(http.HomeAssistantView):
|
||||||
"""View to process text."""
|
"""View to process text."""
|
||||||
|
|
||||||
@ -200,43 +250,41 @@ class ConversationProcessView(http.HomeAssistantView):
|
|||||||
vol.Required("text"): str,
|
vol.Required("text"): str,
|
||||||
vol.Optional("conversation_id"): str,
|
vol.Optional("conversation_id"): str,
|
||||||
vol.Optional("language"): str,
|
vol.Optional("language"): str,
|
||||||
|
vol.Optional("agent_id"): agent_id_validator,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
async def post(self, request, data):
|
async def post(self, request, data):
|
||||||
"""Send a request for processing."""
|
"""Send a request for processing."""
|
||||||
hass = request.app["hass"]
|
hass = request.app["hass"]
|
||||||
|
|
||||||
result = await async_converse(
|
result = await async_converse(
|
||||||
hass,
|
hass,
|
||||||
text=data["text"],
|
text=data["text"],
|
||||||
conversation_id=data.get("conversation_id"),
|
conversation_id=data.get("conversation_id"),
|
||||||
context=self.context(request),
|
context=self.context(request),
|
||||||
language=data.get("language"),
|
language=data.get("language"),
|
||||||
|
agent_id=data.get("agent_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.json(result.as_dict())
|
return self.json(result.as_dict())
|
||||||
|
|
||||||
|
|
||||||
async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
|
|
||||||
"""Get the active conversation agent."""
|
|
||||||
if (agent := hass.data.get(DATA_AGENT)) is None:
|
|
||||||
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
|
|
||||||
await agent.async_initialize(hass.data.get(DATA_CONFIG))
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
async def async_converse(
|
async def async_converse(
|
||||||
hass: core.HomeAssistant,
|
hass: core.HomeAssistant,
|
||||||
text: str,
|
text: str,
|
||||||
conversation_id: str | None,
|
conversation_id: str | None,
|
||||||
context: core.Context,
|
context: core.Context,
|
||||||
language: str | None = None,
|
language: str | None = None,
|
||||||
|
agent_id: str | None = None,
|
||||||
) -> ConversationResult:
|
) -> ConversationResult:
|
||||||
"""Process text and get intent."""
|
"""Process text and get intent."""
|
||||||
agent = await _get_agent(hass)
|
agent = await _get_agent_manager(hass).async_get_agent(agent_id)
|
||||||
|
|
||||||
if language is None:
|
if language is None:
|
||||||
language = hass.config.language
|
language = hass.config.language
|
||||||
|
|
||||||
|
_LOGGER.debug("Processing in %s: %s", language, text)
|
||||||
result = await agent.async_process(
|
result = await agent.async_process(
|
||||||
ConversationInput(
|
ConversationInput(
|
||||||
text=text,
|
text=text,
|
||||||
@ -246,3 +294,88 @@ async def async_converse(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class AgentManager:
|
||||||
|
"""Class to manage conversation agents."""
|
||||||
|
|
||||||
|
HOME_ASSISTANT_AGENT = "homeassistant"
|
||||||
|
|
||||||
|
default_agent: str = HOME_ASSISTANT_AGENT
|
||||||
|
_builtin_agent: AbstractConversationAgent | None = None
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
"""Initialize the conversation agents."""
|
||||||
|
self.hass = hass
|
||||||
|
self._agents: dict[str, AbstractConversationAgent] = {}
|
||||||
|
self._default_agent_init_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def async_get_agent(
|
||||||
|
self, agent_id: str | None = None
|
||||||
|
) -> AbstractConversationAgent:
|
||||||
|
"""Get the agent."""
|
||||||
|
if agent_id is None:
|
||||||
|
agent_id = self.default_agent
|
||||||
|
|
||||||
|
if agent_id == AgentManager.HOME_ASSISTANT_AGENT:
|
||||||
|
if self._builtin_agent is not None:
|
||||||
|
return self._builtin_agent
|
||||||
|
|
||||||
|
async with self._default_agent_init_lock:
|
||||||
|
if self._builtin_agent is not None:
|
||||||
|
return self._builtin_agent
|
||||||
|
|
||||||
|
self._builtin_agent = DefaultAgent(self.hass)
|
||||||
|
await self._builtin_agent.async_initialize(
|
||||||
|
self.hass.data.get(DATA_CONFIG)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._builtin_agent
|
||||||
|
|
||||||
|
return self._agents[agent_id]
|
||||||
|
|
||||||
|
@core.callback
|
||||||
|
def async_get_agent_info(self) -> list[dict[str, Any]]:
|
||||||
|
"""List all agents."""
|
||||||
|
agents = [
|
||||||
|
{
|
||||||
|
"id": AgentManager.HOME_ASSISTANT_AGENT,
|
||||||
|
"name": "Home Assistant",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
for agent_id, agent in self._agents.items():
|
||||||
|
config_entry = self.hass.config_entries.async_get_entry(agent_id)
|
||||||
|
|
||||||
|
# This is a bug, agent should have been unset when config entry was unloaded
|
||||||
|
if config_entry is None:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Agent was still loaded while config entry is gone: %s", agent
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
agents.append(
|
||||||
|
{
|
||||||
|
"id": agent_id,
|
||||||
|
"name": config_entry.title,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return agents
|
||||||
|
|
||||||
|
@core.callback
|
||||||
|
def async_is_valid_agent_id(self, agent_id: str) -> bool:
|
||||||
|
"""Check if the agent id is valid."""
|
||||||
|
return agent_id in self._agents or agent_id == AgentManager.HOME_ASSISTANT_AGENT
|
||||||
|
|
||||||
|
@core.callback
|
||||||
|
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
|
||||||
|
"""Set the agent."""
|
||||||
|
self._agents[agent_id] = agent
|
||||||
|
if self.default_agent == AgentManager.HOME_ASSISTANT_AGENT:
|
||||||
|
self.default_agent = agent_id
|
||||||
|
|
||||||
|
@core.callback
|
||||||
|
def async_unset_agent(self, agent_id: str) -> None:
|
||||||
|
"""Unset the agent."""
|
||||||
|
if self.default_agent == agent_id:
|
||||||
|
self.default_agent = AgentManager.HOME_ASSISTANT_AGENT
|
||||||
|
self._agents.pop(agent_id, None)
|
||||||
|
@ -9,3 +9,15 @@ process:
|
|||||||
example: Turn all lights on
|
example: Turn all lights on
|
||||||
selector:
|
selector:
|
||||||
text:
|
text:
|
||||||
|
language:
|
||||||
|
name: Language
|
||||||
|
description: Language of text. Defaults to server language
|
||||||
|
example: NL
|
||||||
|
selector:
|
||||||
|
text:
|
||||||
|
agent_id:
|
||||||
|
name: Agent
|
||||||
|
description: Assist engine to process your request
|
||||||
|
example: homeassistant
|
||||||
|
selector:
|
||||||
|
text:
|
||||||
|
@ -8,8 +8,9 @@ from homeassistant.helpers import intent
|
|||||||
class MockAgent(conversation.AbstractConversationAgent):
|
class MockAgent(conversation.AbstractConversationAgent):
|
||||||
"""Test Agent."""
|
"""Test Agent."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, agent_id: str) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
|
self.agent_id = agent_id
|
||||||
self.calls = []
|
self.calls = []
|
||||||
self.response = "Test response"
|
self.response = "Test response"
|
||||||
|
|
||||||
|
@ -6,10 +6,14 @@ from homeassistant.components import conversation
|
|||||||
|
|
||||||
from . import MockAgent
|
from . import MockAgent
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_agent(hass):
|
def mock_agent(hass):
|
||||||
"""Mock agent."""
|
"""Mock agent."""
|
||||||
agent = MockAgent()
|
entry = MockConfigEntry(entry_id="mock-entry")
|
||||||
conversation.async_set_agent(hass, None, agent)
|
entry.add_to_hass(hass)
|
||||||
|
agent = MockAgent(entry.entry_id)
|
||||||
|
conversation.async_set_agent(hass, entry, agent)
|
||||||
return agent
|
return agent
|
||||||
|
@ -3,6 +3,7 @@ from http import HTTPStatus
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
||||||
@ -18,6 +19,8 @@ from homeassistant.setup import async_setup_component
|
|||||||
|
|
||||||
from tests.common import MockConfigEntry, async_mock_service
|
from tests.common import MockConfigEntry, async_mock_service
|
||||||
|
|
||||||
|
AGENT_ID_OPTIONS = [None, conversation.AgentManager.HOME_ASSISTANT_AGENT]
|
||||||
|
|
||||||
|
|
||||||
class OrderBeerIntentHandler(intent.IntentHandler):
|
class OrderBeerIntentHandler(intent.IntentHandler):
|
||||||
"""Handle OrderBeer intent."""
|
"""Handle OrderBeer intent."""
|
||||||
@ -40,8 +43,9 @@ async def init_components(hass):
|
|||||||
assert await async_setup_component(hass, "intent", {})
|
assert await async_setup_component(hass, "intent", {})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
|
||||||
async def test_http_processing_intent(
|
async def test_http_processing_intent(
|
||||||
hass, init_components, hass_client, hass_admin_user
|
hass, init_components, hass_client, hass_admin_user, agent_id
|
||||||
):
|
):
|
||||||
"""Test processing intent via HTTP API."""
|
"""Test processing intent via HTTP API."""
|
||||||
# Add an alias
|
# Add an alias
|
||||||
@ -50,9 +54,52 @@ async def test_http_processing_intent(
|
|||||||
entities.async_update_entity("light.kitchen", aliases={"my cool light"})
|
entities.async_update_entity("light.kitchen", aliases={"my cool light"})
|
||||||
hass.states.async_set("light.kitchen", "off")
|
hass.states.async_set("light.kitchen", "off")
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
data = {"text": "turn on my cool light"}
|
||||||
|
if agent_id:
|
||||||
|
data["agent_id"] = agent_id
|
||||||
|
resp = await client.post("/api/conversation/process", json=data)
|
||||||
|
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
assert data == {
|
||||||
|
"response": {
|
||||||
|
"response_type": "action_done",
|
||||||
|
"card": {},
|
||||||
|
"speech": {
|
||||||
|
"plain": {
|
||||||
|
"extra_data": None,
|
||||||
|
"speech": "Turned on my cool light",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"language": hass.config.language,
|
||||||
|
"data": {
|
||||||
|
"targets": [],
|
||||||
|
"success": [
|
||||||
|
{"id": "light.kitchen", "name": "kitchen", "type": "entity"}
|
||||||
|
],
|
||||||
|
"failed": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"conversation_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_http_processing_intent_target_ha_agent(
|
||||||
|
hass, init_components, hass_client, hass_admin_user, mock_agent
|
||||||
|
):
|
||||||
|
"""Test processing intent can be processed via HTTP API with picking agent."""
|
||||||
|
# Add an alias
|
||||||
|
entities = entity_registry.async_get(hass)
|
||||||
|
entities.async_get_or_create("light", "demo", "1234", suggested_object_id="kitchen")
|
||||||
|
entities.async_update_entity("light.kitchen", aliases={"my cool light"})
|
||||||
|
hass.states.async_set("light.kitchen", "off")
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
"/api/conversation/process", json={"text": "turn on my cool light"}
|
"/api/conversation/process",
|
||||||
|
json={"text": "turn on my cool light", "agent_id": "homeassistant"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
@ -218,15 +265,17 @@ async def test_http_processing_intent_entity_added(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
|
||||||
@pytest.mark.parametrize("sentence", ("turn on kitchen", "turn kitchen on"))
|
@pytest.mark.parametrize("sentence", ("turn on kitchen", "turn kitchen on"))
|
||||||
async def test_turn_on_intent(hass, init_components, sentence):
|
async def test_turn_on_intent(hass, init_components, sentence, agent_id):
|
||||||
"""Test calling the turn on intent."""
|
"""Test calling the turn on intent."""
|
||||||
hass.states.async_set("light.kitchen", "off")
|
hass.states.async_set("light.kitchen", "off")
|
||||||
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
|
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
|
||||||
|
|
||||||
await hass.services.async_call(
|
data = {conversation.ATTR_TEXT: sentence}
|
||||||
"conversation", "process", {conversation.ATTR_TEXT: sentence}
|
if agent_id is not None:
|
||||||
)
|
data[conversation.ATTR_AGENT_ID] = agent_id
|
||||||
|
await hass.services.async_call("conversation", "process", data)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
@ -254,46 +303,6 @@ async def test_turn_off_intent(hass, init_components, sentence):
|
|||||||
assert call.data == {"entity_id": "light.kitchen"}
|
assert call.data == {"entity_id": "light.kitchen"}
|
||||||
|
|
||||||
|
|
||||||
async def test_http_api(hass, init_components, hass_client):
|
|
||||||
"""Test the HTTP conversation API."""
|
|
||||||
client = await hass_client()
|
|
||||||
hass.states.async_set("light.kitchen", "off")
|
|
||||||
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
|
|
||||||
|
|
||||||
resp = await client.post(
|
|
||||||
"/api/conversation/process", json={"text": "Turn the kitchen on"}
|
|
||||||
)
|
|
||||||
assert resp.status == HTTPStatus.OK
|
|
||||||
data = await resp.json()
|
|
||||||
|
|
||||||
assert data == {
|
|
||||||
"response": {
|
|
||||||
"card": {},
|
|
||||||
"speech": {"plain": {"extra_data": None, "speech": "Turned on kitchen"}},
|
|
||||||
"language": hass.config.language,
|
|
||||||
"response_type": "action_done",
|
|
||||||
"data": {
|
|
||||||
"targets": [],
|
|
||||||
"success": [
|
|
||||||
{
|
|
||||||
"type": "entity",
|
|
||||||
"name": "kitchen",
|
|
||||||
"id": "light.kitchen",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"failed": [],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"conversation_id": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
assert len(calls) == 1
|
|
||||||
call = calls[0]
|
|
||||||
assert call.domain == HASS_DOMAIN
|
|
||||||
assert call.service == "turn_on"
|
|
||||||
assert call.data == {"entity_id": "light.kitchen"}
|
|
||||||
|
|
||||||
|
|
||||||
async def test_http_api_no_match(hass, init_components, hass_client):
|
async def test_http_api_no_match(hass, init_components, hass_client):
|
||||||
"""Test the HTTP conversation API with an intent match failure."""
|
"""Test the HTTP conversation API with an intent match failure."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@ -406,20 +415,22 @@ async def test_http_api_wrong_data(hass, init_components, hass_client):
|
|||||||
assert resp.status == HTTPStatus.BAD_REQUEST
|
assert resp.status == HTTPStatus.BAD_REQUEST
|
||||||
|
|
||||||
|
|
||||||
async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent):
|
@pytest.mark.parametrize("agent_id", (None, "mock-entry"))
|
||||||
|
async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent, agent_id):
|
||||||
"""Test a custom conversation agent."""
|
"""Test a custom conversation agent."""
|
||||||
assert await async_setup_component(hass, "conversation", {})
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
|
||||||
resp = await client.post(
|
data = {
|
||||||
"/api/conversation/process",
|
"text": "Test Text",
|
||||||
json={
|
"conversation_id": "test-conv-id",
|
||||||
"text": "Test Text",
|
"language": "test-language",
|
||||||
"conversation_id": "test-conv-id",
|
}
|
||||||
"language": "test-language",
|
if agent_id is not None:
|
||||||
},
|
data["agent_id"] = agent_id
|
||||||
)
|
|
||||||
|
resp = await client.post("/api/conversation/process", json=data)
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
assert await resp.json() == {
|
assert await resp.json() == {
|
||||||
"response": {
|
"response": {
|
||||||
@ -443,6 +454,10 @@ async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent):
|
|||||||
assert mock_agent.calls[0].conversation_id == "test-conv-id"
|
assert mock_agent.calls[0].conversation_id == "test-conv-id"
|
||||||
assert mock_agent.calls[0].language == "test-language"
|
assert mock_agent.calls[0].language == "test-language"
|
||||||
|
|
||||||
|
conversation.async_unset_agent(
|
||||||
|
hass, hass.config_entries.async_get_entry(mock_agent.agent_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"payload",
|
"payload",
|
||||||
@ -467,6 +482,10 @@ async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent):
|
|||||||
"conversation_id": "test-conv-id",
|
"conversation_id": "test-conv-id",
|
||||||
"language": "test-language",
|
"language": "test-language",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"text": "Test Text",
|
||||||
|
"agent_id": "homeassistant",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_ws_api(hass, hass_ws_client, payload):
|
async def test_ws_api(hass, hass_ws_client, payload):
|
||||||
@ -496,10 +515,11 @@ async def test_ws_api(hass, hass_ws_client, payload):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_ws_prepare(hass, hass_ws_client):
|
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
|
||||||
|
async def test_ws_prepare(hass, hass_ws_client, agent_id):
|
||||||
"""Test the Websocket prepare conversation API."""
|
"""Test the Websocket prepare conversation API."""
|
||||||
assert await async_setup_component(hass, "conversation", {})
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
agent = await conversation._get_agent(hass)
|
agent = await conversation._get_agent_manager(hass).async_get_agent()
|
||||||
assert isinstance(agent, conversation.DefaultAgent)
|
assert isinstance(agent, conversation.DefaultAgent)
|
||||||
|
|
||||||
# No intents should be loaded yet
|
# No intents should be loaded yet
|
||||||
@ -507,12 +527,13 @@ async def test_ws_prepare(hass, hass_ws_client):
|
|||||||
|
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
await client.send_json(
|
msg = {
|
||||||
{
|
"id": 5,
|
||||||
"id": 5,
|
"type": "conversation/prepare",
|
||||||
"type": "conversation/prepare",
|
}
|
||||||
}
|
if agent_id is not None:
|
||||||
)
|
msg["agent_id"] = agent_id
|
||||||
|
await client.send_json(msg)
|
||||||
|
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
|
|
||||||
@ -618,7 +639,7 @@ async def test_prepare_reload(hass):
|
|||||||
assert await async_setup_component(hass, "conversation", {})
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
|
||||||
# Load intents
|
# Load intents
|
||||||
agent = await conversation._get_agent(hass)
|
agent = await conversation._get_agent_manager(hass).async_get_agent()
|
||||||
assert isinstance(agent, conversation.DefaultAgent)
|
assert isinstance(agent, conversation.DefaultAgent)
|
||||||
await agent.async_prepare(language)
|
await agent.async_prepare(language)
|
||||||
|
|
||||||
@ -638,7 +659,7 @@ async def test_prepare_fail(hass):
|
|||||||
assert await async_setup_component(hass, "conversation", {})
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
|
||||||
# Load intents
|
# Load intents
|
||||||
agent = await conversation._get_agent(hass)
|
agent = await conversation._get_agent_manager(hass).async_get_agent()
|
||||||
assert isinstance(agent, conversation.DefaultAgent)
|
assert isinstance(agent, conversation.DefaultAgent)
|
||||||
await agent.async_prepare("not-a-language")
|
await agent.async_prepare("not-a-language")
|
||||||
|
|
||||||
@ -676,7 +697,7 @@ async def test_reload_on_new_component(hass):
|
|||||||
assert await async_setup_component(hass, "conversation", {})
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
|
||||||
# Load intents
|
# Load intents
|
||||||
agent = await conversation._get_agent(hass)
|
agent = await conversation._get_agent_manager(hass).async_get_agent()
|
||||||
assert isinstance(agent, conversation.DefaultAgent)
|
assert isinstance(agent, conversation.DefaultAgent)
|
||||||
await agent.async_prepare()
|
await agent.async_prepare()
|
||||||
|
|
||||||
@ -700,7 +721,7 @@ async def test_non_default_response(hass, init_components):
|
|||||||
hass.states.async_set("cover.front_door", "closed")
|
hass.states.async_set("cover.front_door", "closed")
|
||||||
async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
|
async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
|
||||||
|
|
||||||
agent = await conversation._get_agent(hass)
|
agent = await conversation._get_agent_manager(hass).async_get_agent()
|
||||||
assert isinstance(agent, conversation.DefaultAgent)
|
assert isinstance(agent, conversation.DefaultAgent)
|
||||||
|
|
||||||
result = await agent.async_process(
|
result = await agent.async_process(
|
||||||
@ -826,3 +847,30 @@ async def test_light_area_same_name(hass, init_components):
|
|||||||
assert call.domain == HASS_DOMAIN
|
assert call.domain == HASS_DOMAIN
|
||||||
assert call.service == "turn_on"
|
assert call.service == "turn_on"
|
||||||
assert call.data == {"entity_id": kitchen_light.entity_id}
|
assert call.data == {"entity_id": kitchen_light.entity_id}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_agent_id_validator_invalid_agent(hass):
|
||||||
|
"""Test validating agent id."""
|
||||||
|
with pytest.raises(vol.Invalid):
|
||||||
|
conversation.agent_id_validator("invalid_agent")
|
||||||
|
|
||||||
|
conversation.agent_id_validator(conversation.AgentManager.HOME_ASSISTANT_AGENT)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_agent_list(hass, init_components, mock_agent, hass_ws_client):
|
||||||
|
"""Test getting agent info."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json({"id": 5, "type": "conversation/agent/list"})
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["id"] == 5
|
||||||
|
assert msg["type"] == "result"
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {
|
||||||
|
"agents": [
|
||||||
|
{"id": "homeassistant", "name": "Home Assistant"},
|
||||||
|
{"id": "mock-entry", "name": "Mock Title"},
|
||||||
|
],
|
||||||
|
"default_agent": "mock-entry",
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user