mirror of
https://github.com/home-assistant/core.git
synced 2025-07-13 00:07:10 +00:00
Use HassKey in conversation (#126332)
* Use HassKey in conversation * Adjust tests
This commit is contained in:
parent
f9e7721653
commit
f8a53aea09
@ -35,6 +35,7 @@ from .const import (
|
|||||||
ATTR_CONVERSATION_ID,
|
ATTR_CONVERSATION_ID,
|
||||||
ATTR_LANGUAGE,
|
ATTR_LANGUAGE,
|
||||||
ATTR_TEXT,
|
ATTR_TEXT,
|
||||||
|
DATA_DEFAULT_ENTITY,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
DOMAIN_DATA,
|
DOMAIN_DATA,
|
||||||
HOME_ASSISTANT_AGENT,
|
HOME_ASSISTANT_AGENT,
|
||||||
@ -43,7 +44,7 @@ from .const import (
|
|||||||
SERVICE_RELOAD,
|
SERVICE_RELOAD,
|
||||||
ConversationEntityFeature,
|
ConversationEntityFeature,
|
||||||
)
|
)
|
||||||
from .default_agent import async_get_default_agent, async_setup_default_agent
|
from .default_agent import async_setup_default_agent
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .http import async_setup as async_setup_conversation_http
|
from .http import async_setup as async_setup_conversation_http
|
||||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||||
@ -247,8 +248,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
|
|
||||||
async def handle_reload(service: ServiceCall) -> None:
|
async def handle_reload(service: ServiceCall) -> None:
|
||||||
"""Reload intents."""
|
"""Reload intents."""
|
||||||
agent = async_get_default_agent(hass)
|
await hass.data[DATA_DEFAULT_ENTITY].async_reload(
|
||||||
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
|
language=service.data.get(ATTR_LANGUAGE)
|
||||||
|
)
|
||||||
|
|
||||||
hass.services.async_register(
|
hass.services.async_register(
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
@ -11,8 +11,12 @@ import voluptuous as vol
|
|||||||
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
|
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
|
||||||
from homeassistant.helpers import config_validation as cv, singleton
|
from homeassistant.helpers import config_validation as cv, singleton
|
||||||
|
|
||||||
from .const import DOMAIN_DATA, HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT
|
from .const import (
|
||||||
from .default_agent import async_get_default_agent
|
DATA_DEFAULT_ENTITY,
|
||||||
|
DOMAIN_DATA,
|
||||||
|
HOME_ASSISTANT_AGENT,
|
||||||
|
OLD_HOME_ASSISTANT_AGENT,
|
||||||
|
)
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .models import (
|
from .models import (
|
||||||
AbstractConversationAgent,
|
AbstractConversationAgent,
|
||||||
@ -50,7 +54,7 @@ def async_get_agent(
|
|||||||
) -> AbstractConversationAgent | ConversationEntity | None:
|
) -> AbstractConversationAgent | ConversationEntity | None:
|
||||||
"""Get specified agent."""
|
"""Get specified agent."""
|
||||||
if agent_id is None or agent_id in (HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT):
|
if agent_id is None or agent_id in (HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT):
|
||||||
return async_get_default_agent(hass)
|
return hass.data[DATA_DEFAULT_ENTITY]
|
||||||
|
|
||||||
if "." in agent_id:
|
if "." in agent_id:
|
||||||
return hass.data[DOMAIN_DATA].get_entity(agent_id)
|
return hass.data[DOMAIN_DATA].get_entity(agent_id)
|
||||||
|
@ -10,6 +10,7 @@ from homeassistant.util.hass_dict import HassKey
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
|
||||||
|
from .default_agent import DefaultAgent
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
|
|
||||||
DOMAIN = "conversation"
|
DOMAIN = "conversation"
|
||||||
@ -26,6 +27,7 @@ SERVICE_PROCESS = "process"
|
|||||||
SERVICE_RELOAD = "reload"
|
SERVICE_RELOAD = "reload"
|
||||||
|
|
||||||
DOMAIN_DATA: HassKey[EntityComponent[ConversationEntity]] = HassKey(DOMAIN)
|
DOMAIN_DATA: HassKey[EntityComponent[ConversationEntity]] = HassKey(DOMAIN)
|
||||||
|
DATA_DEFAULT_ENTITY: HassKey[DefaultAgent] = HassKey(f"{DOMAIN}_default_entity")
|
||||||
|
|
||||||
|
|
||||||
class ConversationEntityFeature(IntFlag):
|
class ConversationEntityFeature(IntFlag):
|
||||||
|
@ -44,7 +44,12 @@ from homeassistant.helpers.entity_component import EntityComponent
|
|||||||
from homeassistant.helpers.event import async_track_state_added_domain
|
from homeassistant.helpers.event import async_track_state_added_domain
|
||||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||||
|
|
||||||
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN, ConversationEntityFeature
|
from .const import (
|
||||||
|
DATA_DEFAULT_ENTITY,
|
||||||
|
DEFAULT_EXPOSED_ATTRIBUTES,
|
||||||
|
DOMAIN,
|
||||||
|
ConversationEntityFeature,
|
||||||
|
)
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .models import ConversationInput, ConversationResult
|
from .models import ConversationInput, ConversationResult
|
||||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||||
@ -60,16 +65,9 @@ TRIGGER_CALLBACK_TYPE = Callable[
|
|||||||
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
|
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
|
||||||
METADATA_CUSTOM_FILE = "hass_custom_file"
|
METADATA_CUSTOM_FILE = "hass_custom_file"
|
||||||
|
|
||||||
DATA_DEFAULT_ENTITY = "conversation_default_entity"
|
|
||||||
ERROR_SENTINEL = object()
|
ERROR_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
@core.callback
|
|
||||||
def async_get_default_agent(hass: core.HomeAssistant) -> DefaultAgent:
|
|
||||||
"""Get the default agent."""
|
|
||||||
return hass.data[DATA_DEFAULT_ENTITY]
|
|
||||||
|
|
||||||
|
|
||||||
def json_load(fp: IO[str]) -> JsonObjectType:
|
def json_load(fp: IO[str]) -> JsonObjectType:
|
||||||
"""Wrap json_loads for get_intents."""
|
"""Wrap json_loads for get_intents."""
|
||||||
return json_loads_object(fp.read())
|
return json_loads_object(fp.read())
|
||||||
|
@ -27,13 +27,11 @@ from .agent_manager import (
|
|||||||
async_get_agent,
|
async_get_agent,
|
||||||
get_agent_manager,
|
get_agent_manager,
|
||||||
)
|
)
|
||||||
from .const import DOMAIN_DATA
|
from .const import DATA_DEFAULT_ENTITY, DOMAIN_DATA
|
||||||
from .default_agent import (
|
from .default_agent import (
|
||||||
METADATA_CUSTOM_FILE,
|
METADATA_CUSTOM_FILE,
|
||||||
METADATA_CUSTOM_SENTENCE,
|
METADATA_CUSTOM_SENTENCE,
|
||||||
DefaultAgent,
|
|
||||||
SentenceTriggerResult,
|
SentenceTriggerResult,
|
||||||
async_get_default_agent,
|
|
||||||
)
|
)
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .models import ConversationInput
|
from .models import ConversationInput
|
||||||
@ -173,10 +171,8 @@ async def websocket_hass_agent_debug(
|
|||||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Return intents that would be matched by the default agent for a list of sentences."""
|
"""Return intents that would be matched by the default agent for a list of sentences."""
|
||||||
agent = async_get_default_agent(hass)
|
|
||||||
assert isinstance(agent, DefaultAgent)
|
|
||||||
results = [
|
results = [
|
||||||
await agent.async_recognize(
|
await hass.data[DATA_DEFAULT_ENTITY].async_recognize(
|
||||||
ConversationInput(
|
ConversationInput(
|
||||||
text=sentence,
|
text=sentence,
|
||||||
context=connection.context(msg),
|
context=connection.context(msg),
|
||||||
|
@ -14,8 +14,7 @@ from homeassistant.helpers.script import ScriptRunResult
|
|||||||
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
||||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType
|
from homeassistant.helpers.typing import UNDEFINED, ConfigType
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DATA_DEFAULT_ENTITY, DOMAIN
|
||||||
from .default_agent import DefaultAgent, async_get_default_agent
|
|
||||||
|
|
||||||
|
|
||||||
def has_no_punctuation(value: list[str]) -> list[str]:
|
def has_no_punctuation(value: list[str]) -> list[str]:
|
||||||
@ -110,7 +109,4 @@ async def async_attach_trigger(
|
|||||||
# two trigger copies for who will provide a response.
|
# two trigger copies for who will provide a response.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
default_agent = async_get_default_agent(hass)
|
return hass.data[DATA_DEFAULT_ENTITY].register_trigger(sentences, call_action)
|
||||||
assert isinstance(default_agent, DefaultAgent)
|
|
||||||
|
|
||||||
return default_agent.register_trigger(sentences, call_action)
|
|
||||||
|
@ -13,6 +13,7 @@ import yaml
|
|||||||
|
|
||||||
from homeassistant.components import conversation, cover, media_player
|
from homeassistant.components import conversation, cover, media_player
|
||||||
from homeassistant.components.conversation import default_agent
|
from homeassistant.components.conversation import default_agent
|
||||||
|
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
|
||||||
from homeassistant.components.conversation.models import ConversationInput
|
from homeassistant.components.conversation.models import ConversationInput
|
||||||
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
||||||
from homeassistant.components.homeassistant.exposed_entities import (
|
from homeassistant.components.homeassistant.exposed_entities import (
|
||||||
@ -203,7 +204,7 @@ async def test_exposed_areas(
|
|||||||
@pytest.mark.usefixtures("init_components")
|
@pytest.mark.usefixtures("init_components")
|
||||||
async def test_conversation_agent(hass: HomeAssistant) -> None:
|
async def test_conversation_agent(hass: HomeAssistant) -> None:
|
||||||
"""Test DefaultAgent."""
|
"""Test DefaultAgent."""
|
||||||
agent = default_agent.async_get_default_agent(hass)
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.conversation.default_agent.get_languages",
|
"homeassistant.components.conversation.default_agent.get_languages",
|
||||||
return_value=["dwarvish", "elvish", "entish"],
|
return_value=["dwarvish", "elvish", "entish"],
|
||||||
@ -380,7 +381,7 @@ async def test_trigger_sentences(hass: HomeAssistant) -> None:
|
|||||||
trigger_sentences = ["It's party time", "It is time to party"]
|
trigger_sentences = ["It's party time", "It is time to party"]
|
||||||
trigger_response = "Cowabunga!"
|
trigger_response = "Cowabunga!"
|
||||||
|
|
||||||
agent = default_agent.async_get_default_agent(hass)
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
assert isinstance(agent, default_agent.DefaultAgent)
|
assert isinstance(agent, default_agent.DefaultAgent)
|
||||||
|
|
||||||
callback = AsyncMock(return_value=trigger_response)
|
callback = AsyncMock(return_value=trigger_response)
|
||||||
@ -1905,7 +1906,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
|
|||||||
hass.states.async_set("cover.front_door", "closed")
|
hass.states.async_set("cover.front_door", "closed")
|
||||||
calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
|
calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
|
||||||
|
|
||||||
agent = default_agent.async_get_default_agent(hass)
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
assert isinstance(agent, default_agent.DefaultAgent)
|
assert isinstance(agent, default_agent.DefaultAgent)
|
||||||
|
|
||||||
result = await agent.async_process(
|
result = await agent.async_process(
|
||||||
|
@ -8,6 +8,7 @@ import pytest
|
|||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components.conversation import default_agent
|
from homeassistant.components.conversation import default_agent
|
||||||
|
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
|
||||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||||
from homeassistant.const import ATTR_FRIENDLY_NAME
|
from homeassistant.const import ATTR_FRIENDLY_NAME
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -214,7 +215,7 @@ async def test_ws_prepare(
|
|||||||
hass: HomeAssistant, init_components, hass_ws_client: WebSocketGenerator, agent_id
|
hass: HomeAssistant, init_components, hass_ws_client: WebSocketGenerator, agent_id
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the Websocket prepare conversation API."""
|
"""Test the Websocket prepare conversation API."""
|
||||||
agent = default_agent.async_get_default_agent(hass)
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
assert isinstance(agent, default_agent.DefaultAgent)
|
assert isinstance(agent, default_agent.DefaultAgent)
|
||||||
|
|
||||||
# No intents should be loaded yet
|
# No intents should be loaded yet
|
||||||
|
@ -9,6 +9,7 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.conversation import default_agent
|
from homeassistant.components.conversation import default_agent
|
||||||
|
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
|
||||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@ -143,7 +144,7 @@ async def test_prepare_reload(hass: HomeAssistant, init_components) -> None:
|
|||||||
language = hass.config.language
|
language = hass.config.language
|
||||||
|
|
||||||
# Load intents
|
# Load intents
|
||||||
agent = default_agent.async_get_default_agent(hass)
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
assert isinstance(agent, default_agent.DefaultAgent)
|
assert isinstance(agent, default_agent.DefaultAgent)
|
||||||
await agent.async_prepare(language)
|
await agent.async_prepare(language)
|
||||||
|
|
||||||
@ -171,7 +172,7 @@ async def test_prepare_fail(hass: HomeAssistant) -> None:
|
|||||||
assert await async_setup_component(hass, "conversation", {})
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
|
||||||
# Load intents
|
# Load intents
|
||||||
agent = default_agent.async_get_default_agent(hass)
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
assert isinstance(agent, default_agent.DefaultAgent)
|
assert isinstance(agent, default_agent.DefaultAgent)
|
||||||
await agent.async_prepare("not-a-language")
|
await agent.async_prepare("not-a-language")
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import pytest
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.conversation import default_agent
|
from homeassistant.components.conversation import default_agent
|
||||||
|
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
|
||||||
from homeassistant.components.conversation.models import ConversationInput
|
from homeassistant.components.conversation.models import ConversationInput
|
||||||
from homeassistant.core import Context, HomeAssistant, ServiceCall
|
from homeassistant.core import Context, HomeAssistant, ServiceCall
|
||||||
from homeassistant.helpers import trigger
|
from homeassistant.helpers import trigger
|
||||||
@ -550,7 +551,7 @@ async def test_trigger_with_device_id(hass: HomeAssistant) -> None:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = default_agent.async_get_default_agent(hass)
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
assert isinstance(agent, default_agent.DefaultAgent)
|
assert isinstance(agent, default_agent.DefaultAgent)
|
||||||
|
|
||||||
result = await agent.async_process(
|
result = await agent.async_process(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user