Make the full conversation input available to sentence triggers (#131982)

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Paulus Schoutsen 2024-11-30 23:04:29 -05:00 committed by GitHub
parent ffeefd4856
commit 6103cea3f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 146 additions and 25 deletions

View File

@ -70,7 +70,7 @@ _ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]
REGEX_TYPE = type(re.compile("")) REGEX_TYPE = type(re.compile(""))
TRIGGER_CALLBACK_TYPE = Callable[ TRIGGER_CALLBACK_TYPE = Callable[
[str, RecognizeResult, str | None], Awaitable[str | None] [ConversationInput, RecognizeResult], Awaitable[str | None]
] ]
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence" METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
METADATA_CUSTOM_FILE = "hass_custom_file" METADATA_CUSTOM_FILE = "hass_custom_file"
@ -1286,9 +1286,7 @@ class DefaultAgent(ConversationEntity):
# Gather callback responses in parallel # Gather callback responses in parallel
trigger_callbacks = [ trigger_callbacks = [
self._trigger_sentences[trigger_id].callback( self._trigger_sentences[trigger_id].callback(user_input, trigger_result)
user_input.text, trigger_result, user_input.device_id
)
for trigger_id, trigger_result in result.matched_triggers.items() for trigger_id, trigger_result in result.matched_triggers.items()
] ]

View File

@ -40,6 +40,17 @@ class ConversationInput:
agent_id: str | None = None agent_id: str | None = None
"""Agent to use for processing.""" """Agent to use for processing."""
def as_dict(self) -> dict[str, Any]:
"""Return input as a dict."""
return {
"text": self.text,
"context": self.context.as_dict(),
"conversation_id": self.conversation_id,
"device_id": self.device_id,
"language": self.language,
"agent_id": self.agent_id,
}
@dataclass(slots=True) @dataclass(slots=True)
class ConversationResult: class ConversationResult:

View File

@ -16,6 +16,7 @@ from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import UNDEFINED, ConfigType from homeassistant.helpers.typing import UNDEFINED, ConfigType
from .const import DATA_DEFAULT_ENTITY, DOMAIN from .const import DATA_DEFAULT_ENTITY, DOMAIN
from .models import ConversationInput
def has_no_punctuation(value: list[str]) -> list[str]: def has_no_punctuation(value: list[str]) -> list[str]:
@ -62,7 +63,7 @@ async def async_attach_trigger(
job = HassJob(action) job = HassJob(action)
async def call_action( async def call_action(
sentence: str, result: RecognizeResult, device_id: str | None user_input: ConversationInput, result: RecognizeResult
) -> str | None: ) -> str | None:
"""Call action with right context.""" """Call action with right context."""
@ -83,12 +84,13 @@ async def async_attach_trigger(
trigger_input: dict[str, Any] = { # Satisfy type checker trigger_input: dict[str, Any] = { # Satisfy type checker
**trigger_data, **trigger_data,
"platform": DOMAIN, "platform": DOMAIN,
"sentence": sentence, "sentence": user_input.text,
"details": details, "details": details,
"slots": { # direct access to values "slots": { # direct access to values
entity_name: entity["value"] for entity_name, entity in details.items() entity_name: entity["value"] for entity_name, entity in details.items()
}, },
"device_id": device_id, "device_id": user_input.device_id,
"user_input": user_input.as_dict(),
} }
# Wait for the automation to complete # Wait for the automation to complete

View File

@ -397,7 +397,7 @@ async def test_trigger_sentences(hass: HomeAssistant) -> None:
callback.reset_mock() callback.reset_mock()
result = await conversation.async_converse(hass, sentence, None, Context()) result = await conversation.async_converse(hass, sentence, None, Context())
assert callback.call_count == 1 assert callback.call_count == 1
assert callback.call_args[0][0] == sentence assert callback.call_args[0][0].text == sentence
assert ( assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE result.response.response_type == intent.IntentResponseType.ACTION_DONE
), sentence ), sentence

View File

@ -40,18 +40,31 @@ async def test_if_fires_on_event(
}, },
"action": { "action": {
"service": "test.automation", "service": "test.automation",
"data_template": {"data": "{{ trigger }}"}, "data": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
}, },
} }
}, },
) )
context = Context()
service_response = await hass.services.async_call( service_response = await hass.services.async_call(
"conversation", "conversation",
"process", "process",
{"text": "Ha ha ha"}, {"text": "Ha ha ha"},
blocking=True, blocking=True,
return_response=True, return_response=True,
context=context,
) )
assert service_response["response"]["speech"]["plain"]["speech"] == "Done" assert service_response["response"]["speech"]["plain"]["speech"] == "Done"
@ -61,13 +74,21 @@ async def test_if_fires_on_event(
assert service_calls[1].service == "automation" assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == { assert service_calls[1].data["data"] == {
"alias": None, "alias": None,
"id": "0", "id": 0,
"idx": "0", "idx": 0,
"platform": "conversation", "platform": "conversation",
"sentence": "Ha ha ha", "sentence": "Ha ha ha",
"slots": {}, "slots": {},
"details": {}, "details": {},
"device_id": None, "device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "Ha ha ha",
},
} }
@ -152,7 +173,19 @@ async def test_response_same_sentence(
{"delay": "0:0:0.100"}, {"delay": "0:0:0.100"},
{ {
"service": "test.automation", "service": "test.automation",
"data_template": {"data": "{{ trigger }}"}, "data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
}, },
{"set_conversation_response": "response 2"}, {"set_conversation_response": "response 2"},
], ],
@ -168,13 +201,14 @@ async def test_response_same_sentence(
] ]
}, },
) )
context = Context()
service_response = await hass.services.async_call( service_response = await hass.services.async_call(
"conversation", "conversation",
"process", "process",
{"text": "test sentence"}, {"text": "test sentence"},
blocking=True, blocking=True,
return_response=True, return_response=True,
context=context,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -188,12 +222,20 @@ async def test_response_same_sentence(
assert service_calls[1].data["data"] == { assert service_calls[1].data["data"] == {
"alias": None, "alias": None,
"id": "trigger1", "id": "trigger1",
"idx": "0", "idx": 0,
"platform": "conversation", "platform": "conversation",
"sentence": "test sentence", "sentence": "test sentence",
"slots": {}, "slots": {},
"details": {}, "details": {},
"device_id": None, "device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "test sentence",
},
} }
@ -231,13 +273,14 @@ async def test_response_same_sentence_with_error(
] ]
}, },
) )
context = Context()
service_response = await hass.services.async_call( service_response = await hass.services.async_call(
"conversation", "conversation",
"process", "process",
{"text": "test sentence"}, {"text": "test sentence"},
blocking=True, blocking=True,
return_response=True, return_response=True,
context=context,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -320,12 +363,24 @@ async def test_same_trigger_multiple_sentences(
}, },
"action": { "action": {
"service": "test.automation", "service": "test.automation",
"data_template": {"data": "{{ trigger }}"}, "data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
}, },
} }
}, },
) )
context = Context()
await hass.services.async_call( await hass.services.async_call(
"conversation", "conversation",
"process", "process",
@ -333,6 +388,7 @@ async def test_same_trigger_multiple_sentences(
"text": "hello", "text": "hello",
}, },
blocking=True, blocking=True,
context=context,
) )
# Only triggers once # Only triggers once
@ -342,13 +398,21 @@ async def test_same_trigger_multiple_sentences(
assert service_calls[1].service == "automation" assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == { assert service_calls[1].data["data"] == {
"alias": None, "alias": None,
"id": "0", "id": 0,
"idx": "0", "idx": 0,
"platform": "conversation", "platform": "conversation",
"sentence": "hello", "sentence": "hello",
"slots": {}, "slots": {},
"details": {}, "details": {},
"device_id": None, "device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "hello",
},
} }
@ -371,7 +435,19 @@ async def test_same_sentence_multiple_triggers(
}, },
"action": { "action": {
"service": "test.automation", "service": "test.automation",
"data_template": {"data": "{{ trigger }}"}, "data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
}, },
}, },
{ {
@ -384,7 +460,19 @@ async def test_same_sentence_multiple_triggers(
}, },
"action": { "action": {
"service": "test.automation", "service": "test.automation",
"data_template": {"data": "{{ trigger }}"}, "data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
}, },
}, },
], ],
@ -488,12 +576,25 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
}, },
"action": { "action": {
"service": "test.automation", "service": "test.automation",
"data_template": {"data": "{{ trigger }}"}, "data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
}, },
} }
}, },
) )
context = Context()
await hass.services.async_call( await hass.services.async_call(
"conversation", "conversation",
"process", "process",
@ -501,6 +602,7 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
"text": "play the white album by the beatles", "text": "play the white album by the beatles",
}, },
blocking=True, blocking=True,
context=context,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -509,8 +611,8 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
assert service_calls[1].service == "automation" assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == { assert service_calls[1].data["data"] == {
"alias": None, "alias": None,
"id": "0", "id": 0,
"idx": "0", "idx": 0,
"platform": "conversation", "platform": "conversation",
"sentence": "play the white album by the beatles", "sentence": "play the white album by the beatles",
"slots": { "slots": {
@ -530,6 +632,14 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
}, },
}, },
"device_id": None, "device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "play the white album by the beatles",
},
} }