Use HassKey in conversation (#126332)

* Use HassKey in conversation

* Adjust tests
This commit is contained in:
epenet 2024-09-22 17:54:14 +02:00 committed by GitHub
parent f9e7721653
commit f8a53aea09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 35 additions and 33 deletions

View File

@ -35,6 +35,7 @@ from .const import (
ATTR_CONVERSATION_ID, ATTR_CONVERSATION_ID,
ATTR_LANGUAGE, ATTR_LANGUAGE,
ATTR_TEXT, ATTR_TEXT,
DATA_DEFAULT_ENTITY,
DOMAIN, DOMAIN,
DOMAIN_DATA, DOMAIN_DATA,
HOME_ASSISTANT_AGENT, HOME_ASSISTANT_AGENT,
@ -43,7 +44,7 @@ from .const import (
SERVICE_RELOAD, SERVICE_RELOAD,
ConversationEntityFeature, ConversationEntityFeature,
) )
from .default_agent import async_get_default_agent, async_setup_default_agent from .default_agent import async_setup_default_agent
from .entity import ConversationEntity from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult from .models import AbstractConversationAgent, ConversationInput, ConversationResult
@ -247,8 +248,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def handle_reload(service: ServiceCall) -> None: async def handle_reload(service: ServiceCall) -> None:
"""Reload intents.""" """Reload intents."""
agent = async_get_default_agent(hass) await hass.data[DATA_DEFAULT_ENTITY].async_reload(
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE)) language=service.data.get(ATTR_LANGUAGE)
)
hass.services.async_register( hass.services.async_register(
DOMAIN, DOMAIN,

View File

@ -11,8 +11,12 @@ import voluptuous as vol
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
from homeassistant.helpers import config_validation as cv, singleton from homeassistant.helpers import config_validation as cv, singleton
from .const import DOMAIN_DATA, HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT from .const import (
from .default_agent import async_get_default_agent DATA_DEFAULT_ENTITY,
DOMAIN_DATA,
HOME_ASSISTANT_AGENT,
OLD_HOME_ASSISTANT_AGENT,
)
from .entity import ConversationEntity from .entity import ConversationEntity
from .models import ( from .models import (
AbstractConversationAgent, AbstractConversationAgent,
@ -50,7 +54,7 @@ def async_get_agent(
) -> AbstractConversationAgent | ConversationEntity | None: ) -> AbstractConversationAgent | ConversationEntity | None:
"""Get specified agent.""" """Get specified agent."""
if agent_id is None or agent_id in (HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT): if agent_id is None or agent_id in (HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT):
return async_get_default_agent(hass) return hass.data[DATA_DEFAULT_ENTITY]
if "." in agent_id: if "." in agent_id:
return hass.data[DOMAIN_DATA].get_entity(agent_id) return hass.data[DOMAIN_DATA].get_entity(agent_id)

View File

@ -10,6 +10,7 @@ from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from .default_agent import DefaultAgent
from .entity import ConversationEntity from .entity import ConversationEntity
DOMAIN = "conversation" DOMAIN = "conversation"
@ -26,6 +27,7 @@ SERVICE_PROCESS = "process"
SERVICE_RELOAD = "reload" SERVICE_RELOAD = "reload"
DOMAIN_DATA: HassKey[EntityComponent[ConversationEntity]] = HassKey(DOMAIN) DOMAIN_DATA: HassKey[EntityComponent[ConversationEntity]] = HassKey(DOMAIN)
DATA_DEFAULT_ENTITY: HassKey[DefaultAgent] = HassKey(f"{DOMAIN}_default_entity")
class ConversationEntityFeature(IntFlag): class ConversationEntityFeature(IntFlag):

View File

@ -44,7 +44,12 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_added_domain from homeassistant.helpers.event import async_track_state_added_domain
from homeassistant.util.json import JsonObjectType, json_loads_object from homeassistant.util.json import JsonObjectType, json_loads_object
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN, ConversationEntityFeature from .const import (
DATA_DEFAULT_ENTITY,
DEFAULT_EXPOSED_ATTRIBUTES,
DOMAIN,
ConversationEntityFeature,
)
from .entity import ConversationEntity from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult from .models import ConversationInput, ConversationResult
from .trace import ConversationTraceEventType, async_conversation_trace_append from .trace import ConversationTraceEventType, async_conversation_trace_append
@ -60,16 +65,9 @@ TRIGGER_CALLBACK_TYPE = Callable[
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence" METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
METADATA_CUSTOM_FILE = "hass_custom_file" METADATA_CUSTOM_FILE = "hass_custom_file"
DATA_DEFAULT_ENTITY = "conversation_default_entity"
ERROR_SENTINEL = object() ERROR_SENTINEL = object()
@core.callback
def async_get_default_agent(hass: core.HomeAssistant) -> DefaultAgent:
"""Get the default agent."""
return hass.data[DATA_DEFAULT_ENTITY]
def json_load(fp: IO[str]) -> JsonObjectType: def json_load(fp: IO[str]) -> JsonObjectType:
"""Wrap json_loads for get_intents.""" """Wrap json_loads for get_intents."""
return json_loads_object(fp.read()) return json_loads_object(fp.read())

View File

@ -27,13 +27,11 @@ from .agent_manager import (
async_get_agent, async_get_agent,
get_agent_manager, get_agent_manager,
) )
from .const import DOMAIN_DATA from .const import DATA_DEFAULT_ENTITY, DOMAIN_DATA
from .default_agent import ( from .default_agent import (
METADATA_CUSTOM_FILE, METADATA_CUSTOM_FILE,
METADATA_CUSTOM_SENTENCE, METADATA_CUSTOM_SENTENCE,
DefaultAgent,
SentenceTriggerResult, SentenceTriggerResult,
async_get_default_agent,
) )
from .entity import ConversationEntity from .entity import ConversationEntity
from .models import ConversationInput from .models import ConversationInput
@ -173,10 +171,8 @@ async def websocket_hass_agent_debug(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None: ) -> None:
"""Return intents that would be matched by the default agent for a list of sentences.""" """Return intents that would be matched by the default agent for a list of sentences."""
agent = async_get_default_agent(hass)
assert isinstance(agent, DefaultAgent)
results = [ results = [
await agent.async_recognize( await hass.data[DATA_DEFAULT_ENTITY].async_recognize(
ConversationInput( ConversationInput(
text=sentence, text=sentence,
context=connection.context(msg), context=connection.context(msg),

View File

@ -14,8 +14,7 @@ from homeassistant.helpers.script import ScriptRunResult
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import UNDEFINED, ConfigType from homeassistant.helpers.typing import UNDEFINED, ConfigType
from .const import DOMAIN from .const import DATA_DEFAULT_ENTITY, DOMAIN
from .default_agent import DefaultAgent, async_get_default_agent
def has_no_punctuation(value: list[str]) -> list[str]: def has_no_punctuation(value: list[str]) -> list[str]:
@ -110,7 +109,4 @@ async def async_attach_trigger(
# two trigger copies for who will provide a response. # two trigger copies for who will provide a response.
return None return None
default_agent = async_get_default_agent(hass) return hass.data[DATA_DEFAULT_ENTITY].register_trigger(sentences, call_action)
assert isinstance(default_agent, DefaultAgent)
return default_agent.register_trigger(sentences, call_action)

View File

@ -13,6 +13,7 @@ import yaml
from homeassistant.components import conversation, cover, media_player from homeassistant.components import conversation, cover, media_player
from homeassistant.components.conversation import default_agent from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
from homeassistant.components.conversation.models import ConversationInput from homeassistant.components.conversation.models import ConversationInput
from homeassistant.components.cover import SERVICE_OPEN_COVER from homeassistant.components.cover import SERVICE_OPEN_COVER
from homeassistant.components.homeassistant.exposed_entities import ( from homeassistant.components.homeassistant.exposed_entities import (
@ -203,7 +204,7 @@ async def test_exposed_areas(
@pytest.mark.usefixtures("init_components") @pytest.mark.usefixtures("init_components")
async def test_conversation_agent(hass: HomeAssistant) -> None: async def test_conversation_agent(hass: HomeAssistant) -> None:
"""Test DefaultAgent.""" """Test DefaultAgent."""
agent = default_agent.async_get_default_agent(hass) agent = hass.data[DATA_DEFAULT_ENTITY]
with patch( with patch(
"homeassistant.components.conversation.default_agent.get_languages", "homeassistant.components.conversation.default_agent.get_languages",
return_value=["dwarvish", "elvish", "entish"], return_value=["dwarvish", "elvish", "entish"],
@ -380,7 +381,7 @@ async def test_trigger_sentences(hass: HomeAssistant) -> None:
trigger_sentences = ["It's party time", "It is time to party"] trigger_sentences = ["It's party time", "It is time to party"]
trigger_response = "Cowabunga!" trigger_response = "Cowabunga!"
agent = default_agent.async_get_default_agent(hass) agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
callback = AsyncMock(return_value=trigger_response) callback = AsyncMock(return_value=trigger_response)
@ -1905,7 +1906,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
hass.states.async_set("cover.front_door", "closed") hass.states.async_set("cover.front_door", "closed")
calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER) calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
agent = default_agent.async_get_default_agent(hass) agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
result = await agent.async_process( result = await agent.async_process(

View File

@ -8,6 +8,7 @@ import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components.conversation import default_agent from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.const import ATTR_FRIENDLY_NAME from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -214,7 +215,7 @@ async def test_ws_prepare(
hass: HomeAssistant, init_components, hass_ws_client: WebSocketGenerator, agent_id hass: HomeAssistant, init_components, hass_ws_client: WebSocketGenerator, agent_id
) -> None: ) -> None:
"""Test the Websocket prepare conversation API.""" """Test the Websocket prepare conversation API."""
agent = default_agent.async_get_default_agent(hass) agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
# No intents should be loaded yet # No intents should be loaded yet

View File

@ -9,6 +9,7 @@ import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.conversation import default_agent from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -143,7 +144,7 @@ async def test_prepare_reload(hass: HomeAssistant, init_components) -> None:
language = hass.config.language language = hass.config.language
# Load intents # Load intents
agent = default_agent.async_get_default_agent(hass) agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
await agent.async_prepare(language) await agent.async_prepare(language)
@ -171,7 +172,7 @@ async def test_prepare_fail(hass: HomeAssistant) -> None:
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
# Load intents # Load intents
agent = default_agent.async_get_default_agent(hass) agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
await agent.async_prepare("not-a-language") await agent.async_prepare("not-a-language")

View File

@ -6,6 +6,7 @@ import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.components.conversation import default_agent from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
from homeassistant.components.conversation.models import ConversationInput from homeassistant.components.conversation.models import ConversationInput
from homeassistant.core import Context, HomeAssistant, ServiceCall from homeassistant.core import Context, HomeAssistant, ServiceCall
from homeassistant.helpers import trigger from homeassistant.helpers import trigger
@ -550,7 +551,7 @@ async def test_trigger_with_device_id(hass: HomeAssistant) -> None:
}, },
) )
agent = default_agent.async_get_default_agent(hass) agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent) assert isinstance(agent, default_agent.DefaultAgent)
result = await agent.async_process( result = await agent.async_process(