Make the full conversation input available to sentence triggers (#131982)

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Paulus Schoutsen 2024-11-30 23:04:29 -05:00 committed by GitHub
parent ffeefd4856
commit 6103cea3f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 146 additions and 25 deletions

View File

@ -70,7 +70,7 @@ _ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]
REGEX_TYPE = type(re.compile(""))
TRIGGER_CALLBACK_TYPE = Callable[
[str, RecognizeResult, str | None], Awaitable[str | None]
[ConversationInput, RecognizeResult], Awaitable[str | None]
]
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
METADATA_CUSTOM_FILE = "hass_custom_file"
@ -1286,9 +1286,7 @@ class DefaultAgent(ConversationEntity):
# Gather callback responses in parallel
trigger_callbacks = [
self._trigger_sentences[trigger_id].callback(
user_input.text, trigger_result, user_input.device_id
)
self._trigger_sentences[trigger_id].callback(user_input, trigger_result)
for trigger_id, trigger_result in result.matched_triggers.items()
]

View File

@ -40,6 +40,17 @@ class ConversationInput:
agent_id: str | None = None
"""Agent to use for processing."""
def as_dict(self) -> dict[str, Any]:
"""Return input as a dict."""
return {
"text": self.text,
"context": self.context.as_dict(),
"conversation_id": self.conversation_id,
"device_id": self.device_id,
"language": self.language,
"agent_id": self.agent_id,
}
@dataclass(slots=True)
class ConversationResult:

View File

@ -16,6 +16,7 @@ from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import UNDEFINED, ConfigType
from .const import DATA_DEFAULT_ENTITY, DOMAIN
from .models import ConversationInput
def has_no_punctuation(value: list[str]) -> list[str]:
@ -62,7 +63,7 @@ async def async_attach_trigger(
job = HassJob(action)
async def call_action(
sentence: str, result: RecognizeResult, device_id: str | None
user_input: ConversationInput, result: RecognizeResult
) -> str | None:
"""Call action with right context."""
@ -83,12 +84,13 @@ async def async_attach_trigger(
trigger_input: dict[str, Any] = { # Satisfy type checker
**trigger_data,
"platform": DOMAIN,
"sentence": sentence,
"sentence": user_input.text,
"details": details,
"slots": { # direct access to values
entity_name: entity["value"] for entity_name, entity in details.items()
},
"device_id": device_id,
"device_id": user_input.device_id,
"user_input": user_input.as_dict(),
}
# Wait for the automation to complete

View File

@ -397,7 +397,7 @@ async def test_trigger_sentences(hass: HomeAssistant) -> None:
callback.reset_mock()
result = await conversation.async_converse(hass, sentence, None, Context())
assert callback.call_count == 1
assert callback.call_args[0][0] == sentence
assert callback.call_args[0][0].text == sentence
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), sentence

View File

@ -40,18 +40,31 @@ async def test_if_fires_on_event(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
}
},
)
context = Context()
service_response = await hass.services.async_call(
"conversation",
"process",
{"text": "Ha ha ha"},
blocking=True,
return_response=True,
context=context,
)
assert service_response["response"]["speech"]["plain"]["speech"] == "Done"
@ -61,13 +74,21 @@ async def test_if_fires_on_event(
assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"id": 0,
"idx": 0,
"platform": "conversation",
"sentence": "Ha ha ha",
"slots": {},
"details": {},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "Ha ha ha",
},
}
@ -152,7 +173,19 @@ async def test_response_same_sentence(
{"delay": "0:0:0.100"},
{
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
{"set_conversation_response": "response 2"},
],
@ -168,13 +201,14 @@ async def test_response_same_sentence(
]
},
)
context = Context()
service_response = await hass.services.async_call(
"conversation",
"process",
{"text": "test sentence"},
blocking=True,
return_response=True,
context=context,
)
await hass.async_block_till_done()
@ -188,12 +222,20 @@ async def test_response_same_sentence(
assert service_calls[1].data["data"] == {
"alias": None,
"id": "trigger1",
"idx": "0",
"idx": 0,
"platform": "conversation",
"sentence": "test sentence",
"slots": {},
"details": {},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "test sentence",
},
}
@ -231,13 +273,14 @@ async def test_response_same_sentence_with_error(
]
},
)
context = Context()
service_response = await hass.services.async_call(
"conversation",
"process",
{"text": "test sentence"},
blocking=True,
return_response=True,
context=context,
)
await hass.async_block_till_done()
@ -320,12 +363,24 @@ async def test_same_trigger_multiple_sentences(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
}
},
)
context = Context()
await hass.services.async_call(
"conversation",
"process",
@ -333,6 +388,7 @@ async def test_same_trigger_multiple_sentences(
"text": "hello",
},
blocking=True,
context=context,
)
# Only triggers once
@ -342,13 +398,21 @@ async def test_same_trigger_multiple_sentences(
assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"id": 0,
"idx": 0,
"platform": "conversation",
"sentence": "hello",
"slots": {},
"details": {},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "hello",
},
}
@ -371,7 +435,19 @@ async def test_same_sentence_multiple_triggers(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
},
{
@ -384,7 +460,19 @@ async def test_same_sentence_multiple_triggers(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
},
],
@ -488,12 +576,25 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
}
},
)
context = Context()
await hass.services.async_call(
"conversation",
"process",
@ -501,6 +602,7 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
"text": "play the white album by the beatles",
},
blocking=True,
context=context,
)
await hass.async_block_till_done()
@ -509,8 +611,8 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"id": 0,
"idx": 0,
"platform": "conversation",
"sentence": "play the white album by the beatles",
"slots": {
@ -530,6 +632,14 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
},
},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "play the white album by the beatles",
},
}