Add conversation entity (#114518)

* Default agent as entity

* Migrate constant to point at new value

* Fix tests

* Fix more tests

* Move assist pipeline back to cloud after dependenceis
This commit is contained in:
Paulus Schoutsen 2024-04-01 21:34:25 -04:00 committed by GitHub
parent b1af590eed
commit d2e4f5f36e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 566 additions and 177 deletions

View File

@ -376,6 +376,10 @@ class Pipeline:
This function was added in HA Core 2023.10, previous versions will raise This function was added in HA Core 2023.10, previous versions will raise
if there are unexpected items in the serialized data. if there are unexpected items in the serialized data.
""" """
# Migrate to new value for conversation agent
if data["conversation_engine"] == conversation.OLD_HOME_ASSISTANT_AGENT:
data["conversation_engine"] = conversation.HOME_ASSISTANT_AGENT
return cls( return cls(
conversation_engine=data["conversation_engine"], conversation_engine=data["conversation_engine"],
conversation_language=data["conversation_language"], conversation_language=data["conversation_language"],

View File

@ -223,7 +223,10 @@ class CloudLoginView(HomeAssistantView):
cloud: Cloud[CloudClient] = hass.data[DOMAIN] cloud: Cloud[CloudClient] = hass.data[DOMAIN]
await cloud.login(data["email"], data["password"]) await cloud.login(data["email"], data["password"])
if "assist_pipeline" in hass.config.components:
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass) new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)
else:
new_cloud_pipeline_id = None
return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id}) return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id})

View File

@ -24,6 +24,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_when_setup
from .assist_pipeline import async_migrate_cloud_pipeline_engine from .assist_pipeline import async_migrate_cloud_pipeline_engine
from .client import CloudClient from .client import CloudClient
@ -86,10 +87,19 @@ class CloudProviderEntity(SpeechToTextEntity):
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Run when entity is about to be added to hass.""" """Run when entity is about to be added to hass."""
await async_migrate_cloud_pipeline_engine(
async def pipeline_setup(hass: HomeAssistant, _comp: str) -> None:
"""When assist_pipeline is set up."""
assert self.platform.config_entry
self.platform.config_entry.async_create_task(
hass,
async_migrate_cloud_pipeline_engine(
self.hass, platform=Platform.STT, engine_id=self.entity_id self.hass, platform=Platform.STT, engine_id=self.entity_id
),
) )
async_when_setup(self.hass, "assist_pipeline", pipeline_setup)
async def async_process_audio_stream( async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult: ) -> SpeechResult:

View File

@ -27,6 +27,7 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_when_setup
from .assist_pipeline import async_migrate_cloud_pipeline_engine from .assist_pipeline import async_migrate_cloud_pipeline_engine
from .client import CloudClient from .client import CloudClient
@ -156,9 +157,19 @@ class CloudTTSEntity(TextToSpeechEntity):
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Handle entity which will be added.""" """Handle entity which will be added."""
await super().async_added_to_hass() await super().async_added_to_hass()
await async_migrate_cloud_pipeline_engine(
async def pipeline_setup(hass: HomeAssistant, _comp: str) -> None:
"""When assist_pipeline is set up."""
assert self.platform.config_entry
self.platform.config_entry.async_create_task(
hass,
async_migrate_cloud_pipeline_engine(
self.hass, platform=Platform.TTS, engine_id=self.entity_id self.hass, platform=Platform.TTS, engine_id=self.entity_id
),
) )
async_when_setup(self.hass, "assist_pipeline", pipeline_setup)
self.async_on_remove( self.async_on_remove(
self.cloud.client.prefs.async_listen_updates(self._sync_prefs) self.cloud.client.prefs.async_listen_updates(self._sync_prefs)
) )

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable
import logging import logging
import re import re
from typing import Literal from typing import Literal
@ -20,6 +19,7 @@ from homeassistant.core import (
) )
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent from homeassistant.helpers import config_validation as cv, intent
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
@ -27,15 +27,19 @@ from .agent_manager import (
AgentInfo, AgentInfo,
agent_id_validator, agent_id_validator,
async_converse, async_converse,
async_get_agent,
get_agent_manager, get_agent_manager,
) )
from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT from .const import HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT
from .default_agent import async_get_default_agent, async_setup_default_agent
from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult from .models import AbstractConversationAgent, ConversationInput, ConversationResult
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
"HOME_ASSISTANT_AGENT", "HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT",
"async_converse", "async_converse",
"async_get_agent_info", "async_get_agent_info",
"async_set_agent", "async_set_agent",
@ -122,16 +126,26 @@ async def async_get_conversation_languages(
all conversation agents. all conversation agents.
""" """
agent_manager = get_agent_manager(hass) agent_manager = get_agent_manager(hass)
entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
languages: set[str] = set() languages: set[str] = set()
agents: list[ConversationEntity | AbstractConversationAgent]
if agent_id:
agent = async_get_agent(hass, agent_id)
if agent is None:
raise ValueError(f"Agent {agent_id} not found")
agents = [agent]
agent_ids: Iterable[str]
if agent_id is None:
agent_ids = iter(info.id for info in agent_manager.async_get_agent_info())
else: else:
agent_ids = (agent_id,) agents = list(entity_component.entities)
for info in agent_manager.async_get_agent_info():
agent = agent_manager.async_get_agent(info.id)
assert agent is not None
agents.append(agent)
for _agent_id in agent_ids: for agent in agents:
agent = await agent_manager.async_get_agent(_agent_id)
if agent.supported_languages == MATCH_ALL: if agent.supported_languages == MATCH_ALL:
return MATCH_ALL return MATCH_ALL
for language_tag in agent.supported_languages: for language_tag in agent.supported_languages:
@ -146,10 +160,18 @@ def async_get_agent_info(
agent_id: str | None = None, agent_id: str | None = None,
) -> AgentInfo | None: ) -> AgentInfo | None:
"""Get information on the agent or None if not found.""" """Get information on the agent or None if not found."""
manager = get_agent_manager(hass) agent = async_get_agent(hass, agent_id)
if agent_id is None: if agent is None:
agent_id = manager.default_agent return None
if isinstance(agent, ConversationEntity):
name = agent.name
if not isinstance(name, str):
name = agent.entity_id
return AgentInfo(id=agent.entity_id, name=name)
manager = get_agent_manager(hass)
for agent_info in manager.async_get_agent_info(): for agent_info in manager.async_get_agent_info():
if agent_info.id == agent_id: if agent_info.id == agent_id:
@ -160,10 +182,11 @@ def async_get_agent_info(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service.""" """Register the process service."""
agent_manager = get_agent_manager(hass) entity_component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass)
if config_intents := config.get(DOMAIN, {}).get("intents"): await async_setup_default_agent(
hass.data[DATA_CONFIG] = config_intents hass, entity_component, config.get(DOMAIN, {}).get("intents", {})
)
async def handle_process(service: ServiceCall) -> ServiceResponse: async def handle_process(service: ServiceCall) -> ServiceResponse:
"""Parse text into commands.""" """Parse text into commands."""
@ -188,7 +211,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def handle_reload(service: ServiceCall) -> None: async def handle_reload(service: ServiceCall) -> None:
"""Reload intents.""" """Reload intents."""
agent = await agent_manager.async_get_agent() agent = async_get_default_agent(hass)
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE)) await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
hass.services.async_register( hass.services.async_register(

View File

@ -2,8 +2,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from dataclasses import dataclass
import logging import logging
from typing import Any from typing import Any
@ -11,10 +9,17 @@ import voluptuous as vol
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
from homeassistant.helpers import config_validation as cv, singleton from homeassistant.helpers import config_validation as cv, singleton
from homeassistant.helpers.entity_component import EntityComponent
from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT from .const import DOMAIN, HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT
from .default_agent import DefaultAgent, async_setup as async_setup_default_agent from .default_agent import async_get_default_agent
from .models import AbstractConversationAgent, ConversationInput, ConversationResult from .entity import ConversationEntity
from .models import (
AbstractConversationAgent,
AgentInfo,
ConversationInput,
ConversationResult,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -23,20 +28,37 @@ _LOGGER = logging.getLogger(__name__)
@callback @callback
def get_agent_manager(hass: HomeAssistant) -> AgentManager: def get_agent_manager(hass: HomeAssistant) -> AgentManager:
"""Get the active agent.""" """Get the active agent."""
manager = AgentManager(hass) return AgentManager(hass)
manager.async_setup()
return manager
def agent_id_validator(value: Any) -> str: def agent_id_validator(value: Any) -> str:
"""Validate agent ID.""" """Validate agent ID."""
hass = async_get_hass() hass = async_get_hass()
manager = get_agent_manager(hass) if async_get_agent(hass, cv.string(value)) is None:
if not manager.async_is_valid_agent_id(cv.string(value)):
raise vol.Invalid("invalid agent ID") raise vol.Invalid("invalid agent ID")
return value return value
@callback
def async_get_agent(
hass: HomeAssistant, agent_id: str | None = None
) -> AbstractConversationAgent | ConversationEntity | None:
"""Get specified agent."""
if agent_id is None or agent_id in (HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT):
return async_get_default_agent(hass)
if "." in agent_id:
entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
return entity_component.get_entity(agent_id)
manager = get_agent_manager(hass)
if not manager.async_is_valid_agent_id(agent_id):
return None
return manager.async_get_agent(agent_id)
async def async_converse( async def async_converse(
hass: HomeAssistant, hass: HomeAssistant,
text: str, text: str,
@ -47,13 +69,22 @@ async def async_converse(
device_id: str | None = None, device_id: str | None = None,
) -> ConversationResult: ) -> ConversationResult:
"""Process text and get intent.""" """Process text and get intent."""
agent = await get_agent_manager(hass).async_get_agent(agent_id) agent = async_get_agent(hass, agent_id)
if agent is None:
raise ValueError(f"Agent {agent_id} not found")
if isinstance(agent, ConversationEntity):
agent.async_set_context(context)
method = agent.internal_async_process
else:
method = agent.async_process
if language is None: if language is None:
language = hass.config.language language = hass.config.language
_LOGGER.debug("Processing in %s: %s", language, text) _LOGGER.debug("Processing in %s: %s", language, text)
result = await agent.async_process( result = await method(
ConversationInput( ConversationInput(
text=text, text=text,
context=context, context=context,
@ -65,52 +96,17 @@ async def async_converse(
return result return result
@dataclass(frozen=True)
class AgentInfo:
"""Container for conversation agent info."""
id: str
name: str
class AgentManager: class AgentManager:
"""Class to manage conversation agents.""" """Class to manage conversation agents."""
default_agent: str = HOME_ASSISTANT_AGENT
_builtin_agent: AbstractConversationAgent | None = None
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the conversation agents.""" """Initialize the conversation agents."""
self.hass = hass self.hass = hass
self._agents: dict[str, AbstractConversationAgent] = {} self._agents: dict[str, AbstractConversationAgent] = {}
self._builtin_agent_init_lock = asyncio.Lock()
def async_setup(self) -> None: @callback
"""Set up the conversation agents.""" def async_get_agent(self, agent_id: str) -> AbstractConversationAgent | None:
async_setup_default_agent(self.hass)
async def async_get_agent(
self, agent_id: str | None = None
) -> AbstractConversationAgent:
"""Get the agent.""" """Get the agent."""
if agent_id is None:
agent_id = self.default_agent
if agent_id == HOME_ASSISTANT_AGENT:
if self._builtin_agent is not None:
return self._builtin_agent
async with self._builtin_agent_init_lock:
if self._builtin_agent is not None:
return self._builtin_agent
self._builtin_agent = DefaultAgent(self.hass)
await self._builtin_agent.async_initialize(
self.hass.data.get(DATA_CONFIG)
)
return self._builtin_agent
if agent_id not in self._agents: if agent_id not in self._agents:
raise ValueError(f"Agent {agent_id} not found") raise ValueError(f"Agent {agent_id} not found")
@ -119,12 +115,7 @@ class AgentManager:
@callback @callback
def async_get_agent_info(self) -> list[AgentInfo]: def async_get_agent_info(self) -> list[AgentInfo]:
"""List all agents.""" """List all agents."""
agents: list[AgentInfo] = [ agents: list[AgentInfo] = []
AgentInfo(
id=HOME_ASSISTANT_AGENT,
name="Home Assistant",
)
]
for agent_id, agent in self._agents.items(): for agent_id, agent in self._agents.items():
config_entry = self.hass.config_entries.async_get_entry(agent_id) config_entry = self.hass.config_entries.async_get_entry(agent_id)
@ -148,7 +139,7 @@ class AgentManager:
@callback @callback
def async_is_valid_agent_id(self, agent_id: str) -> bool: def async_is_valid_agent_id(self, agent_id: str) -> bool:
"""Check if the agent id is valid.""" """Check if the agent id is valid."""
return agent_id in self._agents or agent_id == HOME_ASSISTANT_AGENT return agent_id in self._agents
@callback @callback
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None: def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:

View File

@ -2,5 +2,5 @@
DOMAIN = "conversation" DOMAIN = "conversation"
DEFAULT_EXPOSED_ATTRIBUTES = {"device_class"} DEFAULT_EXPOSED_ATTRIBUTES = {"device_class"}
HOME_ASSISTANT_AGENT = "homeassistant" HOME_ASSISTANT_AGENT = "conversation.home_assistant"
DATA_CONFIG = "conversation_config" OLD_HOME_ASSISTANT_AGENT = "homeassistant"

View File

@ -24,7 +24,7 @@ from hassil.util import merge_dict
from home_assistant_intents import ErrorKey, get_intents, get_languages from home_assistant_intents import ErrorKey, get_intents, get_languages
import yaml import yaml
from homeassistant import core, setup from homeassistant import core
from homeassistant.components.homeassistant.exposed_entities import ( from homeassistant.components.homeassistant.exposed_entities import (
async_listen_entity_updates, async_listen_entity_updates,
async_should_expose, async_should_expose,
@ -40,6 +40,7 @@ from homeassistant.helpers import (
template, template,
translation, translation,
) )
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
EventStateChangedData, EventStateChangedData,
async_track_state_added_domain, async_track_state_added_domain,
@ -47,7 +48,8 @@ from homeassistant.helpers.event import (
from homeassistant.util.json import JsonObjectType, json_loads_object from homeassistant.util.json import JsonObjectType, json_loads_object
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN
from .models import AbstractConversationAgent, ConversationInput, ConversationResult from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that" _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
@ -60,6 +62,14 @@ TRIGGER_CALLBACK_TYPE = Callable[
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence" METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
METADATA_CUSTOM_FILE = "hass_custom_file" METADATA_CUSTOM_FILE = "hass_custom_file"
DATA_DEFAULT_ENTITY = "conversation_default_entity"
@core.callback
def async_get_default_agent(hass: core.HomeAssistant) -> DefaultAgent:
"""Get the default agent."""
return hass.data[DATA_DEFAULT_ENTITY]
def json_load(fp: IO[str]) -> JsonObjectType: def json_load(fp: IO[str]) -> JsonObjectType:
"""Wrap json_loads for get_intents.""" """Wrap json_loads for get_intents."""
@ -109,9 +119,16 @@ def _get_language_variations(language: str) -> Iterable[str]:
yield lang yield lang
@core.callback async def async_setup_default_agent(
def async_setup(hass: core.HomeAssistant) -> None: hass: core.HomeAssistant,
entity_component: EntityComponent[ConversationEntity],
config_intents: dict[str, Any],
) -> None:
"""Set up entity registry listener for the default agent.""" """Set up entity registry listener for the default agent."""
entity = DefaultAgent(hass, config_intents)
await entity_component.async_add_entities([entity])
hass.data[DATA_DEFAULT_ENTITY] = entity
entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)
for entity_id in entity_registry.entities: for entity_id in entity_registry.entities:
async_should_expose(hass, DOMAIN, entity_id) async_should_expose(hass, DOMAIN, entity_id)
@ -131,17 +148,21 @@ def async_setup(hass: core.HomeAssistant) -> None:
start.async_at_started(hass, async_hass_started) start.async_at_started(hass, async_hass_started)
class DefaultAgent(AbstractConversationAgent): class DefaultAgent(ConversationEntity):
"""Default agent for conversation agent.""" """Default agent for conversation agent."""
def __init__(self, hass: core.HomeAssistant) -> None: _attr_name = "Home Assistant"
def __init__(
self, hass: core.HomeAssistant, config_intents: dict[str, Any]
) -> None:
"""Initialize the default agent.""" """Initialize the default agent."""
self.hass = hass self.hass = hass
self._lang_intents: dict[str, LanguageIntents] = {} self._lang_intents: dict[str, LanguageIntents] = {}
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
# intent -> [sentences] # intent -> [sentences]
self._config_intents: dict[str, Any] = {} self._config_intents: dict[str, Any] = config_intents
self._slot_lists: dict[str, SlotList] | None = None self._slot_lists: dict[str, SlotList] | None = None
# Sentences that will trigger a callback (skipping intent recognition) # Sentences that will trigger a callback (skipping intent recognition)
@ -154,15 +175,6 @@ class DefaultAgent(AbstractConversationAgent):
"""Return a list of supported languages.""" """Return a list of supported languages."""
return get_languages() return get_languages()
async def async_initialize(self, config_intents: dict[str, Any] | None) -> None:
"""Initialize the default agent."""
if "intent" not in self.hass.config.components:
await setup.async_setup_component(self.hass, "intent", {})
# Intents from config may only contains sentences for HA config's language
if config_intents:
self._config_intents = config_intents
@core.callback @core.callback
def _filter_entity_registry_changes(self, event_data: dict[str, Any]) -> bool: def _filter_entity_registry_changes(self, event_data: dict[str, Any]) -> bool:
"""Filter entity registry changed events.""" """Filter entity registry changed events."""

View File

@ -0,0 +1,57 @@
"""Entity for conversation integration."""
from abc import abstractmethod
from typing import Literal, final
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util
from .models import ConversationInput, ConversationResult
class ConversationEntity(RestoreEntity):
"""Entity that supports conversations."""
_attr_should_poll = False
__last_activity: str | None = None
@property
@final
def state(self) -> str | None:
"""Return the state of the entity."""
if self.__last_activity is None:
return None
return self.__last_activity
async def async_internal_added_to_hass(self) -> None:
"""Call when the entity is added to hass."""
await super().async_internal_added_to_hass()
state = await self.async_get_last_state()
if (
state is not None
and state.state is not None
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
):
self.__last_activity = state.state
@final
async def internal_async_process(
self, user_input: ConversationInput
) -> ConversationResult:
"""Process a sentence."""
self.__last_activity = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self.async_process(user_input)
@property
@abstractmethod
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
@abstractmethod
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence."""
async def async_prepare(self, language: str | None = None) -> None:
"""Load intents for a language."""

View File

@ -19,16 +19,24 @@ from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import MATCH_ALL from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant, State, callback from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, intent from homeassistant.helpers import config_validation as cv, intent
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.util import language as language_util from homeassistant.util import language as language_util
from .agent_manager import agent_id_validator, async_converse, get_agent_manager from .agent_manager import (
from .const import HOME_ASSISTANT_AGENT agent_id_validator,
async_converse,
async_get_agent,
get_agent_manager,
)
from .const import DOMAIN
from .default_agent import ( from .default_agent import (
METADATA_CUSTOM_FILE, METADATA_CUSTOM_FILE,
METADATA_CUSTOM_SENTENCE, METADATA_CUSTOM_SENTENCE,
DefaultAgent, DefaultAgent,
SentenceTriggerResult, SentenceTriggerResult,
async_get_default_agent,
) )
from .entity import ConversationEntity
from .models import ConversationInput from .models import ConversationInput
@ -83,8 +91,14 @@ async def websocket_prepare(
msg: dict[str, Any], msg: dict[str, Any],
) -> None: ) -> None:
"""Reload intents.""" """Reload intents."""
manager = get_agent_manager(hass) agent = async_get_agent(hass, msg.get("agent_id"))
agent = await manager.async_get_agent(msg.get("agent_id"))
if agent is None:
connection.send_error(
msg["id"], websocket_api.const.ERR_NOT_FOUND, "Agent not found"
)
return
await agent.async_prepare(msg.get("language")) await agent.async_prepare(msg.get("language"))
connection.send_result(msg["id"]) connection.send_result(msg["id"])
@ -101,14 +115,32 @@ async def websocket_list_agents(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None: ) -> None:
"""List conversation agents and, optionally, if they support a given language.""" """List conversation agents and, optionally, if they support a given language."""
manager = get_agent_manager(hass) entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
country = msg.get("country") country = msg.get("country")
language = msg.get("language") language = msg.get("language")
agents = [] agents = []
for entity in entity_component.entities:
supported_languages = entity.supported_languages
if language and supported_languages != MATCH_ALL:
supported_languages = language_util.matches(
language, supported_languages, country
)
agents.append(
{
"id": entity.entity_id,
"name": entity.name or entity.entity_id,
"supported_languages": supported_languages,
}
)
manager = get_agent_manager(hass)
for agent_info in manager.async_get_agent_info(): for agent_info in manager.async_get_agent_info():
agent = await manager.async_get_agent(agent_info.id) agent = manager.async_get_agent(agent_info.id)
assert agent is not None
supported_languages = agent.supported_languages supported_languages = agent.supported_languages
if language and supported_languages != MATCH_ALL: if language and supported_languages != MATCH_ALL:
@ -139,7 +171,7 @@ async def websocket_hass_agent_debug(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None: ) -> None:
"""Return intents that would be matched by the default agent for a list of sentences.""" """Return intents that would be matched by the default agent for a list of sentences."""
agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT) agent = async_get_default_agent(hass)
assert isinstance(agent, DefaultAgent) assert isinstance(agent, DefaultAgent)
results = [ results = [
await agent.async_recognize( await agent.async_recognize(

View File

@ -2,7 +2,7 @@
"domain": "conversation", "domain": "conversation",
"name": "Conversation", "name": "Conversation",
"codeowners": ["@home-assistant/core", "@synesthesiam"], "codeowners": ["@home-assistant/core", "@synesthesiam"],
"dependencies": ["http"], "dependencies": ["http", "intent"],
"documentation": "https://www.home-assistant.io/integrations/conversation", "documentation": "https://www.home-assistant.io/integrations/conversation",
"integration_type": "system", "integration_type": "system",
"iot_class": "local_push", "iot_class": "local_push",

View File

@ -10,6 +10,14 @@ from homeassistant.core import Context
from homeassistant.helpers import intent from homeassistant.helpers import intent
@dataclass(frozen=True)
class AgentInfo:
"""Container for conversation agent info."""
id: str
name: str
@dataclass(slots=True) @dataclass(slots=True)
class ConversationInput: class ConversationInput:
"""User input to be processed.""" """User input to be processed."""

View File

@ -14,9 +14,8 @@ from homeassistant.helpers.script import ScriptRunResult
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import UNDEFINED, ConfigType from homeassistant.helpers.typing import UNDEFINED, ConfigType
from .agent_manager import get_agent_manager from .const import DOMAIN
from .const import DOMAIN, HOME_ASSISTANT_AGENT from .default_agent import DefaultAgent, async_get_default_agent
from .default_agent import DefaultAgent
def has_no_punctuation(value: list[str]) -> list[str]: def has_no_punctuation(value: list[str]) -> list[str]:
@ -111,7 +110,7 @@ async def async_attach_trigger(
# two trigger copies for who will provide a response. # two trigger copies for who will provide a response.
return None return None
default_agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT) default_agent = async_get_default_agent(hass)
assert isinstance(default_agent, DefaultAgent) assert isinstance(default_agent, DefaultAgent)
return default_agent.register_trigger(sentences, call_action) return default_agent.register_trigger(sentences, call_action)

View File

@ -34,7 +34,7 @@
'data': dict({ 'data': dict({
'conversation_id': None, 'conversation_id': None,
'device_id': None, 'device_id': None,
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en', 'language': 'en',
}), }),
@ -123,7 +123,7 @@
'data': dict({ 'data': dict({
'conversation_id': None, 'conversation_id': None,
'device_id': None, 'device_id': None,
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en-US', 'language': 'en-US',
}), }),
@ -212,7 +212,7 @@
'data': dict({ 'data': dict({
'conversation_id': None, 'conversation_id': None,
'device_id': None, 'device_id': None,
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en-US', 'language': 'en-US',
}), }),
@ -325,7 +325,7 @@
'data': dict({ 'data': dict({
'conversation_id': None, 'conversation_id': None,
'device_id': None, 'device_id': None,
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en', 'language': 'en',
}), }),

View File

@ -33,7 +33,7 @@
dict({ dict({
'conversation_id': None, 'conversation_id': None,
'device_id': None, 'device_id': None,
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en', 'language': 'en',
}) })
@ -114,7 +114,7 @@
dict({ dict({
'conversation_id': None, 'conversation_id': None,
'device_id': None, 'device_id': None,
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en', 'language': 'en',
}) })
@ -207,7 +207,7 @@
dict({ dict({
'conversation_id': None, 'conversation_id': None,
'device_id': None, 'device_id': None,
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en', 'language': 'en',
}) })
@ -409,7 +409,7 @@
dict({ dict({
'conversation_id': None, 'conversation_id': None,
'device_id': None, 'device_id': None,
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en', 'language': 'en',
}) })
@ -615,7 +615,7 @@
dict({ dict({
'conversation_id': None, 'conversation_id': None,
'device_id': None, 'device_id': None,
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?', 'intent_input': 'Are the lights on?',
'language': 'en', 'language': 'en',
}) })
@ -637,7 +637,7 @@
dict({ dict({
'conversation_id': None, 'conversation_id': None,
'device_id': None, 'device_id': None,
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?', 'intent_input': 'Are the lights on?',
'language': 'en', 'language': 'en',
}) })
@ -665,7 +665,7 @@
dict({ dict({
'conversation_id': None, 'conversation_id': None,
'device_id': None, 'device_id': None,
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'never mind', 'intent_input': 'never mind',
'language': 'en', 'language': 'en',
}) })
@ -799,7 +799,7 @@
dict({ dict({
'conversation_id': 'mock-conversation-id', 'conversation_id': 'mock-conversation-id',
'device_id': 'mock-device-id', 'device_id': 'mock-device-id',
'engine': 'homeassistant', 'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?', 'intent_input': 'Are the lights on?',
'language': 'en', 'language': 'en',
}) })

View File

@ -6,6 +6,7 @@ from unittest.mock import ANY, patch
import pytest import pytest
from homeassistant.components import conversation
from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import ( from homeassistant.components.assist_pipeline.pipeline import (
STORAGE_KEY, STORAGE_KEY,
@ -117,6 +118,7 @@ async def test_loading_pipelines_from_storage(
hass: HomeAssistant, hass_storage: dict[str, Any] hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None: ) -> None:
"""Test loading stored pipelines on start.""" """Test loading stored pipelines on start."""
id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY"
hass_storage[STORAGE_KEY] = { hass_storage[STORAGE_KEY] = {
"version": STORAGE_VERSION, "version": STORAGE_VERSION,
"minor_version": STORAGE_VERSION_MINOR, "minor_version": STORAGE_VERSION_MINOR,
@ -124,9 +126,9 @@ async def test_loading_pipelines_from_storage(
"data": { "data": {
"items": [ "items": [
{ {
"conversation_engine": "conversation_engine_1", "conversation_engine": conversation.OLD_HOME_ASSISTANT_AGENT,
"conversation_language": "language_1", "conversation_language": "language_1",
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY", "id": id_1,
"language": "language_1", "language": "language_1",
"name": "name_1", "name": "name_1",
"stt_engine": "stt_engine_1", "stt_engine": "stt_engine_1",
@ -166,7 +168,7 @@ async def test_loading_pipelines_from_storage(
"wake_word_id": "wakeword_id_3", "wake_word_id": "wakeword_id_3",
}, },
], ],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY", "preferred_item": id_1,
}, },
} }
@ -175,7 +177,8 @@ async def test_loading_pipelines_from_storage(
pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store store = pipeline_data.pipeline_store
assert len(store.data) == 3 assert len(store.data) == 3
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY" assert store.async_get_preferred_item() == id_1
assert store.data[id_1].conversation_engine == conversation.HOME_ASSISTANT_AGENT
async def test_migrate_pipeline_store( async def test_migrate_pipeline_store(
@ -262,7 +265,7 @@ async def test_create_default_pipeline(
tts_engine_id="test", tts_engine_id="test",
pipeline_name="Test pipeline", pipeline_name="Test pipeline",
) == Pipeline( ) == Pipeline(
conversation_engine="homeassistant", conversation_engine="conversation.home_assistant",
conversation_language="en", conversation_language="en",
id=ANY, id=ANY,
language="en", language="en",
@ -304,7 +307,7 @@ async def test_get_pipelines(hass: HomeAssistant) -> None:
pipelines = async_get_pipelines(hass) pipelines = async_get_pipelines(hass)
assert list(pipelines) == [ assert list(pipelines) == [
Pipeline( Pipeline(
conversation_engine="homeassistant", conversation_engine="conversation.home_assistant",
conversation_language="en", conversation_language="en",
id=ANY, id=ANY,
language="en", language="en",
@ -351,7 +354,7 @@ async def test_default_pipeline_no_stt_tts(
# Check the default pipeline # Check the default pipeline
pipeline = async_get_pipeline(hass, None) pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline( assert pipeline == Pipeline(
conversation_engine="homeassistant", conversation_engine="conversation.home_assistant",
conversation_language=conv_language, conversation_language=conv_language,
id=pipeline.id, id=pipeline.id,
language=pipeline_language, language=pipeline_language,
@ -414,7 +417,7 @@ async def test_default_pipeline(
# Check the default pipeline # Check the default pipeline
pipeline = async_get_pipeline(hass, None) pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline( assert pipeline == Pipeline(
conversation_engine="homeassistant", conversation_engine="conversation.home_assistant",
conversation_language=conv_language, conversation_language=conv_language,
id=pipeline.id, id=pipeline.id,
language=pipeline_language, language=pipeline_language,
@ -445,7 +448,7 @@ async def test_default_pipeline_unsupported_stt_language(
# Check the default pipeline # Check the default pipeline
pipeline = async_get_pipeline(hass, None) pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline( assert pipeline == Pipeline(
conversation_engine="homeassistant", conversation_engine="conversation.home_assistant",
conversation_language="en", conversation_language="en",
id=pipeline.id, id=pipeline.id,
language="en", language="en",
@ -476,7 +479,7 @@ async def test_default_pipeline_unsupported_tts_language(
# Check the default pipeline # Check the default pipeline
pipeline = async_get_pipeline(hass, None) pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline( assert pipeline == Pipeline(
conversation_engine="homeassistant", conversation_engine="conversation.home_assistant",
conversation_language="en", conversation_language="en",
id=pipeline.id, id=pipeline.id,
language="en", language="en",
@ -502,7 +505,7 @@ async def test_update_pipeline(
pipelines = list(pipelines) pipelines = list(pipelines)
assert pipelines == [ assert pipelines == [
Pipeline( Pipeline(
conversation_engine="homeassistant", conversation_engine="conversation.home_assistant",
conversation_language="en", conversation_language="en",
id=ANY, id=ANY,
language="en", language="en",

View File

@ -1166,7 +1166,7 @@ async def test_get_pipeline(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"conversation_engine": "homeassistant", "conversation_engine": "conversation.home_assistant",
"conversation_language": "en", "conversation_language": "en",
"id": ANY, "id": ANY,
"language": "en", "language": "en",
@ -1250,7 +1250,7 @@ async def test_list_pipelines(
assert msg["result"] == { assert msg["result"] == {
"pipelines": [ "pipelines": [
{ {
"conversation_engine": "homeassistant", "conversation_engine": "conversation.home_assistant",
"conversation_language": "en", "conversation_language": "en",
"id": ANY, "id": ANY,
"language": "en", "language": "en",
@ -2012,7 +2012,7 @@ async def test_wake_word_cooldown_different_entities(
await client_pipeline.send_json_auto_id( await client_pipeline.send_json_auto_id(
{ {
"type": "assist_pipeline/pipeline/create", "type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant", "conversation_engine": "conversation.home_assistant",
"conversation_language": "en-US", "conversation_language": "en-US",
"language": "en", "language": "en",
"name": "pipeline_with_wake_word_1", "name": "pipeline_with_wake_word_1",
@ -2032,7 +2032,7 @@ async def test_wake_word_cooldown_different_entities(
await client_pipeline.send_json_auto_id( await client_pipeline.send_json_auto_id(
{ {
"type": "assist_pipeline/pipeline/create", "type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant", "conversation_engine": "conversation.home_assistant",
"conversation_language": "en-US", "conversation_language": "en-US",
"language": "en", "language": "en",
"name": "pipeline_with_wake_word_2", "name": "pipeline_with_wake_word_2",

View File

@ -7,10 +7,12 @@ from homeassistant.components.cloud.assist_pipeline import (
) )
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
async def test_migrate_pipeline_invalid_platform(hass: HomeAssistant) -> None: async def test_migrate_pipeline_invalid_platform(hass: HomeAssistant) -> None:
"""Test migrate pipeline with invalid platform.""" """Test migrate pipeline with invalid platform."""
await async_setup_component(hass, "assist_pipeline", {})
with pytest.raises(ValueError): with pytest.raises(ValueError):
await async_migrate_cloud_pipeline_engine( await async_migrate_cloud_pipeline_engine(
hass, Platform.BINARY_SENSOR, "test-engine-id" hass, Platform.BINARY_SENSOR, "test-engine-id"

View File

@ -231,6 +231,7 @@ async def test_login_view_create_pipeline(
} }
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "assist_pipeline", {})
assert await async_setup_component(hass, DOMAIN, {"cloud": {}}) assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done() await hass.async_block_till_done()
@ -270,6 +271,7 @@ async def test_login_view_create_pipeline_fail(
} }
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "assist_pipeline", {})
assert await async_setup_component(hass, DOMAIN, {"cloud": {}}) assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done() await hass.async_block_till_done()

View File

@ -7,6 +7,8 @@ import pytest
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.shopping_list import intent as sl_intent from homeassistant.components.shopping_list import intent as sl_intent
from homeassistant.const import MATCH_ALL from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import MockAgent from . import MockAgent
@ -14,7 +16,7 @@ from tests.common import MockConfigEntry
@pytest.fixture @pytest.fixture
def mock_agent_support_all(hass): def mock_agent_support_all(hass: HomeAssistant):
"""Mock agent that supports all languages.""" """Mock agent that supports all languages."""
entry = MockConfigEntry(entry_id="mock-entry-support-all") entry = MockConfigEntry(entry_id="mock-entry-support-all")
entry.add_to_hass(hass) entry.add_to_hass(hass)
@ -34,7 +36,7 @@ def mock_shopping_list_io():
@pytest.fixture @pytest.fixture
async def sl_setup(hass): async def sl_setup(hass: HomeAssistant):
"""Set up the shopping list.""" """Set up the shopping list."""
entry = MockConfigEntry(domain="shopping_list") entry = MockConfigEntry(domain="shopping_list")
@ -43,3 +45,10 @@ async def sl_setup(hass):
assert await hass.config_entries.async_setup(entry.entry_id) assert await hass.config_entries.async_setup(entry.entry_id)
await sl_intent.async_setup_intents(hass) await sl_intent.async_setup_intents(hass)
@pytest.fixture
async def init_components(hass: HomeAssistant):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})

View File

@ -101,7 +101,7 @@
# --- # ---
# name: test_get_agent_info # name: test_get_agent_info
dict({ dict({
'id': 'homeassistant', 'id': 'conversation.home_assistant',
'name': 'Home Assistant', 'name': 'Home Assistant',
}) })
# --- # ---
@ -113,7 +113,7 @@
# --- # ---
# name: test_get_agent_info.2 # name: test_get_agent_info.2
dict({ dict({
'id': 'homeassistant', 'id': 'conversation.home_assistant',
'name': 'Home Assistant', 'name': 'Home Assistant',
}) })
# --- # ---
@ -127,7 +127,7 @@
dict({ dict({
'agents': list([ 'agents': list([
dict({ dict({
'id': 'homeassistant', 'id': 'conversation.home_assistant',
'name': 'Home Assistant', 'name': 'Home Assistant',
'supported_languages': list([ 'supported_languages': list([
'af', 'af',
@ -207,7 +207,7 @@
dict({ dict({
'agents': list([ 'agents': list([
dict({ dict({
'id': 'homeassistant', 'id': 'conversation.home_assistant',
'name': 'Home Assistant', 'name': 'Home Assistant',
'supported_languages': list([ 'supported_languages': list([
]), ]),
@ -231,7 +231,7 @@
dict({ dict({
'agents': list([ 'agents': list([
dict({ dict({
'id': 'homeassistant', 'id': 'conversation.home_assistant',
'name': 'Home Assistant', 'name': 'Home Assistant',
'supported_languages': list([ 'supported_languages': list([
'en', 'en',
@ -255,7 +255,7 @@
dict({ dict({
'agents': list([ 'agents': list([
dict({ dict({
'id': 'homeassistant', 'id': 'conversation.home_assistant',
'name': 'Home Assistant', 'name': 'Home Assistant',
'supported_languages': list([ 'supported_languages': list([
'en', 'en',
@ -279,7 +279,7 @@
dict({ dict({
'agents': list([ 'agents': list([
dict({ dict({
'id': 'homeassistant', 'id': 'conversation.home_assistant',
'name': 'Home Assistant', 'name': 'Home Assistant',
'supported_languages': list([ 'supported_languages': list([
'de', 'de',
@ -304,7 +304,7 @@
dict({ dict({
'agents': list([ 'agents': list([
dict({ dict({
'id': 'homeassistant', 'id': 'conversation.home_assistant',
'name': 'Home Assistant', 'name': 'Home Assistant',
'supported_languages': list([ 'supported_languages': list([
'de-CH', 'de-CH',
@ -415,6 +415,36 @@
}), }),
}) })
# --- # ---
# name: test_http_processing_intent[conversation.home_assistant]
dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
dict({
'id': 'light.kitchen',
'name': 'kitchen',
'type': 'entity',
}),
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Turned on the light',
}),
}),
}),
})
# ---
# name: test_http_processing_intent[homeassistant] # name: test_http_processing_intent[homeassistant]
dict({ dict({
'conversation_id': None, 'conversation_id': None,
@ -1035,6 +1065,36 @@
}), }),
}) })
# --- # ---
# name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant]
dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
dict({
'id': 'light.kitchen',
'name': 'kitchen',
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
}),
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Turned on the light',
}),
}),
}),
})
# ---
# name: test_turn_on_intent[None-turn kitchen on-homeassistant] # name: test_turn_on_intent[None-turn kitchen on-homeassistant]
dict({ dict({
'conversation_id': None, 'conversation_id': None,
@ -1095,6 +1155,36 @@
}), }),
}) })
# --- # ---
# name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant]
dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
dict({
'id': 'light.kitchen',
'name': 'kitchen',
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
}),
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Turned on the light',
}),
}),
}),
})
# ---
# name: test_turn_on_intent[None-turn on kitchen-homeassistant] # name: test_turn_on_intent[None-turn on kitchen-homeassistant]
dict({ dict({
'conversation_id': None, 'conversation_id': None,
@ -1155,6 +1245,36 @@
}), }),
}) })
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant]
dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
dict({
'id': 'light.kitchen',
'name': 'kitchen',
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
}),
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Turned on the light',
}),
}),
}),
})
# ---
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant] # name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant]
dict({ dict({
'conversation_id': None, 'conversation_id': None,
@ -1215,6 +1335,36 @@
}), }),
}) })
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant]
dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
dict({
'id': 'light.kitchen',
'name': 'kitchen',
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
}),
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Turned on the light',
}),
}),
}),
})
# ---
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant] # name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant]
dict({ dict({
'conversation_id': None, 'conversation_id': None,

View File

@ -7,7 +7,7 @@ from hassil.recognize import Intent, IntentData, MatchEntity, RecognizeResult
import pytest import pytest
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.conversation import agent_manager, default_agent from homeassistant.components.conversation import default_agent
from homeassistant.components.homeassistant.exposed_entities import ( from homeassistant.components.homeassistant.exposed_entities import (
async_get_assistant_settings, async_get_assistant_settings,
) )
@ -152,9 +152,7 @@ async def test_conversation_agent(
init_components, init_components,
) -> None: ) -> None:
"""Test DefaultAgent.""" """Test DefaultAgent."""
agent = await agent_manager.get_agent_manager(hass).async_get_agent( agent = default_agent.async_get_default_agent(hass)
conversation.HOME_ASSISTANT_AGENT
)
with patch( with patch(
"homeassistant.components.conversation.default_agent.get_languages", "homeassistant.components.conversation.default_agent.get_languages",
return_value=["dwarvish", "elvish", "entish"], return_value=["dwarvish", "elvish", "entish"],
@ -181,6 +179,7 @@ async def test_expose_flag_automatically_set(
# After setting up conversation, the expose flag should now be set on all entities # After setting up conversation, the expose flag should now be set on all entities
assert async_get_assistant_settings(hass, conversation.DOMAIN) == { assert async_get_assistant_settings(hass, conversation.DOMAIN) == {
"conversation.home_assistant": {"should_expose": False},
light.entity_id: {"should_expose": True}, light.entity_id: {"should_expose": True},
test.entity_id: {"should_expose": False}, test.entity_id: {"should_expose": False},
} }
@ -190,6 +189,7 @@ async def test_expose_flag_automatically_set(
hass.states.async_set(new_light, "test") hass.states.async_set(new_light, "test")
await hass.async_block_till_done() await hass.async_block_till_done()
assert async_get_assistant_settings(hass, conversation.DOMAIN) == { assert async_get_assistant_settings(hass, conversation.DOMAIN) == {
"conversation.home_assistant": {"should_expose": False},
light.entity_id: {"should_expose": True}, light.entity_id: {"should_expose": True},
new_light: {"should_expose": True}, new_light: {"should_expose": True},
test.entity_id: {"should_expose": False}, test.entity_id: {"should_expose": False},
@ -254,9 +254,7 @@ async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None:
trigger_sentences = ["It's party time", "It is time to party"] trigger_sentences = ["It's party time", "It is time to party"]
trigger_response = "Cowabunga!" trigger_response = "Cowabunga!"
agent = await agent_manager.get_agent_manager(hass).async_get_agent( agent = default_agent.async_get_default_agent(hass)
conversation.HOME_ASSISTANT_AGENT
)
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
callback = AsyncMock(return_value=trigger_response) callback = AsyncMock(return_value=trigger_response)

View File

@ -0,0 +1,47 @@
"""Tests for conversation entity."""
from unittest.mock import patch
from homeassistant.core import Context, HomeAssistant, State
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
from tests.common import mock_restore_cache
async def test_state_set_and_restore(hass: HomeAssistant) -> None:
"""Test we set and restore state in the integration."""
entity_id = "conversation.home_assistant"
timestamp = "2023-01-01T23:59:59+00:00"
mock_restore_cache(hass, (State(entity_id, timestamp),))
await async_setup_component(hass, "homeassistant", {})
await async_setup_component(hass, "conversation", {})
state = hass.states.get(entity_id)
assert state
assert state.state == timestamp
now = dt_util.utcnow()
context = Context()
with (
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process"
) as mock_process,
patch("homeassistant.util.dt.utcnow", return_value=now),
):
await hass.services.async_call(
"conversation",
"process",
{"text": "Hello"},
context=context,
blocking=True,
)
assert len(mock_process.mock_calls) == 1
state = hass.states.get(entity_id)
assert state
assert state.state == now.isoformat()
assert state.context is context

View File

@ -9,7 +9,7 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.conversation import agent_manager, default_agent from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.models import ConversationInput from homeassistant.components.conversation.models import ConversationInput
from homeassistant.components.cover import SERVICE_OPEN_COVER from homeassistant.components.cover import SERVICE_OPEN_COVER
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
@ -35,7 +35,13 @@ from tests.common import (
from tests.components.light.common import MockLight from tests.components.light.common import MockLight
from tests.typing import ClientSessionGenerator, WebSocketGenerator from tests.typing import ClientSessionGenerator, WebSocketGenerator
AGENT_ID_OPTIONS = [None, conversation.HOME_ASSISTANT_AGENT] AGENT_ID_OPTIONS = [
None,
# Old value of conversation.HOME_ASSISTANT_AGENT,
"homeassistant",
# Current value of conversation.HOME_ASSISTANT_AGENT,
"conversation.home_assistant",
]
class OrderBeerIntentHandler(intent.IntentHandler): class OrderBeerIntentHandler(intent.IntentHandler):
@ -51,14 +57,6 @@ class OrderBeerIntentHandler(intent.IntentHandler):
return response return response
@pytest.fixture
async def init_components(hass):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
assert await async_setup_component(hass, "intent", {})
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS) @pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
async def test_http_processing_intent( async def test_http_processing_intent(
hass: HomeAssistant, hass: HomeAssistant,
@ -752,7 +750,7 @@ async def test_ws_prepare(
"""Test the Websocket prepare conversation API.""" """Test the Websocket prepare conversation API."""
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
agent = await agent_manager.get_agent_manager(hass).async_get_agent() agent = default_agent.async_get_default_agent(hass)
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
# No intents should be loaded yet # No intents should be loaded yet
@ -854,7 +852,7 @@ async def test_prepare_reload(hass: HomeAssistant) -> None:
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
# Load intents # Load intents
agent = await agent_manager.get_agent_manager(hass).async_get_agent() agent = default_agent.async_get_default_agent(hass)
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
await agent.async_prepare(language) await agent.async_prepare(language)
@ -882,7 +880,7 @@ async def test_prepare_fail(hass: HomeAssistant) -> None:
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
# Load intents # Load intents
agent = await agent_manager.get_agent_manager(hass).async_get_agent() agent = default_agent.async_get_default_agent(hass)
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
await agent.async_prepare("not-a-language") await agent.async_prepare("not-a-language")
@ -919,7 +917,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
hass.states.async_set("cover.front_door", "closed") hass.states.async_set("cover.front_door", "closed")
calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER) calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
agent = await agent_manager.get_agent_manager(hass).async_get_agent() agent = default_agent.async_get_default_agent(hass)
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
result = await agent.async_process( result = await agent.async_process(
@ -1063,12 +1061,15 @@ async def test_light_area_same_name(
assert call.data == {"entity_id": [kitchen_light.entity_id]} assert call.data == {"entity_id": [kitchen_light.entity_id]}
async def test_agent_id_validator_invalid_agent(hass: HomeAssistant) -> None: async def test_agent_id_validator_invalid_agent(
hass: HomeAssistant, init_components
) -> None:
"""Test validating agent id.""" """Test validating agent id."""
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
conversation.agent_id_validator("invalid_agent") conversation.agent_id_validator("invalid_agent")
conversation.agent_id_validator(conversation.HOME_ASSISTANT_AGENT) conversation.agent_id_validator(conversation.HOME_ASSISTANT_AGENT)
conversation.agent_id_validator("conversation.home_assistant")
async def test_get_agent_list( async def test_get_agent_list(

View File

@ -5,7 +5,7 @@ import logging
import pytest import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.components.conversation import agent_manager, default_agent from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.models import ConversationInput from homeassistant.components.conversation.models import ConversationInput
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import trigger from homeassistant.helpers import trigger
@ -515,7 +515,7 @@ async def test_trigger_with_device_id(hass: HomeAssistant) -> None:
}, },
) )
agent = await agent_manager.get_agent_manager(hass).async_get_agent() agent = default_agent.async_get_default_agent(hass)
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
result = await agent.async_process( result = await agent.async_process(

View File

@ -327,6 +327,7 @@ async def test_conversation_agent(
"""Test GoogleAssistantConversationAgent.""" """Test GoogleAssistantConversationAgent."""
await setup_integration() await setup_integration()
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
entries = hass.config_entries.async_entries(DOMAIN) entries = hass.config_entries.async_entries(DOMAIN)
@ -334,7 +335,7 @@ async def test_conversation_agent(
entry = entries[0] entry = entries[0]
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
agent = await conversation.get_agent_manager(hass).async_get_agent(entry.entry_id) agent = conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES
text1 = "tell me a joke" text1 = "tell me a joke"
@ -365,6 +366,7 @@ async def test_conversation_agent_refresh_token(
"""Test GoogleAssistantConversationAgent when token is expired.""" """Test GoogleAssistantConversationAgent when token is expired."""
await setup_integration() await setup_integration()
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
entries = hass.config_entries.async_entries(DOMAIN) entries = hass.config_entries.async_entries(DOMAIN)
@ -416,6 +418,7 @@ async def test_conversation_agent_language_changed(
"""Test GoogleAssistantConversationAgent when language is changed.""" """Test GoogleAssistantConversationAgent when language is changed."""
await setup_integration() await setup_integration()
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
entries = hass.config_entries.async_entries(DOMAIN) entries = hass.config_entries.async_entries(DOMAIN)

View File

@ -4,6 +4,8 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -23,10 +25,17 @@ def mock_config_entry(hass):
@pytest.fixture @pytest.fixture
async def mock_init_component(hass, mock_config_entry): async def mock_init_component(hass: HomeAssistant, mock_config_entry: ConfigEntry):
"""Initialize integration.""" """Initialize integration."""
assert await async_setup_component(hass, "homeassistant", {})
with patch("google.generativeai.get_model"): with patch("google.generativeai.get_model"):
assert await async_setup_component( assert await async_setup_component(
hass, "google_generative_ai_conversation", {} hass, "google_generative_ai_conversation", {}
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@pytest.fixture(autouse=True)
async def setup_ha(hass: HomeAssistant) -> None:
"""Set up Home Assistant."""
assert await async_setup_component(hass, "homeassistant", {})

View File

@ -10,6 +10,7 @@ from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -124,6 +125,7 @@ async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:
"""Test that template error handling works.""" """Test that template error handling works."""
assert await async_setup_component(hass, "homeassistant", {})
hass.config_entries.async_update_entry( hass.config_entries.async_update_entry(
mock_config_entry, mock_config_entry,
options={ options={
@ -152,7 +154,7 @@ async def test_conversation_agent(
mock_init_component, mock_init_component,
) -> None: ) -> None:
"""Test GoogleGenerativeAIAgent.""" """Test GoogleGenerativeAIAgent."""
agent = await conversation.get_agent_manager(hass).async_get_agent( agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id mock_config_entry.entry_id
) )
assert agent.supported_languages == "*" assert agent.supported_languages == "*"

View File

@ -1033,7 +1033,7 @@ async def test_webhook_handle_conversation_process(
webhook_client.server.app.router._frozen = False webhook_client.server.app.router._frozen = False
with patch( with patch(
"homeassistant.components.conversation.agent_manager.AgentManager.async_get_agent", "homeassistant.components.conversation.agent_manager.async_get_agent",
return_value=mock_conversation_agent, return_value=mock_conversation_agent,
): ):
resp = await webhook_client.post( resp = await webhook_client.post(

View File

@ -35,3 +35,9 @@ async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfig
): ):
assert await async_setup_component(hass, ollama.DOMAIN, {}) assert await async_setup_component(hass, ollama.DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
@pytest.fixture(autouse=True)
async def setup_ha(hass: HomeAssistant) -> None:
"""Set up Home Assistant."""
assert await async_setup_component(hass, "homeassistant", {})

View File

@ -229,7 +229,7 @@ async def test_message_history_pruning(
assert isinstance(result.conversation_id, str) assert isinstance(result.conversation_id, str)
conversation_ids.append(result.conversation_id) conversation_ids.append(result.conversation_id)
agent = await conversation.get_agent_manager(hass).async_get_agent( agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id mock_config_entry.entry_id
) )
assert isinstance(agent, ollama.OllamaAgent) assert isinstance(agent, ollama.OllamaAgent)
@ -284,7 +284,7 @@ async def test_message_history_unlimited(
result.response.response_type == intent.IntentResponseType.ACTION_DONE result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result ), result
agent = await conversation.get_agent_manager(hass).async_get_agent( agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id mock_config_entry.entry_id
) )
assert isinstance(agent, ollama.OllamaAgent) assert isinstance(agent, ollama.OllamaAgent)
@ -340,7 +340,7 @@ async def test_conversation_agent(
mock_init_component, mock_init_component,
) -> None: ) -> None:
"""Test OllamaAgent.""" """Test OllamaAgent."""
agent = await conversation.get_agent_manager(hass).async_get_agent( agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id mock_config_entry.entry_id
) )
assert agent.supported_languages == MATCH_ALL assert agent.supported_languages == MATCH_ALL

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -30,3 +31,9 @@ async def mock_init_component(hass, mock_config_entry):
): ):
assert await async_setup_component(hass, "openai_conversation", {}) assert await async_setup_component(hass, "openai_conversation", {})
await hass.async_block_till_done() await hass.async_block_till_done()
@pytest.fixture(autouse=True)
async def setup_ha(hass: HomeAssistant) -> None:
"""Set up Home Assistant."""
assert await async_setup_component(hass, "homeassistant", {})

View File

@ -194,7 +194,7 @@ async def test_conversation_agent(
mock_init_component, mock_init_component,
) -> None: ) -> None:
"""Test OpenAIAgent.""" """Test OpenAIAgent."""
agent = await conversation.get_agent_manager(hass).async_get_agent( agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id mock_config_entry.entry_id
) )
assert agent.supported_languages == "*" assert agent.supported_languages == "*"