Add wildcards to sentence triggers (#97236)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Michael Hansen 2023-07-27 13:30:42 -05:00 committed by GitHub
parent af286a8feb
commit 7e3fdd85fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 147 additions and 19 deletions

View File

@ -322,7 +322,11 @@ async def websocket_hass_agent_debug(
"intent": { "intent": {
"name": result.intent.name, "name": result.intent.name,
}, },
"entities": { "slots": { # direct access to values
entity_key: entity.value
for entity_key, entity in result.entities.items()
},
"details": {
entity_key: { entity_key: {
"name": entity.name, "name": entity.name,
"value": entity.value, "value": entity.value,

View File

@ -11,7 +11,14 @@ from pathlib import Path
import re import re
from typing import IO, Any from typing import IO, Any
from hassil.intents import Intents, ResponseType, SlotList, TextSlotList from hassil.expression import Expression, ListReference, Sequence
from hassil.intents import (
Intents,
ResponseType,
SlotList,
TextSlotList,
WildcardSlotList,
)
from hassil.recognize import RecognizeResult, recognize_all from hassil.recognize import RecognizeResult, recognize_all
from hassil.util import merge_dict from hassil.util import merge_dict
from home_assistant_intents import get_domains_and_languages, get_intents from home_assistant_intents import get_domains_and_languages, get_intents
@ -48,7 +55,7 @@ _ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]
REGEX_TYPE = type(re.compile("")) REGEX_TYPE = type(re.compile(""))
TRIGGER_CALLBACK_TYPE = Callable[ # pylint: disable=invalid-name TRIGGER_CALLBACK_TYPE = Callable[ # pylint: disable=invalid-name
[str], Awaitable[str | None] [str, RecognizeResult], Awaitable[str | None]
] ]
@ -657,6 +664,17 @@ class DefaultAgent(AbstractConversationAgent):
} }
self._trigger_intents = Intents.from_dict(intents_dict) self._trigger_intents = Intents.from_dict(intents_dict)
# Assume slot list references are wildcards
wildcard_names: set[str] = set()
for trigger_intent in self._trigger_intents.intents.values():
for intent_data in trigger_intent.data:
for sentence in intent_data.sentences:
_collect_list_references(sentence, wildcard_names)
for wildcard_name in wildcard_names:
self._trigger_intents.slot_lists[wildcard_name] = WildcardSlotList()
_LOGGER.debug("Rebuilt trigger intents: %s", intents_dict) _LOGGER.debug("Rebuilt trigger intents: %s", intents_dict)
def _unregister_trigger(self, trigger_data: TriggerData) -> None: def _unregister_trigger(self, trigger_data: TriggerData) -> None:
@ -682,14 +700,14 @@ class DefaultAgent(AbstractConversationAgent):
assert self._trigger_intents is not None assert self._trigger_intents is not None
matched_triggers: set[int] = set() matched_triggers: dict[int, RecognizeResult] = {}
for result in recognize_all(sentence, self._trigger_intents): for result in recognize_all(sentence, self._trigger_intents):
trigger_id = int(result.intent.name) trigger_id = int(result.intent.name)
if trigger_id in matched_triggers: if trigger_id in matched_triggers:
# Already matched a sentence from this trigger # Already matched a sentence from this trigger
break break
matched_triggers.add(trigger_id) matched_triggers[trigger_id] = result
if not matched_triggers: if not matched_triggers:
# Sentence did not match any trigger sentences # Sentence did not match any trigger sentences
@ -699,14 +717,14 @@ class DefaultAgent(AbstractConversationAgent):
"'%s' matched %s trigger(s): %s", "'%s' matched %s trigger(s): %s",
sentence, sentence,
len(matched_triggers), len(matched_triggers),
matched_triggers, list(matched_triggers),
) )
# Gather callback responses in parallel # Gather callback responses in parallel
trigger_responses = await asyncio.gather( trigger_responses = await asyncio.gather(
*( *(
self._trigger_sentences[trigger_id].callback(sentence) self._trigger_sentences[trigger_id].callback(sentence, result)
for trigger_id in matched_triggers for trigger_id, result in matched_triggers.items()
) )
) )
@ -733,3 +751,15 @@ def _make_error_result(
response.async_set_error(error_code, response_text) response.async_set_error(error_code, response_text)
return ConversationResult(response, conversation_id) return ConversationResult(response, conversation_id)
def _collect_list_references(expression: Expression, list_names: set[str]) -> None:
"""Collect list reference names recursively."""
if isinstance(expression, Sequence):
seq: Sequence = expression
for item in seq.items:
_collect_list_references(item, list_names)
elif isinstance(expression, ListReference):
# {list}
list_ref: ListReference = expression
list_names.add(list_ref.slot_name)

View File

@ -7,5 +7,5 @@
"integration_type": "system", "integration_type": "system",
"iot_class": "local_push", "iot_class": "local_push",
"quality_scale": "internal", "quality_scale": "internal",
"requirements": ["hassil==1.2.2", "home-assistant-intents==2023.7.25"] "requirements": ["hassil==1.2.5", "home-assistant-intents==2023.7.25"]
} }

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any from typing import Any
from hassil.recognize import PUNCTUATION from hassil.recognize import PUNCTUATION, RecognizeResult
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_COMMAND, CONF_PLATFORM from homeassistant.const import CONF_COMMAND, CONF_PLATFORM
@ -49,12 +49,29 @@ async def async_attach_trigger(
job = HassJob(action) job = HassJob(action)
@callback @callback
async def call_action(sentence: str) -> str | None: async def call_action(sentence: str, result: RecognizeResult) -> str | None:
"""Call action with right context.""" """Call action with right context."""
# Add slot values as extra trigger data
details = {
entity_name: {
"name": entity_name,
"text": entity.text.strip(), # remove whitespace
"value": entity.value.strip()
if isinstance(entity.value, str)
else entity.value,
}
for entity_name, entity in result.entities.items()
}
trigger_input: dict[str, Any] = { # Satisfy type checker trigger_input: dict[str, Any] = { # Satisfy type checker
**trigger_data, **trigger_data,
"platform": DOMAIN, "platform": DOMAIN,
"sentence": sentence, "sentence": sentence,
"details": details,
"slots": { # direct access to values
entity_name: entity["value"] for entity_name, entity in details.items()
},
} }
# Wait for the automation to complete # Wait for the automation to complete

View File

@ -20,7 +20,7 @@ dbus-fast==1.87.2
fnv-hash-fast==0.4.0 fnv-hash-fast==0.4.0
ha-av==10.1.0 ha-av==10.1.0
hass-nabucasa==0.69.0 hass-nabucasa==0.69.0
hassil==1.2.2 hassil==1.2.5
home-assistant-bluetooth==1.10.2 home-assistant-bluetooth==1.10.2
home-assistant-frontend==20230725.0 home-assistant-frontend==20230725.0
home-assistant-intents==2023.7.25 home-assistant-intents==2023.7.25

View File

@ -958,7 +958,7 @@ hass-nabucasa==0.69.0
hass-splunk==0.1.1 hass-splunk==0.1.1
# homeassistant.components.conversation # homeassistant.components.conversation
hassil==1.2.2 hassil==1.2.5
# homeassistant.components.jewish_calendar # homeassistant.components.jewish_calendar
hdate==0.10.4 hdate==0.10.4

View File

@ -753,7 +753,7 @@ habitipy==0.2.0
hass-nabucasa==0.69.0 hass-nabucasa==0.69.0
# homeassistant.components.conversation # homeassistant.components.conversation
hassil==1.2.2 hassil==1.2.5
# homeassistant.components.jewish_calendar # homeassistant.components.jewish_calendar
hdate==0.10.4 hdate==0.10.4

View File

@ -372,7 +372,7 @@
dict({ dict({
'results': list([ 'results': list([
dict({ dict({
'entities': dict({ 'details': dict({
'name': dict({ 'name': dict({
'name': 'name', 'name': 'name',
'text': 'my cool light', 'text': 'my cool light',
@ -382,6 +382,9 @@
'intent': dict({ 'intent': dict({
'name': 'HassTurnOn', 'name': 'HassTurnOn',
}), }),
'slots': dict({
'name': 'my cool light',
}),
'targets': dict({ 'targets': dict({
'light.kitchen': dict({ 'light.kitchen': dict({
'matched': True, 'matched': True,
@ -389,7 +392,7 @@
}), }),
}), }),
dict({ dict({
'entities': dict({ 'details': dict({
'name': dict({ 'name': dict({
'name': 'name', 'name': 'name',
'text': 'my cool light', 'text': 'my cool light',
@ -399,6 +402,9 @@
'intent': dict({ 'intent': dict({
'name': 'HassTurnOff', 'name': 'HassTurnOff',
}), }),
'slots': dict({
'name': 'my cool light',
}),
'targets': dict({ 'targets': dict({
'light.kitchen': dict({ 'light.kitchen': dict({
'matched': True, 'matched': True,
@ -406,7 +412,7 @@
}), }),
}), }),
dict({ dict({
'entities': dict({ 'details': dict({
'area': dict({ 'area': dict({
'name': 'area', 'name': 'area',
'text': 'kitchen', 'text': 'kitchen',
@ -421,6 +427,10 @@
'intent': dict({ 'intent': dict({
'name': 'HassTurnOn', 'name': 'HassTurnOn',
}), }),
'slots': dict({
'area': 'kitchen',
'domain': 'light',
}),
'targets': dict({ 'targets': dict({
'light.kitchen': dict({ 'light.kitchen': dict({
'matched': True, 'matched': True,
@ -428,7 +438,7 @@
}), }),
}), }),
dict({ dict({
'entities': dict({ 'details': dict({
'area': dict({ 'area': dict({
'name': 'area', 'name': 'area',
'text': 'kitchen', 'text': 'kitchen',
@ -448,6 +458,11 @@
'intent': dict({ 'intent': dict({
'name': 'HassGetState', 'name': 'HassGetState',
}), }),
'slots': dict({
'area': 'kitchen',
'domain': 'light',
'state': 'on',
}),
'targets': dict({ 'targets': dict({
'light.kitchen': dict({ 'light.kitchen': dict({
'matched': False, 'matched': False,

View File

@ -246,7 +246,8 @@ async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None:
for sentence in test_sentences: for sentence in test_sentences:
callback.reset_mock() callback.reset_mock()
result = await conversation.async_converse(hass, sentence, None, Context()) result = await conversation.async_converse(hass, sentence, None, Context())
callback.assert_called_once_with(sentence) assert callback.call_count == 1
assert callback.call_args[0][0] == sentence
assert ( assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE result.response.response_type == intent.IntentResponseType.ACTION_DONE
), sentence ), sentence

View File

@ -61,6 +61,8 @@ async def test_if_fires_on_event(hass: HomeAssistant, calls, setup_comp) -> None
"idx": "0", "idx": "0",
"platform": "conversation", "platform": "conversation",
"sentence": "Ha ha ha", "sentence": "Ha ha ha",
"slots": {},
"details": {},
} }
@ -103,6 +105,8 @@ async def test_same_trigger_multiple_sentences(
"idx": "0", "idx": "0",
"platform": "conversation", "platform": "conversation",
"sentence": "hello", "sentence": "hello",
"slots": {},
"details": {},
} }
@ -188,3 +192,60 @@ async def test_fails_on_punctuation(hass: HomeAssistant, command: str) -> None:
}, },
], ],
) )
async def test_wildcards(hass: HomeAssistant, calls, setup_comp) -> None:
"""Test wildcards in trigger sentences."""
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": [
"play {album} by {artist}",
],
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
},
}
},
)
await hass.services.async_call(
"conversation",
"process",
{
"text": "play the white album by the beatles",
},
blocking=True,
)
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"platform": "conversation",
"sentence": "play the white album by the beatles",
"slots": {
"album": "the white album",
"artist": "the beatles",
},
"details": {
"album": {
"name": "album",
"text": "the white album",
"value": "the white album",
},
"artist": {
"name": "artist",
"text": "the beatles",
"value": "the beatles",
},
},
}