mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Add websocket command to test intent recognition for default agent (#94674)
* Add websocket command to test intent recognition for default agent * Return results as a list * Only check intent name/entities in test * Less verbose output in debug API
This commit is contained in:
parent
1459bf4011
commit
38614bc3f0
@ -186,6 +186,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
websocket_api.async_register_command(hass, websocket_prepare)
|
websocket_api.async_register_command(hass, websocket_prepare)
|
||||||
websocket_api.async_register_command(hass, websocket_get_agent_info)
|
websocket_api.async_register_command(hass, websocket_get_agent_info)
|
||||||
websocket_api.async_register_command(hass, websocket_list_agents)
|
websocket_api.async_register_command(hass, websocket_list_agents)
|
||||||
|
websocket_api.async_register_command(hass, websocket_hass_agent_debug)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -297,6 +298,60 @@ async def websocket_list_agents(
|
|||||||
connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents}))
|
connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents}))
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "conversation/agent/homeassistant/debug",
|
||||||
|
vol.Required("sentences"): [str],
|
||||||
|
vol.Optional("language"): str,
|
||||||
|
vol.Optional("device_id"): vol.Any(str, None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@websocket_api.async_response
|
||||||
|
async def websocket_hass_agent_debug(
|
||||||
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||||
|
) -> None:
|
||||||
|
"""Return intents that would be matched by the default agent for a list of sentences."""
|
||||||
|
agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
|
||||||
|
assert isinstance(agent, DefaultAgent)
|
||||||
|
results = [
|
||||||
|
await agent.async_recognize(
|
||||||
|
ConversationInput(
|
||||||
|
text=sentence,
|
||||||
|
context=connection.context(msg),
|
||||||
|
conversation_id=None,
|
||||||
|
device_id=msg.get("device_id"),
|
||||||
|
language=msg.get("language", hass.config.language),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for sentence in msg["sentences"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Return results for each sentence in the same order as the input.
|
||||||
|
connection.send_result(
|
||||||
|
msg["id"],
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"intent": {
|
||||||
|
"name": result.intent.name,
|
||||||
|
},
|
||||||
|
"entities": {
|
||||||
|
entity_key: {
|
||||||
|
"name": entity.name,
|
||||||
|
"value": entity.value,
|
||||||
|
"text": entity.text,
|
||||||
|
}
|
||||||
|
for entity_key, entity in result.entities.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if result is not None
|
||||||
|
else None
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessView(http.HomeAssistantView):
|
class ConversationProcessView(http.HomeAssistantView):
|
||||||
"""View to process text."""
|
"""View to process text."""
|
||||||
|
|
||||||
|
@ -143,11 +143,12 @@ class DefaultAgent(AbstractConversationAgent):
|
|||||||
self.hass, DOMAIN, self._async_exposed_entities_updated
|
self.hass, DOMAIN, self._async_exposed_entities_updated
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
async def async_recognize(
|
||||||
"""Process a sentence."""
|
self, user_input: ConversationInput
|
||||||
|
) -> RecognizeResult | None:
|
||||||
|
"""Recognize intent from user input."""
|
||||||
language = user_input.language or self.hass.config.language
|
language = user_input.language or self.hass.config.language
|
||||||
lang_intents = self._lang_intents.get(language)
|
lang_intents = self._lang_intents.get(language)
|
||||||
conversation_id = None # Not supported
|
|
||||||
|
|
||||||
# Reload intents if missing or new components
|
# Reload intents if missing or new components
|
||||||
if lang_intents is None or (
|
if lang_intents is None or (
|
||||||
@ -159,21 +160,26 @@ class DefaultAgent(AbstractConversationAgent):
|
|||||||
if lang_intents is None:
|
if lang_intents is None:
|
||||||
# No intents loaded
|
# No intents loaded
|
||||||
_LOGGER.warning("No intents were loaded for language: %s", language)
|
_LOGGER.warning("No intents were loaded for language: %s", language)
|
||||||
return _make_error_result(
|
return None
|
||||||
language,
|
|
||||||
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
|
|
||||||
_DEFAULT_ERROR_TEXT,
|
|
||||||
conversation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
slot_lists = self._make_slot_lists()
|
slot_lists = self._make_slot_lists()
|
||||||
|
|
||||||
result = await self.hass.async_add_executor_job(
|
result = await self.hass.async_add_executor_job(
|
||||||
self._recognize,
|
self._recognize,
|
||||||
user_input,
|
user_input,
|
||||||
lang_intents,
|
lang_intents,
|
||||||
slot_lists,
|
slot_lists,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||||
|
"""Process a sentence."""
|
||||||
|
language = user_input.language or self.hass.config.language
|
||||||
|
conversation_id = None # Not supported
|
||||||
|
|
||||||
|
result = await self.async_recognize(user_input)
|
||||||
|
lang_intents = self._lang_intents.get(language)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
_LOGGER.debug("No intent was matched for '%s'", user_input.text)
|
_LOGGER.debug("No intent was matched for '%s'", user_input.text)
|
||||||
return _make_error_result(
|
return _make_error_result(
|
||||||
@ -183,6 +189,10 @@ class DefaultAgent(AbstractConversationAgent):
|
|||||||
conversation_id,
|
conversation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Will never happen because result will be None when no intents are
|
||||||
|
# loaded in async_recognize.
|
||||||
|
assert lang_intents is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
intent_response = await intent.async_handle(
|
intent_response = await intent.async_handle(
|
||||||
self.hass,
|
self.hass,
|
||||||
@ -585,9 +595,12 @@ class DefaultAgent(AbstractConversationAgent):
|
|||||||
return self._slot_lists
|
return self._slot_lists
|
||||||
|
|
||||||
def _get_error_text(
|
def _get_error_text(
|
||||||
self, response_type: ResponseType, lang_intents: LanguageIntents
|
self, response_type: ResponseType, lang_intents: LanguageIntents | None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get response error text by type."""
|
"""Get response error text by type."""
|
||||||
|
if lang_intents is None:
|
||||||
|
return _DEFAULT_ERROR_TEXT
|
||||||
|
|
||||||
response_key = response_type.value
|
response_key = response_type.value
|
||||||
response_str = lang_intents.error_responses.get(response_key)
|
response_str = lang_intents.error_responses.get(response_key)
|
||||||
return response_str or _DEFAULT_ERROR_TEXT
|
return response_str or _DEFAULT_ERROR_TEXT
|
||||||
|
@ -249,3 +249,43 @@
|
|||||||
'message': "invalid agent ID for dictionary value @ data['agent_id']. Got 'not_exist'",
|
'message': "invalid agent ID for dictionary value @ data['agent_id']. Got 'not_exist'",
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_ws_hass_agent_debug
|
||||||
|
dict({
|
||||||
|
'results': list([
|
||||||
|
dict({
|
||||||
|
'entities': dict({
|
||||||
|
'name': dict({
|
||||||
|
'name': 'name',
|
||||||
|
'text': 'my cool light',
|
||||||
|
'value': 'my cool light',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'intent': dict({
|
||||||
|
'name': 'HassTurnOn',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'entities': dict({
|
||||||
|
'name': dict({
|
||||||
|
'name': 'name',
|
||||||
|
'text': 'my cool light',
|
||||||
|
'value': 'my cool light',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'intent': dict({
|
||||||
|
'name': 'HassTurnOff',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
None,
|
||||||
|
]),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_ws_hass_agent_debug.1
|
||||||
|
dict({
|
||||||
|
'name': dict({
|
||||||
|
'name': 'name',
|
||||||
|
'text': 'my cool light',
|
||||||
|
'value': 'my cool light',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
@ -1626,3 +1626,43 @@ async def test_ws_get_agent_info(
|
|||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert not msg["success"]
|
assert not msg["success"]
|
||||||
assert msg["error"] == snapshot
|
assert msg["error"] == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ws_hass_agent_debug(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test homeassistant agent debug websocket command."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
entity_registry.async_get_or_create(
|
||||||
|
"light", "demo", "1234", suggested_object_id="kitchen"
|
||||||
|
)
|
||||||
|
entity_registry.async_update_entity("light.kitchen", aliases={"my cool light"})
|
||||||
|
hass.states.async_set("light.kitchen", "off")
|
||||||
|
|
||||||
|
on_calls = async_mock_service(hass, LIGHT_DOMAIN, "turn_on")
|
||||||
|
off_calls = async_mock_service(hass, LIGHT_DOMAIN, "turn_off")
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "conversation/agent/homeassistant/debug",
|
||||||
|
"sentences": [
|
||||||
|
"turn on my cool light",
|
||||||
|
"turn my cool light off",
|
||||||
|
"this will not match anything", # null in results
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == snapshot
|
||||||
|
|
||||||
|
# Light state should not have been changed
|
||||||
|
assert len(on_calls) == 0
|
||||||
|
assert len(off_calls) == 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user