diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index dd8fb967824..a0717ddaa58 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -2,43 +2,36 @@ from __future__ import annotations -import asyncio from collections.abc import Iterable -from dataclasses import dataclass import logging import re -from typing import Any, Literal +from typing import Literal -from aiohttp import web -from hassil.recognize import ( - MISSING_ENTITY, - RecognizeResult, - UnmatchedRangeEntity, - UnmatchedTextEntity, -) import voluptuous as vol -from homeassistant import core -from homeassistant.components import http, websocket_api -from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.config_entries import ConfigEntry from homeassistant.const import MATCH_ALL -from homeassistant.core import HomeAssistant +from homeassistant.core import ( + HomeAssistant, + ServiceCall, + ServiceResponse, + SupportsResponse, + callback, +) from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import config_validation as cv, intent, singleton +from homeassistant.helpers import config_validation as cv, intent from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass -from homeassistant.util import language as language_util -from .agent import AbstractConversationAgent, ConversationInput, ConversationResult -from .const import HOME_ASSISTANT_AGENT -from .default_agent import ( - METADATA_CUSTOM_FILE, - METADATA_CUSTOM_SENTENCE, - DefaultAgent, - SentenceTriggerResult, - async_setup as async_setup_default_agent, +from .agent_manager import ( + AgentInfo, + agent_id_validator, + async_converse, + get_agent_manager, ) +from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT +from .http import async_setup as async_setup_conversation_http +from .models import AbstractConversationAgent, ConversationInput, ConversationResult __all__ = [ "DOMAIN", @@ -48,6 +41,8 @@ __all__ = [ "async_set_agent", "async_unset_agent", "async_setup", + "ConversationInput", + "ConversationResult", ] _LOGGER = logging.getLogger(__name__) @@ -60,21 +55,11 @@ ATTR_CONVERSATION_ID = "conversation_id" DOMAIN = "conversation" REGEX_TYPE = type(re.compile("")) -DATA_CONFIG = "conversation_config" SERVICE_PROCESS = "process" SERVICE_RELOAD = "reload" -def agent_id_validator(value: Any) -> str: - """Validate agent ID.""" - hass = core.async_get_hass() - manager = _get_agent_manager(hass) - if not manager.async_is_valid_agent_id(cv.string(value)): - raise vol.Invalid("invalid agent ID") - return value - - SERVICE_PROCESS_SCHEMA = vol.Schema( { vol.Required(ATTR_TEXT): cv.string, @@ -106,34 +91,25 @@ CONFIG_SCHEMA = vol.Schema( ) -@singleton.singleton("conversation_agent") -@core.callback -def _get_agent_manager(hass: HomeAssistant) -> AgentManager: - """Get the active agent.""" - manager = AgentManager(hass) - manager.async_setup() - return manager - - -@core.callback +@callback @bind_hass def async_set_agent( - hass: core.HomeAssistant, + hass: HomeAssistant, config_entry: ConfigEntry, agent: AbstractConversationAgent, ) -> None: """Set the agent to handle the conversations.""" - _get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent) + get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent) -@core.callback +@callback @bind_hass def async_unset_agent( - hass: core.HomeAssistant, + hass: HomeAssistant, config_entry: ConfigEntry, ) -> None: """Set the agent to handle the conversations.""" - _get_agent_manager(hass).async_unset_agent(config_entry.entry_id) + get_agent_manager(hass).async_unset_agent(config_entry.entry_id) async def async_get_conversation_languages( @@ -145,7 +121,7 @@ async def async_get_conversation_languages( If no agent is specified, return a set with the union of languages supported by all conversation agents. """ - agent_manager = _get_agent_manager(hass) + agent_manager = get_agent_manager(hass) languages: set[str] = set() agent_ids: Iterable[str] @@ -164,14 +140,32 @@ async def async_get_conversation_languages( return languages +@callback +def async_get_agent_info( + hass: HomeAssistant, + agent_id: str | None = None, +) -> AgentInfo | None: + """Get information on the agent or None if not found.""" + manager = get_agent_manager(hass) + + if agent_id is None: + agent_id = manager.default_agent + + for agent_info in manager.async_get_agent_info(): + if agent_info.id == agent_id: + return agent_info + + return None + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Register the process service.""" - agent_manager = _get_agent_manager(hass) + agent_manager = get_agent_manager(hass) if config_intents := config.get(DOMAIN, {}).get("intents"): hass.data[DATA_CONFIG] = config_intents - async def handle_process(service: core.ServiceCall) -> core.ServiceResponse: + async def handle_process(service: ServiceCall) -> ServiceResponse: """Parse text into commands.""" text = service.data[ATTR_TEXT] _LOGGER.debug("Processing: <%s>", text) @@ -192,7 +186,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return None - async def handle_reload(service: core.ServiceCall) -> None: + async def handle_reload(service: ServiceCall) -> None: """Reload intents.""" agent = await agent_manager.async_get_agent() await agent.async_reload(language=service.data.get(ATTR_LANGUAGE)) @@ -202,440 +196,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: SERVICE_PROCESS, handle_process, schema=SERVICE_PROCESS_SCHEMA, - supports_response=core.SupportsResponse.OPTIONAL, + supports_response=SupportsResponse.OPTIONAL, ) hass.services.async_register( DOMAIN, SERVICE_RELOAD, handle_reload, schema=SERVICE_RELOAD_SCHEMA ) - hass.http.register_view(ConversationProcessView()) - websocket_api.async_register_command(hass, websocket_process) - websocket_api.async_register_command(hass, websocket_prepare) - websocket_api.async_register_command(hass, websocket_list_agents) - websocket_api.async_register_command(hass, websocket_hass_agent_debug) + async_setup_conversation_http(hass) return True - - -@websocket_api.websocket_command( - { - vol.Required("type"): "conversation/process", - vol.Required("text"): str, - vol.Optional("conversation_id"): vol.Any(str, None), - vol.Optional("language"): str, - vol.Optional("agent_id"): agent_id_validator, - } -) -@websocket_api.async_response -async def websocket_process( - hass: HomeAssistant, - connection: websocket_api.ActiveConnection, - msg: dict[str, Any], -) -> None: - """Process text.""" - result = await async_converse( - hass=hass, - text=msg["text"], - conversation_id=msg.get("conversation_id"), - context=connection.context(msg), - language=msg.get("language"), - agent_id=msg.get("agent_id"), - ) - connection.send_result(msg["id"], result.as_dict()) - - -@websocket_api.websocket_command( - { - "type": "conversation/prepare", - vol.Optional("language"): str, - vol.Optional("agent_id"): agent_id_validator, - } -) -@websocket_api.async_response -async def websocket_prepare( - hass: HomeAssistant, - connection: websocket_api.ActiveConnection, - msg: dict[str, Any], -) -> None: - """Reload intents.""" - manager = _get_agent_manager(hass) - agent = await manager.async_get_agent(msg.get("agent_id")) - await agent.async_prepare(msg.get("language")) - connection.send_result(msg["id"]) - - -@websocket_api.websocket_command( - { - vol.Required("type"): "conversation/agent/list", - vol.Optional("language"): str, - vol.Optional("country"): str, - } -) -@websocket_api.async_response -async def websocket_list_agents( - hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict -) -> None: - """List conversation agents and, optionally, if they support a given language.""" - manager = _get_agent_manager(hass) - - country = msg.get("country") - language = msg.get("language") - agents = [] - - for agent_info in manager.async_get_agent_info(): - agent = await manager.async_get_agent(agent_info.id) - - supported_languages = agent.supported_languages - if language and supported_languages != MATCH_ALL: - supported_languages = language_util.matches( - language, supported_languages, country - ) - - agent_dict: dict[str, Any] = { - "id": agent_info.id, - "name": agent_info.name, - "supported_languages": supported_languages, - } - agents.append(agent_dict) - - connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents})) - - -@websocket_api.websocket_command( - { - vol.Required("type"): "conversation/agent/homeassistant/debug", - vol.Required("sentences"): [str], - vol.Optional("language"): str, - vol.Optional("device_id"): vol.Any(str, None), - } -) -@websocket_api.async_response -async def websocket_hass_agent_debug( - hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict -) -> None: - """Return intents that would be matched by the default agent for a list of sentences.""" - agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT) - assert isinstance(agent, DefaultAgent) - results = [ - await agent.async_recognize( - ConversationInput( - text=sentence, - context=connection.context(msg), - conversation_id=None, - device_id=msg.get("device_id"), - language=msg.get("language", hass.config.language), - ) - ) - for sentence in msg["sentences"] - ] - - # Return results for each sentence in the same order as the input. - result_dicts: list[dict[str, Any] | None] = [] - for result in results: - result_dict: dict[str, Any] | None = None - if isinstance(result, SentenceTriggerResult): - result_dict = { - # Matched a user-defined sentence trigger. - # We can't provide the response here without executing the - # trigger. - "match": True, - "source": "trigger", - "sentence_template": result.sentence_template or "", - } - elif isinstance(result, RecognizeResult): - successful_match = not result.unmatched_entities - result_dict = { - # Name of the matching intent (or the closest) - "intent": { - "name": result.intent.name, - }, - # Slot values that would be received by the intent - "slots": { # direct access to values - entity_key: entity.text or entity.value - for entity_key, entity in result.entities.items() - }, - # Extra slot details, such as the originally matched text - "details": { - entity_key: { - "name": entity.name, - "value": entity.value, - "text": entity.text, - } - for entity_key, entity in result.entities.items() - }, - # Entities/areas/etc. that would be targeted - "targets": {}, - # True if match was successful - "match": successful_match, - # Text of the sentence template that matched (or was closest) - "sentence_template": "", - # When match is incomplete, this will contain the best slot guesses - "unmatched_slots": _get_unmatched_slots(result), - } - - if successful_match: - result_dict["targets"] = { - state.entity_id: {"matched": is_matched} - for state, is_matched in _get_debug_targets(hass, result) - } - - if result.intent_sentence is not None: - result_dict["sentence_template"] = result.intent_sentence.text - - # Inspect metadata to determine if this matched a custom sentence - if result.intent_metadata and result.intent_metadata.get( - METADATA_CUSTOM_SENTENCE - ): - result_dict["source"] = "custom" - result_dict["file"] = result.intent_metadata.get(METADATA_CUSTOM_FILE) - else: - result_dict["source"] = "builtin" - - result_dicts.append(result_dict) - - connection.send_result(msg["id"], {"results": result_dicts}) - - -def _get_debug_targets( - hass: HomeAssistant, - result: RecognizeResult, -) -> Iterable[tuple[core.State, bool]]: - """Yield state/is_matched pairs for a hassil recognition.""" - entities = result.entities - - name: str | None = None - area_name: str | None = None - domains: set[str] | None = None - device_classes: set[str] | None = None - state_names: set[str] | None = None - - if "name" in entities: - name = str(entities["name"].value) - - if "area" in entities: - area_name = str(entities["area"].value) - - if "domain" in entities: - domains = set(cv.ensure_list(entities["domain"].value)) - - if "device_class" in entities: - device_classes = set(cv.ensure_list(entities["device_class"].value)) - - if "state" in entities: - # HassGetState only - state_names = set(cv.ensure_list(entities["state"].value)) - - if ( - (name is None) - and (area_name is None) - and (not domains) - and (not device_classes) - and (not state_names) - ): - # Avoid "matching" all entities when there is no filter - return - - states = intent.async_match_states( - hass, - name=name, - area_name=area_name, - domains=domains, - device_classes=device_classes, - ) - - for state in states: - # For queries, a target is "matched" based on its state - is_matched = (state_names is None) or (state.state in state_names) - yield state, is_matched - - -def _get_unmatched_slots( - result: RecognizeResult, -) -> dict[str, str | int]: - """Return a dict of unmatched text/range slot entities.""" - unmatched_slots: dict[str, str | int] = {} - for entity in result.unmatched_entities_list: - if isinstance(entity, UnmatchedTextEntity): - if entity.text == MISSING_ENTITY: - # Don't report since these are just missing context - # slots. - continue - - unmatched_slots[entity.name] = entity.text - elif isinstance(entity, UnmatchedRangeEntity): - unmatched_slots[entity.name] = entity.value - - return unmatched_slots - - -class ConversationProcessView(http.HomeAssistantView): - """View to process text.""" - - url = "/api/conversation/process" - name = "api:conversation:process" - - @RequestDataValidator( - vol.Schema( - { - vol.Required("text"): str, - vol.Optional("conversation_id"): str, - vol.Optional("language"): str, - vol.Optional("agent_id"): agent_id_validator, - } - ) - ) - async def post(self, request: web.Request, data: dict[str, str]) -> web.Response: - """Send a request for processing.""" - hass = request.app[http.KEY_HASS] - - result = await async_converse( - hass, - text=data["text"], - conversation_id=data.get("conversation_id"), - context=self.context(request), - language=data.get("language"), - agent_id=data.get("agent_id"), - ) - - return self.json(result.as_dict()) - - -@dataclass(frozen=True) -class AgentInfo: - """Container for conversation agent info.""" - - id: str - name: str - - -@core.callback -def async_get_agent_info( - hass: core.HomeAssistant, - agent_id: str | None = None, -) -> AgentInfo | None: - """Get information on the agent or None if not found.""" - manager = _get_agent_manager(hass) - - if agent_id is None: - agent_id = manager.default_agent - - for agent_info in manager.async_get_agent_info(): - if agent_info.id == agent_id: - return agent_info - - return None - - -async def async_converse( - hass: core.HomeAssistant, - text: str, - conversation_id: str | None, - context: core.Context, - language: str | None = None, - agent_id: str | None = None, - device_id: str | None = None, -) -> ConversationResult: - """Process text and get intent.""" - agent = await _get_agent_manager(hass).async_get_agent(agent_id) - - if language is None: - language = hass.config.language - - _LOGGER.debug("Processing in %s: %s", language, text) - result = await agent.async_process( - ConversationInput( - text=text, - context=context, - conversation_id=conversation_id, - device_id=device_id, - language=language, - ) - ) - return result - - -class AgentManager: - """Class to manage conversation agents.""" - - default_agent: str = HOME_ASSISTANT_AGENT - _builtin_agent: AbstractConversationAgent | None = None - - def __init__(self, hass: HomeAssistant) -> None: - """Initialize the conversation agents.""" - self.hass = hass - self._agents: dict[str, AbstractConversationAgent] = {} - self._builtin_agent_init_lock = asyncio.Lock() - - def async_setup(self) -> None: - """Set up the conversation agents.""" - async_setup_default_agent(self.hass) - - async def async_get_agent( - self, agent_id: str | None = None - ) -> AbstractConversationAgent: - """Get the agent.""" - if agent_id is None: - agent_id = self.default_agent - - if agent_id == HOME_ASSISTANT_AGENT: - if self._builtin_agent is not None: - return self._builtin_agent - - async with self._builtin_agent_init_lock: - if self._builtin_agent is not None: - return self._builtin_agent - - self._builtin_agent = DefaultAgent(self.hass) - await self._builtin_agent.async_initialize( - self.hass.data.get(DATA_CONFIG) - ) - - return self._builtin_agent - - if agent_id not in self._agents: - raise ValueError(f"Agent {agent_id} not found") - - return self._agents[agent_id] - - @core.callback - def async_get_agent_info(self) -> list[AgentInfo]: - """List all agents.""" - agents: list[AgentInfo] = [ - AgentInfo( - id=HOME_ASSISTANT_AGENT, - name="Home Assistant", - ) - ] - for agent_id, agent in self._agents.items(): - config_entry = self.hass.config_entries.async_get_entry(agent_id) - - # Guard against potential bugs in conversation agents where the agent is not - # removed from the manager when the config entry is removed - if config_entry is None: - _LOGGER.warning( - "Conversation agent %s is still loaded after config entry removal", - agent, - ) - continue - - agents.append( - AgentInfo( - id=agent_id, - name=config_entry.title or config_entry.domain, - ) - ) - return agents - - @core.callback - def async_is_valid_agent_id(self, agent_id: str) -> bool: - """Check if the agent id is valid.""" - return agent_id in self._agents or agent_id == HOME_ASSISTANT_AGENT - - @core.callback - def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None: - """Set the agent.""" - self._agents[agent_id] = agent - - @core.callback - def async_unset_agent(self, agent_id: str) -> None: - """Unset the agent.""" - self._agents.pop(agent_id, None) diff --git a/homeassistant/components/conversation/agent_manager.py b/homeassistant/components/conversation/agent_manager.py new file mode 100644 index 00000000000..f34ecfaecc9 --- /dev/null +++ b/homeassistant/components/conversation/agent_manager.py @@ -0,0 +1,161 @@ +"""Agent foundation for conversation integration.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +import logging +from typing import Any + +import voluptuous as vol + +from homeassistant.core import Context, HomeAssistant, async_get_hass, callback +from homeassistant.helpers import config_validation as cv, singleton + +from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT +from .default_agent import DefaultAgent, async_setup as async_setup_default_agent +from .models import AbstractConversationAgent, ConversationInput, ConversationResult + +_LOGGER = logging.getLogger(__name__) + + +@singleton.singleton("conversation_agent") +@callback +def get_agent_manager(hass: HomeAssistant) -> AgentManager: + """Get the active agent.""" + manager = AgentManager(hass) + manager.async_setup() + return manager + + +def agent_id_validator(value: Any) -> str: + """Validate agent ID.""" + hass = async_get_hass() + manager = get_agent_manager(hass) + if not manager.async_is_valid_agent_id(cv.string(value)): + raise vol.Invalid("invalid agent ID") + return value + + +async def async_converse( + hass: HomeAssistant, + text: str, + conversation_id: str | None, + context: Context, + language: str | None = None, + agent_id: str | None = None, + device_id: str | None = None, +) -> ConversationResult: + """Process text and get intent.""" + agent = await get_agent_manager(hass).async_get_agent(agent_id) + + if language is None: + language = hass.config.language + + _LOGGER.debug("Processing in %s: %s", language, text) + result = await agent.async_process( + ConversationInput( + text=text, + context=context, + conversation_id=conversation_id, + device_id=device_id, + language=language, + ) + ) + return result + + +@dataclass(frozen=True) +class AgentInfo: + """Container for conversation agent info.""" + + id: str + name: str + + +class AgentManager: + """Class to manage conversation agents.""" + + default_agent: str = HOME_ASSISTANT_AGENT + _builtin_agent: AbstractConversationAgent | None = None + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the conversation agents.""" + self.hass = hass + self._agents: dict[str, AbstractConversationAgent] = {} + self._builtin_agent_init_lock = asyncio.Lock() + + def async_setup(self) -> None: + """Set up the conversation agents.""" + async_setup_default_agent(self.hass) + + async def async_get_agent( + self, agent_id: str | None = None + ) -> AbstractConversationAgent: + """Get the agent.""" + if agent_id is None: + agent_id = self.default_agent + + if agent_id == HOME_ASSISTANT_AGENT: + if self._builtin_agent is not None: + return self._builtin_agent + + async with self._builtin_agent_init_lock: + if self._builtin_agent is not None: + return self._builtin_agent + + self._builtin_agent = DefaultAgent(self.hass) + await self._builtin_agent.async_initialize( + self.hass.data.get(DATA_CONFIG) + ) + + return self._builtin_agent + + if agent_id not in self._agents: + raise ValueError(f"Agent {agent_id} not found") + + return self._agents[agent_id] + + @callback + def async_get_agent_info(self) -> list[AgentInfo]: + """List all agents.""" + agents: list[AgentInfo] = [ + AgentInfo( + id=HOME_ASSISTANT_AGENT, + name="Home Assistant", + ) + ] + for agent_id, agent in self._agents.items(): + config_entry = self.hass.config_entries.async_get_entry(agent_id) + + # Guard against potential bugs in conversation agents where the agent is not + # removed from the manager when the config entry is removed + if config_entry is None: + _LOGGER.warning( + "Conversation agent %s is still loaded after config entry removal", + agent, + ) + continue + + agents.append( + AgentInfo( + id=agent_id, + name=config_entry.title or config_entry.domain, + ) + ) + return agents + + @callback + def async_is_valid_agent_id(self, agent_id: str) -> bool: + """Check if the agent id is valid.""" + return agent_id in self._agents or agent_id == HOME_ASSISTANT_AGENT + + @callback + def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None: + """Set the agent.""" + self._agents[agent_id] = agent + + @callback + def async_unset_agent(self, agent_id: str) -> None: + """Unset the agent.""" + self._agents.pop(agent_id, None) diff --git a/homeassistant/components/conversation/const.py b/homeassistant/components/conversation/const.py index a8828fcc0e9..5cb5ca3bdea 100644 --- a/homeassistant/components/conversation/const.py +++ b/homeassistant/components/conversation/const.py @@ -3,3 +3,4 @@ DOMAIN = "conversation" DEFAULT_EXPOSED_ATTRIBUTES = {"device_class"} HOME_ASSISTANT_AGENT = "homeassistant" +DATA_CONFIG = "conversation_config" diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 29a06d44c5f..5a8d7b64eec 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -46,8 +46,8 @@ from homeassistant.helpers.event import ( ) from homeassistant.util.json import JsonObjectType, json_loads_object -from .agent import AbstractConversationAgent, ConversationInput, ConversationResult from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN +from .models import AbstractConversationAgent, ConversationInput, ConversationResult _LOGGER = logging.getLogger(__name__) _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that" diff --git a/homeassistant/components/conversation/http.py b/homeassistant/components/conversation/http.py new file mode 100644 index 00000000000..fb67d686b23 --- /dev/null +++ b/homeassistant/components/conversation/http.py @@ -0,0 +1,325 @@ +"""HTTP endpoints for conversation integration.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +from aiohttp import web +from hassil.recognize import ( + MISSING_ENTITY, + RecognizeResult, + UnmatchedRangeEntity, + UnmatchedTextEntity, +) +import voluptuous as vol + +from homeassistant.components import http, websocket_api +from homeassistant.components.http.data_validator import RequestDataValidator +from homeassistant.const import MATCH_ALL +from homeassistant.core import HomeAssistant, State, callback +from homeassistant.helpers import config_validation as cv, intent +from homeassistant.util import language as language_util + +from .agent_manager import agent_id_validator, async_converse, get_agent_manager +from .const import HOME_ASSISTANT_AGENT +from .default_agent import ( + METADATA_CUSTOM_FILE, + METADATA_CUSTOM_SENTENCE, + DefaultAgent, + SentenceTriggerResult, +) +from .models import ConversationInput + + +@callback +def async_setup(hass: HomeAssistant) -> None: + """Set up the HTTP API for the conversation integration.""" + hass.http.register_view(ConversationProcessView()) + websocket_api.async_register_command(hass, websocket_process) + websocket_api.async_register_command(hass, websocket_prepare) + websocket_api.async_register_command(hass, websocket_list_agents) + websocket_api.async_register_command(hass, websocket_hass_agent_debug) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "conversation/process", + vol.Required("text"): str, + vol.Optional("conversation_id"): vol.Any(str, None), + vol.Optional("language"): str, + vol.Optional("agent_id"): agent_id_validator, + } +) +@websocket_api.async_response +async def websocket_process( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Process text.""" + result = await async_converse( + hass=hass, + text=msg["text"], + conversation_id=msg.get("conversation_id"), + context=connection.context(msg), + language=msg.get("language"), + agent_id=msg.get("agent_id"), + ) + connection.send_result(msg["id"], result.as_dict()) + + +@websocket_api.websocket_command( + { + "type": "conversation/prepare", + vol.Optional("language"): str, + vol.Optional("agent_id"): agent_id_validator, + } +) +@websocket_api.async_response +async def websocket_prepare( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Reload intents.""" + manager = get_agent_manager(hass) + agent = await manager.async_get_agent(msg.get("agent_id")) + await agent.async_prepare(msg.get("language")) + connection.send_result(msg["id"]) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "conversation/agent/list", + vol.Optional("language"): str, + vol.Optional("country"): str, + } +) +@websocket_api.async_response +async def websocket_list_agents( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict +) -> None: + """List conversation agents and, optionally, if they support a given language.""" + manager = get_agent_manager(hass) + + country = msg.get("country") + language = msg.get("language") + agents = [] + + for agent_info in manager.async_get_agent_info(): + agent = await manager.async_get_agent(agent_info.id) + + supported_languages = agent.supported_languages + if language and supported_languages != MATCH_ALL: + supported_languages = language_util.matches( + language, supported_languages, country + ) + + agent_dict: dict[str, Any] = { + "id": agent_info.id, + "name": agent_info.name, + "supported_languages": supported_languages, + } + agents.append(agent_dict) + + connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents})) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "conversation/agent/homeassistant/debug", + vol.Required("sentences"): [str], + vol.Optional("language"): str, + vol.Optional("device_id"): vol.Any(str, None), + } +) +@websocket_api.async_response +async def websocket_hass_agent_debug( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict +) -> None: + """Return intents that would be matched by the default agent for a list of sentences.""" + agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT) + assert isinstance(agent, DefaultAgent) + results = [ + await agent.async_recognize( + ConversationInput( + text=sentence, + context=connection.context(msg), + conversation_id=None, + device_id=msg.get("device_id"), + language=msg.get("language", hass.config.language), + ) + ) + for sentence in msg["sentences"] + ] + + # Return results for each sentence in the same order as the input. + result_dicts: list[dict[str, Any] | None] = [] + for result in results: + result_dict: dict[str, Any] | None = None + if isinstance(result, SentenceTriggerResult): + result_dict = { + # Matched a user-defined sentence trigger. + # We can't provide the response here without executing the + # trigger. + "match": True, + "source": "trigger", + "sentence_template": result.sentence_template or "", + } + elif isinstance(result, RecognizeResult): + successful_match = not result.unmatched_entities + result_dict = { + # Name of the matching intent (or the closest) + "intent": { + "name": result.intent.name, + }, + # Slot values that would be received by the intent + "slots": { # direct access to values + entity_key: entity.text or entity.value + for entity_key, entity in result.entities.items() + }, + # Extra slot details, such as the originally matched text + "details": { + entity_key: { + "name": entity.name, + "value": entity.value, + "text": entity.text, + } + for entity_key, entity in result.entities.items() + }, + # Entities/areas/etc. that would be targeted + "targets": {}, + # True if match was successful + "match": successful_match, + # Text of the sentence template that matched (or was closest) + "sentence_template": "", + # When match is incomplete, this will contain the best slot guesses + "unmatched_slots": _get_unmatched_slots(result), + } + + if successful_match: + result_dict["targets"] = { + state.entity_id: {"matched": is_matched} + for state, is_matched in _get_debug_targets(hass, result) + } + + if result.intent_sentence is not None: + result_dict["sentence_template"] = result.intent_sentence.text + + # Inspect metadata to determine if this matched a custom sentence + if result.intent_metadata and result.intent_metadata.get( + METADATA_CUSTOM_SENTENCE + ): + result_dict["source"] = "custom" + result_dict["file"] = result.intent_metadata.get(METADATA_CUSTOM_FILE) + else: + result_dict["source"] = "builtin" + + result_dicts.append(result_dict) + + connection.send_result(msg["id"], {"results": result_dicts}) + + +def _get_debug_targets( + hass: HomeAssistant, + result: RecognizeResult, +) -> Iterable[tuple[State, bool]]: + """Yield state/is_matched pairs for a hassil recognition.""" + entities = result.entities + + name: str | None = None + area_name: str | None = None + domains: set[str] | None = None + device_classes: set[str] | None = None + state_names: set[str] | None = None + + if "name" in entities: + name = str(entities["name"].value) + + if "area" in entities: + area_name = str(entities["area"].value) + + if "domain" in entities: + domains = set(cv.ensure_list(entities["domain"].value)) + + if "device_class" in entities: + device_classes = set(cv.ensure_list(entities["device_class"].value)) + + if "state" in entities: + # HassGetState only + state_names = set(cv.ensure_list(entities["state"].value)) + + if ( + (name is None) + and (area_name is None) + and (not domains) + and (not device_classes) + and (not state_names) + ): + # Avoid "matching" all entities when there is no filter + return + + states = intent.async_match_states( + hass, + name=name, + area_name=area_name, + domains=domains, + device_classes=device_classes, + ) + + for state in states: + # For queries, a target is "matched" based on its state + is_matched = (state_names is None) or (state.state in state_names) + yield state, is_matched + + +def _get_unmatched_slots( + result: RecognizeResult, +) -> dict[str, str | int]: + """Return a dict of unmatched text/range slot entities.""" + unmatched_slots: dict[str, str | int] = {} + for entity in result.unmatched_entities_list: + if isinstance(entity, UnmatchedTextEntity): + if entity.text == MISSING_ENTITY: + # Don't report since these are just missing context + # slots. + continue + + unmatched_slots[entity.name] = entity.text + elif isinstance(entity, UnmatchedRangeEntity): + unmatched_slots[entity.name] = entity.value + + return unmatched_slots + + +class ConversationProcessView(http.HomeAssistantView): + """View to process text.""" + + url = "/api/conversation/process" + name = "api:conversation:process" + + @RequestDataValidator( + vol.Schema( + { + vol.Required("text"): str, + vol.Optional("conversation_id"): str, + vol.Optional("language"): str, + vol.Optional("agent_id"): agent_id_validator, + } + ) + ) + async def post(self, request: web.Request, data: dict[str, str]) -> web.Response: + """Send a request for processing.""" + hass = request.app[http.KEY_HASS] + + result = await async_converse( + hass, + text=data["text"], + conversation_id=data.get("conversation_id"), + context=self.context(request), + language=data.get("language"), + agent_id=data.get("agent_id"), + ) + + return self.json(result.as_dict()) diff --git a/homeassistant/components/conversation/agent.py b/homeassistant/components/conversation/models.py similarity index 100% rename from homeassistant/components/conversation/agent.py rename to homeassistant/components/conversation/models.py diff --git a/homeassistant/components/conversation/trigger.py b/homeassistant/components/conversation/trigger.py index 0fadc458352..05fea054bca 100644 --- a/homeassistant/components/conversation/trigger.py +++ b/homeassistant/components/conversation/trigger.py @@ -14,8 +14,8 @@ from homeassistant.helpers.script import ScriptRunResult from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.typing import UNDEFINED, ConfigType -from . import HOME_ASSISTANT_AGENT, _get_agent_manager -from .const import DOMAIN +from .agent_manager import get_agent_manager +from .const import DOMAIN, HOME_ASSISTANT_AGENT from .default_agent import DefaultAgent @@ -111,7 +111,7 @@ async def async_attach_trigger( # two trigger copies for who will provide a response. return None - default_agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT) + default_agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT) assert isinstance(default_agent, DefaultAgent) return default_agent.register_trigger(sentences, call_action) diff --git a/tests/components/cast/test_media_player.py b/tests/components/cast/test_media_player.py index 8381f27398a..d75aebe4ded 100644 --- a/tests/components/cast/test_media_player.py +++ b/tests/components/cast/test_media_player.py @@ -453,11 +453,13 @@ async def test_stop_discovery_called_on_stop( """Test pychromecast.stop_discovery called on shutdown.""" # start_discovery should be called with empty config await async_setup_cast(hass, {}) + await hass.async_block_till_done() assert castbrowser_mock.return_value.start_discovery.call_count == 1 # stop discovery should be called on shutdown hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) await hass.async_block_till_done() + await hass.async_block_till_done() assert castbrowser_mock.return_value.stop_discovery.call_count == 1 diff --git a/tests/components/conversation/__init__.py b/tests/components/conversation/__init__.py index 7209148e21f..fb9bcab7498 100644 --- a/tests/components/conversation/__init__.py +++ b/tests/components/conversation/__init__.py @@ -5,11 +5,16 @@ from __future__ import annotations from typing import Literal from homeassistant.components import conversation +from homeassistant.components.conversation.models import ( + ConversationInput, + ConversationResult, +) from homeassistant.components.homeassistant.exposed_entities import ( DATA_EXPOSED_ENTITIES, ExposedEntities, async_expose_entity, ) +from homeassistant.core import HomeAssistant from homeassistant.helpers import intent @@ -30,24 +35,22 @@ class MockAgent(conversation.AbstractConversationAgent): """Return a list of supported languages.""" return self._supported_languages - async def async_process( - self, user_input: conversation.ConversationInput - ) -> conversation.ConversationResult: + async def async_process(self, user_input: ConversationInput) -> ConversationResult: """Process some text.""" self.calls.append(user_input) response = intent.IntentResponse(language=user_input.language) response.async_set_speech(self.response) - return conversation.ConversationResult( + return ConversationResult( response=response, conversation_id=user_input.conversation_id ) -def expose_new(hass, expose_new): +def expose_new(hass: HomeAssistant, expose_new: bool): """Enable exposing new entities to the default agent.""" exposed_entities: ExposedEntities = hass.data[DATA_EXPOSED_ENTITIES] exposed_entities.async_set_expose_new_entities(conversation.DOMAIN, expose_new) -def expose_entity(hass, entity_id, should_expose): +def expose_entity(hass: HomeAssistant, entity_id: str, should_expose: bool): """Expose an entity to the default agent.""" async_expose_entity(hass, conversation.DOMAIN, entity_id, should_expose) diff --git a/tests/components/conversation/test_default_agent.py b/tests/components/conversation/test_default_agent.py index 8f38459a8da..c600c71711e 100644 --- a/tests/components/conversation/test_default_agent.py +++ b/tests/components/conversation/test_default_agent.py @@ -7,6 +7,7 @@ from hassil.recognize import Intent, IntentData, MatchEntity, RecognizeResult import pytest from homeassistant.components import conversation +from homeassistant.components.conversation import agent_manager, default_agent from homeassistant.components.homeassistant.exposed_entities import ( async_get_assistant_settings, ) @@ -151,7 +152,7 @@ async def test_conversation_agent( init_components, ) -> None: """Test DefaultAgent.""" - agent = await conversation._get_agent_manager(hass).async_get_agent( + agent = await agent_manager.get_agent_manager(hass).async_get_agent( conversation.HOME_ASSISTANT_AGENT ) with patch( @@ -253,10 +254,10 @@ async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None: trigger_sentences = ["It's party time", "It is time to party"] trigger_response = "Cowabunga!" - agent = await conversation._get_agent_manager(hass).async_get_agent( + agent = await agent_manager.get_agent_manager(hass).async_get_agent( conversation.HOME_ASSISTANT_AGENT ) - assert isinstance(agent, conversation.DefaultAgent) + assert isinstance(agent, default_agent.DefaultAgent) callback = AsyncMock(return_value=trigger_response) unregister = agent.register_trigger(trigger_sentences, callback) @@ -850,7 +851,7 @@ async def test_empty_aliases( ) with patch( - "homeassistant.components.conversation.DefaultAgent._recognize", + "homeassistant.components.conversation.default_agent.DefaultAgent._recognize", return_value=None, ) as mock_recognize_all: await conversation.async_converse( diff --git a/tests/components/conversation/test_init.py b/tests/components/conversation/test_init.py index 7b2c44a755d..62f67548ece 100644 --- a/tests/components/conversation/test_init.py +++ b/tests/components/conversation/test_init.py @@ -9,6 +9,8 @@ from syrupy.assertion import SnapshotAssertion import voluptuous as vol from homeassistant.components import conversation +from homeassistant.components.conversation import agent_manager, default_agent +from homeassistant.components.conversation.models import ConversationInput from homeassistant.components.cover import SERVICE_OPEN_COVER from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN from homeassistant.const import ATTR_FRIENDLY_NAME @@ -750,8 +752,8 @@ async def test_ws_prepare( """Test the Websocket prepare conversation API.""" assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "conversation", {}) - agent = await conversation._get_agent_manager(hass).async_get_agent() - assert isinstance(agent, conversation.DefaultAgent) + agent = await agent_manager.get_agent_manager(hass).async_get_agent() + assert isinstance(agent, default_agent.DefaultAgent) # No intents should be loaded yet assert not agent._lang_intents.get(hass.config.language) @@ -852,8 +854,8 @@ async def test_prepare_reload(hass: HomeAssistant) -> None: assert await async_setup_component(hass, "conversation", {}) # Load intents - agent = await conversation._get_agent_manager(hass).async_get_agent() - assert isinstance(agent, conversation.DefaultAgent) + agent = await agent_manager.get_agent_manager(hass).async_get_agent() + assert isinstance(agent, default_agent.DefaultAgent) await agent.async_prepare(language) # Confirm intents are loaded @@ -880,8 +882,8 @@ async def test_prepare_fail(hass: HomeAssistant) -> None: assert await async_setup_component(hass, "conversation", {}) # Load intents - agent = await conversation._get_agent_manager(hass).async_get_agent() - assert isinstance(agent, conversation.DefaultAgent) + agent = await agent_manager.get_agent_manager(hass).async_get_agent() + assert isinstance(agent, default_agent.DefaultAgent) await agent.async_prepare("not-a-language") # Confirm no intents were loaded @@ -917,11 +919,11 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non hass.states.async_set("cover.front_door", "closed") calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER) - agent = await conversation._get_agent_manager(hass).async_get_agent() - assert isinstance(agent, conversation.DefaultAgent) + agent = await agent_manager.get_agent_manager(hass).async_get_agent() + assert isinstance(agent, default_agent.DefaultAgent) result = await agent.async_process( - conversation.ConversationInput( + ConversationInput( text="open the front door", context=Context(), conversation_id=None, diff --git a/tests/components/conversation/test_trigger.py b/tests/components/conversation/test_trigger.py index 221789b49e0..33ad8efdd2e 100644 --- a/tests/components/conversation/test_trigger.py +++ b/tests/components/conversation/test_trigger.py @@ -5,7 +5,8 @@ import logging import pytest import voluptuous as vol -from homeassistant.components import conversation +from homeassistant.components.conversation import agent_manager, default_agent +from homeassistant.components.conversation.models import ConversationInput from homeassistant.core import Context, HomeAssistant from homeassistant.helpers import trigger from homeassistant.setup import async_setup_component @@ -514,11 +515,11 @@ async def test_trigger_with_device_id(hass: HomeAssistant) -> None: }, ) - agent = await conversation._get_agent_manager(hass).async_get_agent() - assert isinstance(agent, conversation.DefaultAgent) + agent = await agent_manager.get_agent_manager(hass).async_get_agent() + assert isinstance(agent, default_agent.DefaultAgent) result = await agent.async_process( - conversation.ConversationInput( + ConversationInput( text="test sentence", context=Context(), conversation_id=None, diff --git a/tests/components/google_assistant_sdk/test_init.py b/tests/components/google_assistant_sdk/test_init.py index 2d930599c24..7c2fc8291d4 100644 --- a/tests/components/google_assistant_sdk/test_init.py +++ b/tests/components/google_assistant_sdk/test_init.py @@ -334,7 +334,7 @@ async def test_conversation_agent( entry = entries[0] assert entry.state is ConfigEntryState.LOADED - agent = await conversation._get_agent_manager(hass).async_get_agent(entry.entry_id) + agent = await conversation.get_agent_manager(hass).async_get_agent(entry.entry_id) assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES text1 = "tell me a joke" diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index b77fa14b4cf..92e84b1fd39 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -152,7 +152,7 @@ async def test_conversation_agent( mock_init_component, ) -> None: """Test GoogleGenerativeAIAgent.""" - agent = await conversation._get_agent_manager(hass).async_get_agent( + agent = await conversation.get_agent_manager(hass).async_get_agent( mock_config_entry.entry_id ) assert agent.supported_languages == "*" diff --git a/tests/components/mobile_app/test_webhook.py b/tests/components/mobile_app/test_webhook.py index dfab474f127..9d941685c09 100644 --- a/tests/components/mobile_app/test_webhook.py +++ b/tests/components/mobile_app/test_webhook.py @@ -1033,7 +1033,7 @@ async def test_webhook_handle_conversation_process( webhook_client.server.app.router._frozen = False with patch( - "homeassistant.components.conversation.AgentManager.async_get_agent", + "homeassistant.components.conversation.agent_manager.AgentManager.async_get_agent", return_value=mock_conversation_agent, ): resp = await webhook_client.post( diff --git a/tests/components/ollama/test_init.py b/tests/components/ollama/test_init.py index ffe69ca4628..6dd9dc73973 100644 --- a/tests/components/ollama/test_init.py +++ b/tests/components/ollama/test_init.py @@ -229,7 +229,7 @@ async def test_message_history_pruning( assert isinstance(result.conversation_id, str) conversation_ids.append(result.conversation_id) - agent = await conversation._get_agent_manager(hass).async_get_agent( + agent = await conversation.get_agent_manager(hass).async_get_agent( mock_config_entry.entry_id ) assert isinstance(agent, ollama.OllamaAgent) @@ -284,7 +284,7 @@ async def test_message_history_unlimited( result.response.response_type == intent.IntentResponseType.ACTION_DONE ), result - agent = await conversation._get_agent_manager(hass).async_get_agent( + agent = await conversation.get_agent_manager(hass).async_get_agent( mock_config_entry.entry_id ) assert isinstance(agent, ollama.OllamaAgent) @@ -340,7 +340,7 @@ async def test_conversation_agent( mock_init_component, ) -> None: """Test OllamaAgent.""" - agent = await conversation._get_agent_manager(hass).async_get_agent( + agent = await conversation.get_agent_manager(hass).async_get_agent( mock_config_entry.entry_id ) assert agent.supported_languages == MATCH_ALL diff --git a/tests/components/openai_conversation/test_init.py b/tests/components/openai_conversation/test_init.py index 3a8db2a71c0..c94fdcebcde 100644 --- a/tests/components/openai_conversation/test_init.py +++ b/tests/components/openai_conversation/test_init.py @@ -194,7 +194,7 @@ async def test_conversation_agent( mock_init_component, ) -> None: """Test OpenAIAgent.""" - agent = await conversation._get_agent_manager(hass).async_get_agent( + agent = await conversation.get_agent_manager(hass).async_get_agent( mock_config_entry.entry_id ) assert agent.supported_languages == "*"