Use HassKey in conversation (#126332)

* Use HassKey in conversation

* Adjust tests
This commit is contained in:
epenet 2024-09-22 17:54:14 +02:00 committed by GitHub
parent f9e7721653
commit f8a53aea09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 35 additions and 33 deletions

View File

@ -35,6 +35,7 @@ from .const import (
ATTR_CONVERSATION_ID,
ATTR_LANGUAGE,
ATTR_TEXT,
DATA_DEFAULT_ENTITY,
DOMAIN,
DOMAIN_DATA,
HOME_ASSISTANT_AGENT,
@ -43,7 +44,7 @@ from .const import (
SERVICE_RELOAD,
ConversationEntityFeature,
)
from .default_agent import async_get_default_agent, async_setup_default_agent
from .default_agent import async_setup_default_agent
from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
@ -247,8 +248,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def handle_reload(service: ServiceCall) -> None:
"""Reload intents."""
agent = async_get_default_agent(hass)
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
await hass.data[DATA_DEFAULT_ENTITY].async_reload(
language=service.data.get(ATTR_LANGUAGE)
)
hass.services.async_register(
DOMAIN,

View File

@ -11,8 +11,12 @@ import voluptuous as vol
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
from homeassistant.helpers import config_validation as cv, singleton
from .const import DOMAIN_DATA, HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT
from .default_agent import async_get_default_agent
from .const import (
DATA_DEFAULT_ENTITY,
DOMAIN_DATA,
HOME_ASSISTANT_AGENT,
OLD_HOME_ASSISTANT_AGENT,
)
from .entity import ConversationEntity
from .models import (
AbstractConversationAgent,
@ -50,7 +54,7 @@ def async_get_agent(
) -> AbstractConversationAgent | ConversationEntity | None:
"""Get specified agent."""
if agent_id is None or agent_id in (HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT):
return async_get_default_agent(hass)
return hass.data[DATA_DEFAULT_ENTITY]
if "." in agent_id:
return hass.data[DOMAIN_DATA].get_entity(agent_id)

View File

@ -10,6 +10,7 @@ from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from homeassistant.helpers.entity_component import EntityComponent
from .default_agent import DefaultAgent
from .entity import ConversationEntity
DOMAIN = "conversation"
@ -26,6 +27,7 @@ SERVICE_PROCESS = "process"
SERVICE_RELOAD = "reload"
DOMAIN_DATA: HassKey[EntityComponent[ConversationEntity]] = HassKey(DOMAIN)
DATA_DEFAULT_ENTITY: HassKey[DefaultAgent] = HassKey(f"{DOMAIN}_default_entity")
class ConversationEntityFeature(IntFlag):

View File

@ -44,7 +44,12 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_added_domain
from homeassistant.util.json import JsonObjectType, json_loads_object
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN, ConversationEntityFeature
from .const import (
DATA_DEFAULT_ENTITY,
DEFAULT_EXPOSED_ATTRIBUTES,
DOMAIN,
ConversationEntityFeature,
)
from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult
from .trace import ConversationTraceEventType, async_conversation_trace_append
@ -60,16 +65,9 @@ TRIGGER_CALLBACK_TYPE = Callable[
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
METADATA_CUSTOM_FILE = "hass_custom_file"
DATA_DEFAULT_ENTITY = "conversation_default_entity"
ERROR_SENTINEL = object()
@core.callback
def async_get_default_agent(hass: core.HomeAssistant) -> DefaultAgent:
"""Get the default agent."""
return hass.data[DATA_DEFAULT_ENTITY]
def json_load(fp: IO[str]) -> JsonObjectType:
"""Wrap json_loads for get_intents."""
return json_loads_object(fp.read())

View File

@ -27,13 +27,11 @@ from .agent_manager import (
async_get_agent,
get_agent_manager,
)
from .const import DOMAIN_DATA
from .const import DATA_DEFAULT_ENTITY, DOMAIN_DATA
from .default_agent import (
METADATA_CUSTOM_FILE,
METADATA_CUSTOM_SENTENCE,
DefaultAgent,
SentenceTriggerResult,
async_get_default_agent,
)
from .entity import ConversationEntity
from .models import ConversationInput
@ -173,10 +171,8 @@ async def websocket_hass_agent_debug(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Return intents that would be matched by the default agent for a list of sentences."""
agent = async_get_default_agent(hass)
assert isinstance(agent, DefaultAgent)
results = [
await agent.async_recognize(
await hass.data[DATA_DEFAULT_ENTITY].async_recognize(
ConversationInput(
text=sentence,
context=connection.context(msg),

View File

@ -14,8 +14,7 @@ from homeassistant.helpers.script import ScriptRunResult
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import UNDEFINED, ConfigType
from .const import DOMAIN
from .default_agent import DefaultAgent, async_get_default_agent
from .const import DATA_DEFAULT_ENTITY, DOMAIN
def has_no_punctuation(value: list[str]) -> list[str]:
@ -110,7 +109,4 @@ async def async_attach_trigger(
# two trigger copies for who will provide a response.
return None
default_agent = async_get_default_agent(hass)
assert isinstance(default_agent, DefaultAgent)
return default_agent.register_trigger(sentences, call_action)
return hass.data[DATA_DEFAULT_ENTITY].register_trigger(sentences, call_action)

View File

@ -13,6 +13,7 @@ import yaml
from homeassistant.components import conversation, cover, media_player
from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
from homeassistant.components.conversation.models import ConversationInput
from homeassistant.components.cover import SERVICE_OPEN_COVER
from homeassistant.components.homeassistant.exposed_entities import (
@ -203,7 +204,7 @@ async def test_exposed_areas(
@pytest.mark.usefixtures("init_components")
async def test_conversation_agent(hass: HomeAssistant) -> None:
"""Test DefaultAgent."""
agent = default_agent.async_get_default_agent(hass)
agent = hass.data[DATA_DEFAULT_ENTITY]
with patch(
"homeassistant.components.conversation.default_agent.get_languages",
return_value=["dwarvish", "elvish", "entish"],
@ -380,7 +381,7 @@ async def test_trigger_sentences(hass: HomeAssistant) -> None:
trigger_sentences = ["It's party time", "It is time to party"]
trigger_response = "Cowabunga!"
agent = default_agent.async_get_default_agent(hass)
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
callback = AsyncMock(return_value=trigger_response)
@ -1905,7 +1906,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
hass.states.async_set("cover.front_door", "closed")
calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
agent = default_agent.async_get_default_agent(hass)
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
result = await agent.async_process(

View File

@ -8,6 +8,7 @@ import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import HomeAssistant
@ -214,7 +215,7 @@ async def test_ws_prepare(
hass: HomeAssistant, init_components, hass_ws_client: WebSocketGenerator, agent_id
) -> None:
"""Test the Websocket prepare conversation API."""
agent = default_agent.async_get_default_agent(hass)
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
# No intents should be loaded yet

View File

@ -9,6 +9,7 @@ import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@ -143,7 +144,7 @@ async def test_prepare_reload(hass: HomeAssistant, init_components) -> None:
language = hass.config.language
# Load intents
agent = default_agent.async_get_default_agent(hass)
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
await agent.async_prepare(language)
@ -171,7 +172,7 @@ async def test_prepare_fail(hass: HomeAssistant) -> None:
assert await async_setup_component(hass, "conversation", {})
# Load intents
agent = default_agent.async_get_default_agent(hass)
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
await agent.async_prepare("not-a-language")

View File

@ -6,6 +6,7 @@ import pytest
import voluptuous as vol
from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
from homeassistant.components.conversation.models import ConversationInput
from homeassistant.core import Context, HomeAssistant, ServiceCall
from homeassistant.helpers import trigger
@ -550,7 +551,7 @@ async def test_trigger_with_device_id(hass: HomeAssistant) -> None:
},
)
agent = default_agent.async_get_default_agent(hass)
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
result = await agent.async_process(