mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Use HassKey in conversation (#126332)
* Use HassKey in conversation * Adjust tests
This commit is contained in:
parent
f9e7721653
commit
f8a53aea09
@ -35,6 +35,7 @@ from .const import (
|
||||
ATTR_CONVERSATION_ID,
|
||||
ATTR_LANGUAGE,
|
||||
ATTR_TEXT,
|
||||
DATA_DEFAULT_ENTITY,
|
||||
DOMAIN,
|
||||
DOMAIN_DATA,
|
||||
HOME_ASSISTANT_AGENT,
|
||||
@ -43,7 +44,7 @@ from .const import (
|
||||
SERVICE_RELOAD,
|
||||
ConversationEntityFeature,
|
||||
)
|
||||
from .default_agent import async_get_default_agent, async_setup_default_agent
|
||||
from .default_agent import async_setup_default_agent
|
||||
from .entity import ConversationEntity
|
||||
from .http import async_setup as async_setup_conversation_http
|
||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
@ -247,8 +248,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
||||
async def handle_reload(service: ServiceCall) -> None:
|
||||
"""Reload intents."""
|
||||
agent = async_get_default_agent(hass)
|
||||
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
|
||||
await hass.data[DATA_DEFAULT_ENTITY].async_reload(
|
||||
language=service.data.get(ATTR_LANGUAGE)
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
|
@ -11,8 +11,12 @@ import voluptuous as vol
|
||||
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
|
||||
from homeassistant.helpers import config_validation as cv, singleton
|
||||
|
||||
from .const import DOMAIN_DATA, HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT
|
||||
from .default_agent import async_get_default_agent
|
||||
from .const import (
|
||||
DATA_DEFAULT_ENTITY,
|
||||
DOMAIN_DATA,
|
||||
HOME_ASSISTANT_AGENT,
|
||||
OLD_HOME_ASSISTANT_AGENT,
|
||||
)
|
||||
from .entity import ConversationEntity
|
||||
from .models import (
|
||||
AbstractConversationAgent,
|
||||
@ -50,7 +54,7 @@ def async_get_agent(
|
||||
) -> AbstractConversationAgent | ConversationEntity | None:
|
||||
"""Get specified agent."""
|
||||
if agent_id is None or agent_id in (HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT):
|
||||
return async_get_default_agent(hass)
|
||||
return hass.data[DATA_DEFAULT_ENTITY]
|
||||
|
||||
if "." in agent_id:
|
||||
return hass.data[DOMAIN_DATA].get_entity(agent_id)
|
||||
|
@ -10,6 +10,7 @@ from homeassistant.util.hass_dict import HassKey
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from .default_agent import DefaultAgent
|
||||
from .entity import ConversationEntity
|
||||
|
||||
DOMAIN = "conversation"
|
||||
@ -26,6 +27,7 @@ SERVICE_PROCESS = "process"
|
||||
SERVICE_RELOAD = "reload"
|
||||
|
||||
DOMAIN_DATA: HassKey[EntityComponent[ConversationEntity]] = HassKey(DOMAIN)
|
||||
DATA_DEFAULT_ENTITY: HassKey[DefaultAgent] = HassKey(f"{DOMAIN}_default_entity")
|
||||
|
||||
|
||||
class ConversationEntityFeature(IntFlag):
|
||||
|
@ -44,7 +44,12 @@ from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.event import async_track_state_added_domain
|
||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||
|
||||
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN, ConversationEntityFeature
|
||||
from .const import (
|
||||
DATA_DEFAULT_ENTITY,
|
||||
DEFAULT_EXPOSED_ATTRIBUTES,
|
||||
DOMAIN,
|
||||
ConversationEntityFeature,
|
||||
)
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput, ConversationResult
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
@ -60,16 +65,9 @@ TRIGGER_CALLBACK_TYPE = Callable[
|
||||
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
|
||||
METADATA_CUSTOM_FILE = "hass_custom_file"
|
||||
|
||||
DATA_DEFAULT_ENTITY = "conversation_default_entity"
|
||||
ERROR_SENTINEL = object()
|
||||
|
||||
|
||||
@core.callback
|
||||
def async_get_default_agent(hass: core.HomeAssistant) -> DefaultAgent:
|
||||
"""Get the default agent."""
|
||||
return hass.data[DATA_DEFAULT_ENTITY]
|
||||
|
||||
|
||||
def json_load(fp: IO[str]) -> JsonObjectType:
|
||||
"""Wrap json_loads for get_intents."""
|
||||
return json_loads_object(fp.read())
|
||||
|
@ -27,13 +27,11 @@ from .agent_manager import (
|
||||
async_get_agent,
|
||||
get_agent_manager,
|
||||
)
|
||||
from .const import DOMAIN_DATA
|
||||
from .const import DATA_DEFAULT_ENTITY, DOMAIN_DATA
|
||||
from .default_agent import (
|
||||
METADATA_CUSTOM_FILE,
|
||||
METADATA_CUSTOM_SENTENCE,
|
||||
DefaultAgent,
|
||||
SentenceTriggerResult,
|
||||
async_get_default_agent,
|
||||
)
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput
|
||||
@ -173,10 +171,8 @@ async def websocket_hass_agent_debug(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Return intents that would be matched by the default agent for a list of sentences."""
|
||||
agent = async_get_default_agent(hass)
|
||||
assert isinstance(agent, DefaultAgent)
|
||||
results = [
|
||||
await agent.async_recognize(
|
||||
await hass.data[DATA_DEFAULT_ENTITY].async_recognize(
|
||||
ConversationInput(
|
||||
text=sentence,
|
||||
context=connection.context(msg),
|
||||
|
@ -14,8 +14,7 @@ from homeassistant.helpers.script import ScriptRunResult
|
||||
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
from .default_agent import DefaultAgent, async_get_default_agent
|
||||
from .const import DATA_DEFAULT_ENTITY, DOMAIN
|
||||
|
||||
|
||||
def has_no_punctuation(value: list[str]) -> list[str]:
|
||||
@ -110,7 +109,4 @@ async def async_attach_trigger(
|
||||
# two trigger copies for who will provide a response.
|
||||
return None
|
||||
|
||||
default_agent = async_get_default_agent(hass)
|
||||
assert isinstance(default_agent, DefaultAgent)
|
||||
|
||||
return default_agent.register_trigger(sentences, call_action)
|
||||
return hass.data[DATA_DEFAULT_ENTITY].register_trigger(sentences, call_action)
|
||||
|
@ -13,6 +13,7 @@ import yaml
|
||||
|
||||
from homeassistant.components import conversation, cover, media_player
|
||||
from homeassistant.components.conversation import default_agent
|
||||
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
|
||||
from homeassistant.components.conversation.models import ConversationInput
|
||||
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
||||
from homeassistant.components.homeassistant.exposed_entities import (
|
||||
@ -203,7 +204,7 @@ async def test_exposed_areas(
|
||||
@pytest.mark.usefixtures("init_components")
|
||||
async def test_conversation_agent(hass: HomeAssistant) -> None:
|
||||
"""Test DefaultAgent."""
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
with patch(
|
||||
"homeassistant.components.conversation.default_agent.get_languages",
|
||||
return_value=["dwarvish", "elvish", "entish"],
|
||||
@ -380,7 +381,7 @@ async def test_trigger_sentences(hass: HomeAssistant) -> None:
|
||||
trigger_sentences = ["It's party time", "It is time to party"]
|
||||
trigger_response = "Cowabunga!"
|
||||
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
callback = AsyncMock(return_value=trigger_response)
|
||||
@ -1905,7 +1906,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
|
||||
hass.states.async_set("cover.front_door", "closed")
|
||||
calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
|
||||
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
result = await agent.async_process(
|
||||
|
@ -8,6 +8,7 @@ import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.conversation import default_agent
|
||||
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
|
||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -214,7 +215,7 @@ async def test_ws_prepare(
|
||||
hass: HomeAssistant, init_components, hass_ws_client: WebSocketGenerator, agent_id
|
||||
) -> None:
|
||||
"""Test the Websocket prepare conversation API."""
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
# No intents should be loaded yet
|
||||
|
@ -9,6 +9,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import default_agent
|
||||
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
|
||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@ -143,7 +144,7 @@ async def test_prepare_reload(hass: HomeAssistant, init_components) -> None:
|
||||
language = hass.config.language
|
||||
|
||||
# Load intents
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
await agent.async_prepare(language)
|
||||
|
||||
@ -171,7 +172,7 @@ async def test_prepare_fail(hass: HomeAssistant) -> None:
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
|
||||
# Load intents
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
await agent.async_prepare("not-a-language")
|
||||
|
||||
|
@ -6,6 +6,7 @@ import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.conversation import default_agent
|
||||
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
|
||||
from homeassistant.components.conversation.models import ConversationInput
|
||||
from homeassistant.core import Context, HomeAssistant, ServiceCall
|
||||
from homeassistant.helpers import trigger
|
||||
@ -550,7 +551,7 @@ async def test_trigger_with_device_id(hass: HomeAssistant) -> None:
|
||||
},
|
||||
)
|
||||
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
result = await agent.async_process(
|
||||
|
Loading…
x
Reference in New Issue
Block a user