Handle sentence triggers and local intents before pipeline agent (#129058)

* Handle sentence triggers and registered intents in Assist LLM API

* Remove from LLM

* Check sentence triggers and local intents first

* Fix type

* Fix type again

* Use pipeline language

* Fix cloud test

* Clean up and fix translation key

* Refactor async_recognize
This commit is contained in:
Michael Hansen 2024-11-14 10:50:50 -06:00 committed by GitHub
parent df55d198c8
commit 5fa9a945d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 492 additions and 98 deletions

View File

@ -31,6 +31,7 @@ from homeassistant.components.tts import (
)
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent
from homeassistant.helpers.collection import (
CHANGE_UPDATED,
CollectionError,
@ -109,6 +110,7 @@ PIPELINE_FIELDS: VolDictType = {
vol.Required("tts_voice"): vol.Any(str, None),
vol.Required("wake_word_entity"): vol.Any(str, None),
vol.Required("wake_word_id"): vol.Any(str, None),
vol.Optional("prefer_local_intents"): bool,
}
STORED_PIPELINE_RUNS = 10
@ -322,6 +324,7 @@ async def async_update_pipeline(
tts_voice: str | None | UndefinedType = UNDEFINED,
wake_word_entity: str | None | UndefinedType = UNDEFINED,
wake_word_id: str | None | UndefinedType = UNDEFINED,
prefer_local_intents: bool | UndefinedType = UNDEFINED,
) -> None:
"""Update a pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN]
@ -345,6 +348,7 @@ async def async_update_pipeline(
("tts_voice", tts_voice),
("wake_word_entity", wake_word_entity),
("wake_word_id", wake_word_id),
("prefer_local_intents", prefer_local_intents),
)
if val is not UNDEFINED
}
@ -398,6 +402,7 @@ class Pipeline:
tts_voice: str | None
wake_word_entity: str | None
wake_word_id: str | None
prefer_local_intents: bool = False
id: str = field(default_factory=ulid_util.ulid_now)
@ -421,6 +426,7 @@ class Pipeline:
tts_voice=data["tts_voice"],
wake_word_entity=data["wake_word_entity"],
wake_word_id=data["wake_word_id"],
prefer_local_intents=data.get("prefer_local_intents", False),
)
def to_json(self) -> dict[str, Any]:
@ -438,6 +444,7 @@ class Pipeline:
"tts_voice": self.tts_voice,
"wake_word_entity": self.wake_word_entity,
"wake_word_id": self.wake_word_id,
"prefer_local_intents": self.prefer_local_intents,
}
@ -1016,15 +1023,58 @@ class PipelineRun:
)
try:
conversation_result = await conversation.async_converse(
hass=self.hass,
user_input = conversation.ConversationInput(
text=intent_input,
context=self.context,
conversation_id=conversation_id,
device_id=device_id,
context=self.context,
language=self.pipeline.conversation_language,
language=self.pipeline.language,
agent_id=self.intent_agent,
)
# Sentence triggers override conversation agent
if (
trigger_response_text
:= await conversation.async_handle_sentence_triggers(
self.hass, user_input
)
):
# Sentence trigger matched
trigger_response = intent.IntentResponse(
self.pipeline.conversation_language
)
trigger_response.async_set_speech(trigger_response_text)
conversation_result = conversation.ConversationResult(
response=trigger_response,
conversation_id=user_input.conversation_id,
)
# Try local intents first, if preferred.
# Skip this step if the default agent is already used.
elif (
self.pipeline.prefer_local_intents
and (user_input.agent_id != conversation.HOME_ASSISTANT_AGENT)
and (
intent_response := await conversation.async_handle_intents(
self.hass, user_input
)
)
):
# Local intent matched
conversation_result = conversation.ConversationResult(
response=intent_response,
conversation_id=user_input.conversation_id,
)
else:
# Fall back to pipeline conversation agent
conversation_result = await conversation.async_converse(
hass=self.hass,
text=user_input.text,
conversation_id=user_input.conversation_id,
device_id=user_input.device_id,
context=user_input.context,
language=user_input.language,
agent_id=user_input.agent_id,
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during intent recognition")
raise IntentRecognitionError(

View File

@ -1,6 +1,7 @@
"""Handle Cloud assist pipelines."""
import asyncio
from typing import Any
from homeassistant.components.assist_pipeline import (
async_create_default_pipeline,
@ -98,7 +99,7 @@ async def async_migrate_cloud_pipeline_engine(
# is an after dependency of cloud
await async_setup_pipeline_store(hass)
kwargs: dict[str, str] = {pipeline_attribute: engine_id}
kwargs: dict[str, Any] = {pipeline_attribute: engine_id}
pipelines = async_get_pipelines(hass)
for pipeline in pipelines:
if getattr(pipeline, pipeline_attribute) == DOMAIN:

View File

@ -44,7 +44,7 @@ from .const import (
SERVICE_RELOAD,
ConversationEntityFeature,
)
from .default_agent import async_setup_default_agent
from .default_agent import DefaultAgent, async_setup_default_agent
from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
@ -207,6 +207,32 @@ async def async_prepare_agent(
await agent.async_prepare(language)
async def async_handle_sentence_triggers(
hass: HomeAssistant, user_input: ConversationInput
) -> str | None:
"""Try to match input against sentence triggers and return response text.
Returns None if no match occurred.
"""
default_agent = async_get_agent(hass)
assert isinstance(default_agent, DefaultAgent)
return await default_agent.async_handle_sentence_triggers(user_input)
async def async_handle_intents(
hass: HomeAssistant, user_input: ConversationInput
) -> intent.IntentResponse | None:
"""Try to match input against registered intents and return response.
Returns None if no match occurred.
"""
default_agent = async_get_agent(hass)
assert isinstance(default_agent, DefaultAgent)
return await default_agent.async_handle_intents(user_input)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service."""
entity_component = EntityComponent[ConversationEntity](_LOGGER, DOMAIN, hass)

View File

@ -213,13 +213,10 @@ class DefaultAgent(ConversationEntity):
async_listen_entity_updates(self.hass, DOMAIN, self._async_clear_slot_list),
]
async def async_recognize(
self, user_input: ConversationInput
) -> RecognizeResult | SentenceTriggerResult | None:
async def async_recognize_intent(
self, user_input: ConversationInput, strict_intents_only: bool = False
) -> RecognizeResult | None:
"""Recognize intent from user input."""
if trigger_result := await self._match_triggers(user_input.text):
return trigger_result
language = user_input.language or self.hass.config.language
lang_intents = await self.async_get_or_load_intents(language)
@ -240,6 +237,7 @@ class DefaultAgent(ConversationEntity):
slot_lists,
intent_context,
language,
strict_intents_only,
)
_LOGGER.debug(
@ -251,56 +249,36 @@ class DefaultAgent(ConversationEntity):
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence."""
language = user_input.language or self.hass.config.language
conversation_id = None # Not supported
result = await self.async_recognize(user_input)
# Check if a trigger matched
if isinstance(result, SentenceTriggerResult):
# Gather callback responses in parallel
trigger_callbacks = [
self._trigger_sentences[trigger_id].callback(
result.sentence, trigger_result, user_input.device_id
)
for trigger_id, trigger_result in result.matched_triggers.items()
]
# Use first non-empty result as response.
#
# There may be multiple copies of a trigger running when editing in
# the UI, so it's critical that we filter out empty responses here.
response_text: str | None = None
response_set_by_trigger = False
for trigger_future in asyncio.as_completed(trigger_callbacks):
trigger_response = await trigger_future
if trigger_response is None:
continue
response_text = trigger_response
response_set_by_trigger = True
break
if trigger_result := await self.async_recognize_sentence_trigger(user_input):
# Process callbacks and get response
response_text = await self._handle_trigger_result(
trigger_result, user_input
)
# Convert to conversation result
response = intent.IntentResponse(language=language)
response = intent.IntentResponse(
language=user_input.language or self.hass.config.language
)
response.response_type = intent.IntentResponseType.ACTION_DONE
if response_set_by_trigger:
# Response was explicitly set to empty
response_text = response_text or ""
elif not response_text:
# Use translated acknowledgment for pipeline language
translations = await translation.async_get_translations(
self.hass, language, DOMAIN, [DOMAIN]
)
response_text = translations.get(
f"component.{DOMAIN}.conversation.agent.done", "Done"
)
response.async_set_speech(response_text)
return ConversationResult(response=response)
# Match intents
intent_result = await self.async_recognize_intent(user_input)
return await self._async_process_intent_result(intent_result, user_input)
async def _async_process_intent_result(
self,
result: RecognizeResult | None,
user_input: ConversationInput,
) -> ConversationResult:
"""Process user input with intents."""
language = user_input.language or self.hass.config.language
conversation_id = None # Not supported
# Intent match or failure
lang_intents = await self.async_get_or_load_intents(language)
@ -436,6 +414,7 @@ class DefaultAgent(ConversationEntity):
slot_lists: dict[str, SlotList],
intent_context: dict[str, Any] | None,
language: str,
strict_intents_only: bool,
) -> RecognizeResult | None:
"""Search intents for a match to user input."""
strict_result = self._recognize_strict(
@ -446,6 +425,9 @@ class DefaultAgent(ConversationEntity):
# Successful strict match
return strict_result
if strict_intents_only:
return None
# Try again with all entities (including unexposed)
entity_registry = er.async_get(self.hass)
all_entity_names: list[tuple[str, str, dict[str, Any]]] = []
@ -1056,7 +1038,9 @@ class DefaultAgent(ConversationEntity):
# Force rebuild on next use
self._trigger_intents = None
async def _match_triggers(self, sentence: str) -> SentenceTriggerResult | None:
async def async_recognize_sentence_trigger(
self, user_input: ConversationInput
) -> SentenceTriggerResult | None:
"""Try to match sentence against registered trigger sentences.
Calls the registered callbacks if there's a match and returns a sentence
@ -1074,7 +1058,7 @@ class DefaultAgent(ConversationEntity):
matched_triggers: dict[int, RecognizeResult] = {}
matched_template: str | None = None
for result in recognize_all(sentence, self._trigger_intents):
for result in recognize_all(user_input.text, self._trigger_intents):
if result.intent_sentence is not None:
matched_template = result.intent_sentence.text
@ -1091,12 +1075,88 @@ class DefaultAgent(ConversationEntity):
_LOGGER.debug(
"'%s' matched %s trigger(s): %s",
sentence,
user_input.text,
len(matched_triggers),
list(matched_triggers),
)
return SentenceTriggerResult(sentence, matched_template, matched_triggers)
return SentenceTriggerResult(
user_input.text, matched_template, matched_triggers
)
async def _handle_trigger_result(
self, result: SentenceTriggerResult, user_input: ConversationInput
) -> str:
"""Run sentence trigger callbacks and return response text."""
# Gather callback responses in parallel
trigger_callbacks = [
self._trigger_sentences[trigger_id].callback(
user_input.text, trigger_result, user_input.device_id
)
for trigger_id, trigger_result in result.matched_triggers.items()
]
# Use first non-empty result as response.
#
# There may be multiple copies of a trigger running when editing in
# the UI, so it's critical that we filter out empty responses here.
response_text = ""
response_set_by_trigger = False
for trigger_future in asyncio.as_completed(trigger_callbacks):
trigger_response = await trigger_future
if trigger_response is None:
continue
response_text = trigger_response
response_set_by_trigger = True
break
if response_set_by_trigger:
# Response was explicitly set to empty
response_text = response_text or ""
elif not response_text:
# Use translated acknowledgment for pipeline language
language = user_input.language or self.hass.config.language
translations = await translation.async_get_translations(
self.hass, language, DOMAIN, [DOMAIN]
)
response_text = translations.get(
f"component.{DOMAIN}.conversation.agent.done", "Done"
)
return response_text
async def async_handle_sentence_triggers(
self, user_input: ConversationInput
) -> str | None:
"""Try to input sentence against sentence triggers and return response text.
Returns None if no match occurred.
"""
if trigger_result := await self.async_recognize_sentence_trigger(user_input):
return await self._handle_trigger_result(trigger_result, user_input)
return None
async def async_handle_intents(
self,
user_input: ConversationInput,
) -> intent.IntentResponse | None:
"""Try to match sentence against registered intents and return response.
Only performs strict matching with exposed entities and exact wording.
Returns None if no match occurred.
"""
result = await self.async_recognize_intent(user_input, strict_intents_only=True)
if not isinstance(result, RecognizeResult):
# No error message on failed match
return None
conversation_result = await self._async_process_intent_result(
result, user_input
)
return conversation_result.response
def _make_error_result(
@ -1108,7 +1168,6 @@ def _make_error_result(
"""Create conversation result with error code and text."""
response = intent.IntentResponse(language=language)
response.async_set_error(error_code, response_text)
return ConversationResult(response, conversation_id)

View File

@ -24,11 +24,7 @@ from .agent_manager import (
get_agent_manager,
)
from .const import DATA_COMPONENT, DATA_DEFAULT_ENTITY
from .default_agent import (
METADATA_CUSTOM_FILE,
METADATA_CUSTOM_SENTENCE,
SentenceTriggerResult,
)
from .default_agent import METADATA_CUSTOM_FILE, METADATA_CUSTOM_SENTENCE, DefaultAgent
from .entity import ConversationEntity
from .models import ConversationInput
@ -167,44 +163,42 @@ async def websocket_hass_agent_debug(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Return intents that would be matched by the default agent for a list of sentences."""
results = [
await hass.data[DATA_DEFAULT_ENTITY].async_recognize(
ConversationInput(
text=sentence,
context=connection.context(msg),
conversation_id=None,
device_id=msg.get("device_id"),
language=msg.get("language", hass.config.language),
agent_id=None,
)
)
for sentence in msg["sentences"]
]
agent = hass.data.get(DATA_DEFAULT_ENTITY)
assert isinstance(agent, DefaultAgent)
# Return results for each sentence in the same order as the input.
result_dicts: list[dict[str, Any] | None] = []
for result in results:
for sentence in msg["sentences"]:
user_input = ConversationInput(
text=sentence,
context=connection.context(msg),
conversation_id=None,
device_id=msg.get("device_id"),
language=msg.get("language", hass.config.language),
agent_id=None,
)
result_dict: dict[str, Any] | None = None
if isinstance(result, SentenceTriggerResult):
if trigger_result := await agent.async_recognize_sentence_trigger(user_input):
result_dict = {
# Matched a user-defined sentence trigger.
# We can't provide the response here without executing the
# trigger.
"match": True,
"source": "trigger",
"sentence_template": result.sentence_template or "",
"sentence_template": trigger_result.sentence_template or "",
}
elif isinstance(result, RecognizeResult):
successful_match = not result.unmatched_entities
elif intent_result := await agent.async_recognize_intent(user_input):
successful_match = not intent_result.unmatched_entities
result_dict = {
# Name of the matching intent (or the closest)
"intent": {
"name": result.intent.name,
"name": intent_result.intent.name,
},
# Slot values that would be received by the intent
"slots": { # direct access to values
entity_key: entity.text or entity.value
for entity_key, entity in result.entities.items()
for entity_key, entity in intent_result.entities.items()
},
# Extra slot details, such as the originally matched text
"details": {
@ -213,7 +207,7 @@ async def websocket_hass_agent_debug(
"value": entity.value,
"text": entity.text,
}
for entity_key, entity in result.entities.items()
for entity_key, entity in intent_result.entities.items()
},
# Entities/areas/etc. that would be targeted
"targets": {},
@ -222,24 +216,26 @@ async def websocket_hass_agent_debug(
# Text of the sentence template that matched (or was closest)
"sentence_template": "",
# When match is incomplete, this will contain the best slot guesses
"unmatched_slots": _get_unmatched_slots(result),
"unmatched_slots": _get_unmatched_slots(intent_result),
}
if successful_match:
result_dict["targets"] = {
state.entity_id: {"matched": is_matched}
for state, is_matched in _get_debug_targets(hass, result)
for state, is_matched in _get_debug_targets(hass, intent_result)
}
if result.intent_sentence is not None:
result_dict["sentence_template"] = result.intent_sentence.text
if intent_result.intent_sentence is not None:
result_dict["sentence_template"] = intent_result.intent_sentence.text
# Inspect metadata to determine if this matched a custom sentence
if result.intent_metadata and result.intent_metadata.get(
if intent_result.intent_metadata and intent_result.intent_metadata.get(
METADATA_CUSTOM_SENTENCE
):
result_dict["source"] = "custom"
result_dict["file"] = result.intent_metadata.get(METADATA_CUSTOM_FILE)
result_dict["file"] = intent_result.intent_metadata.get(
METADATA_CUSTOM_FILE
)
else:
result_dict["source"] = "builtin"

View File

@ -139,7 +139,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
@ -228,7 +228,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({

View File

@ -11,13 +11,20 @@ import wave
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, media_source, stt, tts
from homeassistant.components import (
assist_pipeline,
conversation,
media_source,
stt,
tts,
)
from homeassistant.components.assist_pipeline.const import (
BYTES_PER_CHUNK,
CONF_DEBUG_RECORDING_DIR,
DOMAIN,
)
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
from .conftest import (
@ -927,3 +934,148 @@ async def test_tts_dict_preferred_format(
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
async def test_sentence_trigger_overrides_conversation_agent(
hass: HomeAssistant,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that sentence triggers are checked before the conversation agent."""
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": [
"test trigger sentence",
],
},
"action": {
"set_conversation_response": "test trigger response",
},
}
},
)
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test trigger sentence",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
) as mock_async_converse:
await pipeline_input.execute()
# Sentence trigger should have been handled
mock_async_converse.assert_not_called()
# Verify sentence trigger response
intent_end_event = next(
(
e
for e in events
if e.type == assist_pipeline.PipelineEventType.INTENT_END
),
None,
)
assert (intent_end_event is not None) and intent_end_event.data
assert (
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
"speech"
]
== "test trigger response"
)
async def test_prefer_local_intents(
hass: HomeAssistant,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that the default agent is checked first when local intents are preferred."""
events: list[assist_pipeline.PipelineEvent] = []
# Reuse custom sentences in test config
class OrderBeerIntentHandler(intent.IntentHandler):
intent_type = "OrderBeer"
async def async_handle(
self, intent_obj: intent.Intent
) -> intent.IntentResponse:
response = intent_obj.create_response()
response.async_set_speech("Order confirmed")
return response
handler = OrderBeerIntentHandler()
intent.async_register(hass, handler)
# Fake a test agent and prefer local intents
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="I'd like to order a stout please",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
):
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
) as mock_async_converse:
await pipeline_input.execute()
# Test agent should not have been called
mock_async_converse.assert_not_called()
# Verify local intent response
intent_end_event = next(
(
e
for e in events
if e.type == assist_pipeline.PipelineEventType.INTENT_END
),
None,
)
assert (intent_end_event is not None) and intent_end_event.data
assert (
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
"speech"
]
== "Order confirmed"
)

View File

@ -574,6 +574,7 @@ async def test_update_pipeline(
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
"prefer_local_intents": False,
}
await async_update_pipeline(
@ -617,6 +618,7 @@ async def test_update_pipeline(
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
"prefer_local_intents": False,
}

View File

@ -974,6 +974,7 @@ async def test_add_pipeline(
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
"prefer_local_intents": True,
}
)
msg = await client.receive_json()
@ -991,6 +992,7 @@ async def test_add_pipeline(
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
"prefer_local_intents": True,
}
assert len(pipeline_store.data) == 2
@ -1008,6 +1010,7 @@ async def test_add_pipeline(
tts_voice="Arnold Schwarzenegger",
wake_word_entity="wakeword_entity_1",
wake_word_id="wakeword_id_1",
prefer_local_intents=True,
)
await client.send_json_auto_id(
@ -1195,6 +1198,7 @@ async def test_get_pipeline(
"tts_voice": "james_earl_jones",
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
}
# Get conversation agent as pipeline
@ -1220,6 +1224,7 @@ async def test_get_pipeline(
"tts_voice": "james_earl_jones",
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
}
await client.send_json_auto_id(
@ -1249,6 +1254,7 @@ async def test_get_pipeline(
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
"prefer_local_intents": False,
}
)
msg = await client.receive_json()
@ -1277,6 +1283,7 @@ async def test_get_pipeline(
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
"prefer_local_intents": False,
}
@ -1304,6 +1311,7 @@ async def test_list_pipelines(
"tts_voice": "james_earl_jones",
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
}
],
"preferred_pipeline": ANY,
@ -1395,6 +1403,7 @@ async def test_update_pipeline(
"tts_voice": "new_tts_voice",
"wake_word_entity": "new_wakeword_entity",
"wake_word_id": "new_wakeword_id",
"prefer_local_intents": False,
}
assert len(pipeline_store.data) == 2
@ -1446,6 +1455,7 @@ async def test_update_pipeline(
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
}
pipeline = pipeline_store.data[pipeline_id]

View File

@ -35,6 +35,7 @@ PIPELINE_DATA = {
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
},
{
"conversation_engine": "conversation_engine_2",
@ -49,6 +50,7 @@ PIPELINE_DATA = {
"tts_voice": "The Voice",
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
},
{
"conversation_engine": "conversation_engine_3",
@ -63,6 +65,7 @@ PIPELINE_DATA = {
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
},
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",

View File

@ -355,15 +355,15 @@ async def test_ws_hass_agent_debug_null_result(
"""Test homeassistant agent debug websocket command with a null result."""
client = await hass_ws_client(hass)
async def async_recognize(self, user_input, *args, **kwargs):
async def async_recognize_intent(self, user_input, *args, **kwargs):
if user_input.text == "bad sentence":
return None
return await self.async_recognize(user_input, *args, **kwargs)
with patch(
"homeassistant.components.conversation.default_agent.DefaultAgent.async_recognize",
async_recognize,
"homeassistant.components.conversation.default_agent.DefaultAgent.async_recognize_intent",
async_recognize_intent,
):
await client.send_json_auto_id(
{

View File

@ -8,10 +8,15 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation import (
ConversationInput,
async_handle_intents,
async_handle_sentence_triggers,
default_agent,
)
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
@ -229,3 +234,93 @@ async def test_prepare_agent(
await conversation.async_prepare_agent(hass, agent_id, "en")
assert len(mock_prepare.mock_calls) == 1
async def test_async_handle_sentence_triggers(hass: HomeAssistant) -> None:
"""Test handling sentence triggers with async_handle_sentence_triggers."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
response_template = "response {{ trigger.device_id }}"
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": ["my trigger"],
},
"action": {
"set_conversation_response": response_template,
},
}
},
)
# Device id will be available in response template
device_id = "1234"
expected_response = f"response {device_id}"
actual_response = await async_handle_sentence_triggers(
hass,
ConversationInput(
text="my trigger",
context=Context(),
conversation_id=None,
device_id=device_id,
language=hass.config.language,
),
)
assert actual_response == expected_response
async def test_async_handle_intents(hass: HomeAssistant) -> None:
"""Test handling registered intents with async_handle_intents."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
# Reuse custom sentences in test config to trigger default agent.
class OrderBeerIntentHandler(intent.IntentHandler):
intent_type = "OrderBeer"
def __init__(self) -> None:
super().__init__()
self.was_handled = False
async def async_handle(
self, intent_obj: intent.Intent
) -> intent.IntentResponse:
self.was_handled = True
return intent_obj.create_response()
handler = OrderBeerIntentHandler()
intent.async_register(hass, handler)
# Registered intent will be handled
result = await async_handle_intents(
hass,
ConversationInput(
text="I'd like to order a stout",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
),
)
assert result is not None
assert result.intent is not None
assert result.intent.intent_type == handler.intent_type
assert handler.was_handled
# No error messages, just None as a result
result = await async_handle_intents(
hass,
ConversationInput(
text="this sentence does not exist",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
),
)
assert result is None