mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 03:07:37 +00:00
Cache intent recognition results (#131114)
This commit is contained in:
parent
8f9095ba67
commit
f47840d83c
@ -3,8 +3,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections import OrderedDict
|
||||||
from collections.abc import Awaitable, Callable, Iterable
|
from collections.abc import Awaitable, Callable, Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -102,6 +104,77 @@ class SentenceTriggerResult:
|
|||||||
matched_triggers: dict[int, RecognizeResult]
|
matched_triggers: dict[int, RecognizeResult]
|
||||||
|
|
||||||
|
|
||||||
|
class IntentMatchingStage(Enum):
|
||||||
|
"""Stages of intent matching."""
|
||||||
|
|
||||||
|
EXPOSED_ENTITIES_ONLY = auto()
|
||||||
|
"""Match against exposed entities only."""
|
||||||
|
|
||||||
|
ALL_ENTITIES = auto()
|
||||||
|
"""Match against all entities in Home Assistant."""
|
||||||
|
|
||||||
|
FUZZY = auto()
|
||||||
|
"""Capture names that are not known to Home Assistant."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class IntentCacheKey:
|
||||||
|
"""Key for IntentCache."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""User input text."""
|
||||||
|
|
||||||
|
language: str
|
||||||
|
"""Language of text."""
|
||||||
|
|
||||||
|
device_id: str | None
|
||||||
|
"""Device id from user input."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class IntentCacheValue:
|
||||||
|
"""Value for IntentCache."""
|
||||||
|
|
||||||
|
result: RecognizeResult | None
|
||||||
|
"""Result of intent recognition."""
|
||||||
|
|
||||||
|
stage: IntentMatchingStage
|
||||||
|
"""Stage where result was found."""
|
||||||
|
|
||||||
|
|
||||||
|
class IntentCache:
|
||||||
|
"""LRU cache for intent recognition results."""
|
||||||
|
|
||||||
|
def __init__(self, capacity: int) -> None:
|
||||||
|
"""Initialize cache."""
|
||||||
|
self.cache: OrderedDict[IntentCacheKey, IntentCacheValue] = OrderedDict()
|
||||||
|
self.capacity = capacity
|
||||||
|
|
||||||
|
def get(self, key: IntentCacheKey) -> IntentCacheValue | None:
|
||||||
|
"""Get value for cache or None."""
|
||||||
|
if key not in self.cache:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Move the key to the end to show it was recently used
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
return self.cache[key]
|
||||||
|
|
||||||
|
def put(self, key: IntentCacheKey, value: IntentCacheValue) -> None:
|
||||||
|
"""Put a value in the cache, evicting the least recently used item if necessary."""
|
||||||
|
if key in self.cache:
|
||||||
|
# Update value and mark as recently used
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
elif len(self.cache) >= self.capacity:
|
||||||
|
# Evict the oldest item
|
||||||
|
self.cache.popitem(last=False)
|
||||||
|
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear the cache."""
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
def _get_language_variations(language: str) -> Iterable[str]:
|
def _get_language_variations(language: str) -> Iterable[str]:
|
||||||
"""Generate language codes with and without region."""
|
"""Generate language codes with and without region."""
|
||||||
yield language
|
yield language
|
||||||
@ -160,6 +233,7 @@ class DefaultAgent(ConversationEntity):
|
|||||||
# intent -> [sentences]
|
# intent -> [sentences]
|
||||||
self._config_intents: dict[str, Any] = config_intents
|
self._config_intents: dict[str, Any] = config_intents
|
||||||
self._slot_lists: dict[str, SlotList] | None = None
|
self._slot_lists: dict[str, SlotList] | None = None
|
||||||
|
self._all_entity_names: TextSlotList | None = None
|
||||||
|
|
||||||
# Sentences that will trigger a callback (skipping intent recognition)
|
# Sentences that will trigger a callback (skipping intent recognition)
|
||||||
self._trigger_sentences: list[TriggerData] = []
|
self._trigger_sentences: list[TriggerData] = []
|
||||||
@ -167,6 +241,9 @@ class DefaultAgent(ConversationEntity):
|
|||||||
self._unsub_clear_slot_list: list[Callable[[], None]] | None = None
|
self._unsub_clear_slot_list: list[Callable[[], None]] | None = None
|
||||||
self._load_intents_lock = asyncio.Lock()
|
self._load_intents_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# LRU cache to avoid unnecessary intent matching
|
||||||
|
self._intent_cache = IntentCache(capacity=128)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str]:
|
def supported_languages(self) -> list[str]:
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
@ -417,18 +494,200 @@ class DefaultAgent(ConversationEntity):
|
|||||||
strict_intents_only: bool,
|
strict_intents_only: bool,
|
||||||
) -> RecognizeResult | None:
|
) -> RecognizeResult | None:
|
||||||
"""Search intents for a match to user input."""
|
"""Search intents for a match to user input."""
|
||||||
strict_result = self._recognize_strict(
|
skip_exposed_match = False
|
||||||
user_input, lang_intents, slot_lists, intent_context, language
|
|
||||||
)
|
|
||||||
|
|
||||||
if strict_result is not None:
|
# Try cache first
|
||||||
# Successful strict match
|
cache_key = IntentCacheKey(
|
||||||
return strict_result
|
text=user_input.text, language=language, device_id=user_input.device_id
|
||||||
|
)
|
||||||
|
cache_value = self._intent_cache.get(cache_key)
|
||||||
|
if cache_value is not None:
|
||||||
|
if (cache_value.result is not None) and (
|
||||||
|
cache_value.stage == IntentMatchingStage.EXPOSED_ENTITIES_ONLY
|
||||||
|
):
|
||||||
|
_LOGGER.debug("Got cached result for exposed entities")
|
||||||
|
return cache_value.result
|
||||||
|
|
||||||
|
# Continue with matching, but we know we won't succeed for exposed
|
||||||
|
# entities only.
|
||||||
|
skip_exposed_match = True
|
||||||
|
|
||||||
|
if not skip_exposed_match:
|
||||||
|
start_time = time.monotonic()
|
||||||
|
strict_result = self._recognize_strict(
|
||||||
|
user_input, lang_intents, slot_lists, intent_context, language
|
||||||
|
)
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Checked exposed entities in %s second(s)",
|
||||||
|
time.monotonic() - start_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
self._intent_cache.put(
|
||||||
|
cache_key,
|
||||||
|
IntentCacheValue(
|
||||||
|
result=strict_result,
|
||||||
|
stage=IntentMatchingStage.EXPOSED_ENTITIES_ONLY,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if strict_result is not None:
|
||||||
|
# Successful strict match with exposed entities
|
||||||
|
return strict_result
|
||||||
|
|
||||||
if strict_intents_only:
|
if strict_intents_only:
|
||||||
|
# Don't try matching against all entities or doing a fuzzy match
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Try again with all entities (including unexposed)
|
# Try again with all entities (including unexposed)
|
||||||
|
skip_all_entities_match = False
|
||||||
|
if cache_value is not None:
|
||||||
|
if (cache_value.result is not None) and (
|
||||||
|
cache_value.stage == IntentMatchingStage.ALL_ENTITIES
|
||||||
|
):
|
||||||
|
_LOGGER.debug("Got cached result for all entities")
|
||||||
|
return cache_value.result
|
||||||
|
|
||||||
|
# Continue with matching, but we know we won't succeed for all
|
||||||
|
# entities.
|
||||||
|
skip_all_entities_match = True
|
||||||
|
|
||||||
|
if not skip_all_entities_match:
|
||||||
|
all_entities_slot_lists = {
|
||||||
|
**slot_lists,
|
||||||
|
"name": self._get_all_entity_names(),
|
||||||
|
}
|
||||||
|
|
||||||
|
start_time = time.monotonic()
|
||||||
|
strict_result = self._recognize_strict(
|
||||||
|
user_input,
|
||||||
|
lang_intents,
|
||||||
|
all_entities_slot_lists,
|
||||||
|
intent_context,
|
||||||
|
language,
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Checked all entities in %s second(s)", time.monotonic() - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
self._intent_cache.put(
|
||||||
|
cache_key,
|
||||||
|
IntentCacheValue(
|
||||||
|
result=strict_result, stage=IntentMatchingStage.ALL_ENTITIES
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if strict_result is not None:
|
||||||
|
# Not a successful match, but useful for an error message.
|
||||||
|
# This should fail the intent handling phase (async_match_targets).
|
||||||
|
return strict_result
|
||||||
|
|
||||||
|
# Try again with missing entities enabled
|
||||||
|
skip_fuzzy_match = False
|
||||||
|
if cache_value is not None:
|
||||||
|
if (cache_value.result is not None) and (
|
||||||
|
cache_value.stage == IntentMatchingStage.FUZZY
|
||||||
|
):
|
||||||
|
_LOGGER.debug("Got cached result for fuzzy match")
|
||||||
|
return cache_value.result
|
||||||
|
|
||||||
|
# We know we won't succeed for fuzzy matching.
|
||||||
|
skip_fuzzy_match = True
|
||||||
|
|
||||||
|
maybe_result: RecognizeResult | None = None
|
||||||
|
if not skip_fuzzy_match:
|
||||||
|
start_time = time.monotonic()
|
||||||
|
best_num_matched_entities = 0
|
||||||
|
best_num_unmatched_entities = 0
|
||||||
|
best_num_unmatched_ranges = 0
|
||||||
|
for result in recognize_all(
|
||||||
|
user_input.text,
|
||||||
|
lang_intents.intents,
|
||||||
|
slot_lists=slot_lists,
|
||||||
|
intent_context=intent_context,
|
||||||
|
allow_unmatched_entities=True,
|
||||||
|
):
|
||||||
|
if result.text_chunks_matched < 1:
|
||||||
|
# Skip results that don't match any literal text
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Don't count missing entities that couldn't be filled from context
|
||||||
|
num_matched_entities = 0
|
||||||
|
for matched_entity in result.entities_list:
|
||||||
|
if matched_entity.name not in result.unmatched_entities:
|
||||||
|
num_matched_entities += 1
|
||||||
|
|
||||||
|
num_unmatched_entities = 0
|
||||||
|
num_unmatched_ranges = 0
|
||||||
|
for unmatched_entity in result.unmatched_entities_list:
|
||||||
|
if isinstance(unmatched_entity, UnmatchedTextEntity):
|
||||||
|
if unmatched_entity.text != MISSING_ENTITY:
|
||||||
|
num_unmatched_entities += 1
|
||||||
|
elif isinstance(unmatched_entity, UnmatchedRangeEntity):
|
||||||
|
num_unmatched_ranges += 1
|
||||||
|
num_unmatched_entities += 1
|
||||||
|
else:
|
||||||
|
num_unmatched_entities += 1
|
||||||
|
|
||||||
|
if (
|
||||||
|
(maybe_result is None) # first result
|
||||||
|
or (num_matched_entities > best_num_matched_entities)
|
||||||
|
or (
|
||||||
|
# Fewer unmatched entities
|
||||||
|
(num_matched_entities == best_num_matched_entities)
|
||||||
|
and (num_unmatched_entities < best_num_unmatched_entities)
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
# Prefer unmatched ranges
|
||||||
|
(num_matched_entities == best_num_matched_entities)
|
||||||
|
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||||
|
and (num_unmatched_ranges > best_num_unmatched_ranges)
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
# More literal text matched
|
||||||
|
(num_matched_entities == best_num_matched_entities)
|
||||||
|
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||||
|
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
||||||
|
and (
|
||||||
|
result.text_chunks_matched
|
||||||
|
> maybe_result.text_chunks_matched
|
||||||
|
)
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
# Prefer match failures with entities
|
||||||
|
(result.text_chunks_matched == maybe_result.text_chunks_matched)
|
||||||
|
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||||
|
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
||||||
|
and (
|
||||||
|
("name" in result.entities)
|
||||||
|
or ("name" in result.unmatched_entities)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
maybe_result = result
|
||||||
|
best_num_matched_entities = num_matched_entities
|
||||||
|
best_num_unmatched_entities = num_unmatched_entities
|
||||||
|
best_num_unmatched_ranges = num_unmatched_ranges
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
self._intent_cache.put(
|
||||||
|
cache_key,
|
||||||
|
IntentCacheValue(result=maybe_result, stage=IntentMatchingStage.FUZZY),
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Did fuzzy match in %s second(s)", time.monotonic() - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
return maybe_result
|
||||||
|
|
||||||
|
def _get_all_entity_names(self) -> TextSlotList:
|
||||||
|
"""Get slot list with all entity names in Home Assistant."""
|
||||||
|
if self._all_entity_names is not None:
|
||||||
|
return self._all_entity_names
|
||||||
|
|
||||||
entity_registry = er.async_get(self.hass)
|
entity_registry = er.async_get(self.hass)
|
||||||
all_entity_names: list[tuple[str, str, dict[str, Any]]] = []
|
all_entity_names: list[tuple[str, str, dict[str, Any]]] = []
|
||||||
|
|
||||||
@ -459,96 +718,10 @@ class DefaultAgent(ConversationEntity):
|
|||||||
# Default name
|
# Default name
|
||||||
all_entity_names.append((state.name, state.name, context))
|
all_entity_names.append((state.name, state.name, context))
|
||||||
|
|
||||||
slot_lists = {
|
self._all_entity_names = TextSlotList.from_tuples(
|
||||||
**slot_lists,
|
all_entity_names, allow_template=False
|
||||||
"name": TextSlotList.from_tuples(all_entity_names, allow_template=False),
|
|
||||||
}
|
|
||||||
|
|
||||||
strict_result = self._recognize_strict(
|
|
||||||
user_input,
|
|
||||||
lang_intents,
|
|
||||||
slot_lists,
|
|
||||||
intent_context,
|
|
||||||
language,
|
|
||||||
)
|
)
|
||||||
|
return self._all_entity_names
|
||||||
if strict_result is not None:
|
|
||||||
# Not a successful match, but useful for an error message.
|
|
||||||
# This should fail the intent handling phase (async_match_targets).
|
|
||||||
return strict_result
|
|
||||||
|
|
||||||
# Try again with missing entities enabled
|
|
||||||
maybe_result: RecognizeResult | None = None
|
|
||||||
best_num_matched_entities = 0
|
|
||||||
best_num_unmatched_entities = 0
|
|
||||||
best_num_unmatched_ranges = 0
|
|
||||||
for result in recognize_all(
|
|
||||||
user_input.text,
|
|
||||||
lang_intents.intents,
|
|
||||||
slot_lists=slot_lists,
|
|
||||||
intent_context=intent_context,
|
|
||||||
allow_unmatched_entities=True,
|
|
||||||
):
|
|
||||||
if result.text_chunks_matched < 1:
|
|
||||||
# Skip results that don't match any literal text
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Don't count missing entities that couldn't be filled from context
|
|
||||||
num_matched_entities = 0
|
|
||||||
for matched_entity in result.entities_list:
|
|
||||||
if matched_entity.name not in result.unmatched_entities:
|
|
||||||
num_matched_entities += 1
|
|
||||||
|
|
||||||
num_unmatched_entities = 0
|
|
||||||
num_unmatched_ranges = 0
|
|
||||||
for unmatched_entity in result.unmatched_entities_list:
|
|
||||||
if isinstance(unmatched_entity, UnmatchedTextEntity):
|
|
||||||
if unmatched_entity.text != MISSING_ENTITY:
|
|
||||||
num_unmatched_entities += 1
|
|
||||||
elif isinstance(unmatched_entity, UnmatchedRangeEntity):
|
|
||||||
num_unmatched_ranges += 1
|
|
||||||
num_unmatched_entities += 1
|
|
||||||
else:
|
|
||||||
num_unmatched_entities += 1
|
|
||||||
|
|
||||||
if (
|
|
||||||
(maybe_result is None) # first result
|
|
||||||
or (num_matched_entities > best_num_matched_entities)
|
|
||||||
or (
|
|
||||||
# Fewer unmatched entities
|
|
||||||
(num_matched_entities == best_num_matched_entities)
|
|
||||||
and (num_unmatched_entities < best_num_unmatched_entities)
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
# Prefer unmatched ranges
|
|
||||||
(num_matched_entities == best_num_matched_entities)
|
|
||||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
|
||||||
and (num_unmatched_ranges > best_num_unmatched_ranges)
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
# More literal text matched
|
|
||||||
(num_matched_entities == best_num_matched_entities)
|
|
||||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
|
||||||
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
|
||||||
and (result.text_chunks_matched > maybe_result.text_chunks_matched)
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
# Prefer match failures with entities
|
|
||||||
(result.text_chunks_matched == maybe_result.text_chunks_matched)
|
|
||||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
|
||||||
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
|
||||||
and (
|
|
||||||
("name" in result.entities)
|
|
||||||
or ("name" in result.unmatched_entities)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
maybe_result = result
|
|
||||||
best_num_matched_entities = num_matched_entities
|
|
||||||
best_num_unmatched_entities = num_unmatched_entities
|
|
||||||
best_num_unmatched_ranges = num_unmatched_ranges
|
|
||||||
|
|
||||||
return maybe_result
|
|
||||||
|
|
||||||
def _recognize_strict(
|
def _recognize_strict(
|
||||||
self,
|
self,
|
||||||
@ -653,6 +826,9 @@ class DefaultAgent(ConversationEntity):
|
|||||||
self._lang_intents.pop(language, None)
|
self._lang_intents.pop(language, None)
|
||||||
_LOGGER.debug("Cleared intents for language: %s", language)
|
_LOGGER.debug("Cleared intents for language: %s", language)
|
||||||
|
|
||||||
|
# Intents have changed, so we must clear the cache
|
||||||
|
self._intent_cache.clear()
|
||||||
|
|
||||||
async def async_prepare(self, language: str | None = None) -> None:
|
async def async_prepare(self, language: str | None = None) -> None:
|
||||||
"""Load intents for a language."""
|
"""Load intents for a language."""
|
||||||
if language is None:
|
if language is None:
|
||||||
@ -837,10 +1013,14 @@ class DefaultAgent(ConversationEntity):
|
|||||||
if self._unsub_clear_slot_list is None:
|
if self._unsub_clear_slot_list is None:
|
||||||
return
|
return
|
||||||
self._slot_lists = None
|
self._slot_lists = None
|
||||||
|
self._all_entity_names = None
|
||||||
for unsub in self._unsub_clear_slot_list:
|
for unsub in self._unsub_clear_slot_list:
|
||||||
unsub()
|
unsub()
|
||||||
self._unsub_clear_slot_list = None
|
self._unsub_clear_slot_list = None
|
||||||
|
|
||||||
|
# Slot lists have changed, so we must clear the cache
|
||||||
|
self._intent_cache.clear()
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
def _make_slot_lists(self) -> dict[str, SlotList]:
|
def _make_slot_lists(self) -> dict[str, SlotList]:
|
||||||
"""Create slot lists with areas and entity names/aliases."""
|
"""Create slot lists with areas and entity names/aliases."""
|
||||||
|
@ -2833,3 +2833,110 @@ async def test_query_same_name_different_areas(
|
|||||||
assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER
|
assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER
|
||||||
assert len(result.response.matched_states) == 1
|
assert len(result.response.matched_states) == 1
|
||||||
assert result.response.matched_states[0].entity_id == kitchen_light.entity_id
|
assert result.response.matched_states[0].entity_id == kitchen_light.entity_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("init_components")
|
||||||
|
async def test_intent_cache_exposed(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that intent recognition results are cached for exposed entities."""
|
||||||
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
|
assert isinstance(agent, default_agent.DefaultAgent)
|
||||||
|
|
||||||
|
entity_id = "light.test_light"
|
||||||
|
hass.states.async_set(entity_id, "off")
|
||||||
|
expose_entity(hass, entity_id, True)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
user_input = ConversationInput(
|
||||||
|
text="turn on test light",
|
||||||
|
context=Context(),
|
||||||
|
conversation_id=None,
|
||||||
|
device_id=None,
|
||||||
|
language=hass.config.language,
|
||||||
|
agent_id=None,
|
||||||
|
)
|
||||||
|
result = await agent.async_recognize_intent(user_input)
|
||||||
|
assert result is not None
|
||||||
|
assert result.entities["name"].text == "test light"
|
||||||
|
|
||||||
|
# Mark this result so we know it is from cache next time
|
||||||
|
mark = "_from_cache"
|
||||||
|
setattr(result, mark, True)
|
||||||
|
|
||||||
|
# Should be from cache this time
|
||||||
|
result = await agent.async_recognize_intent(user_input)
|
||||||
|
assert result is not None
|
||||||
|
assert getattr(result, mark, None) is True
|
||||||
|
|
||||||
|
# Unexposing clears the cache
|
||||||
|
expose_entity(hass, entity_id, False)
|
||||||
|
result = await agent.async_recognize_intent(user_input)
|
||||||
|
assert result is not None
|
||||||
|
assert getattr(result, mark, None) is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("init_components")
|
||||||
|
async def test_intent_cache_all_entities(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that intent recognition results are cached for all entities."""
|
||||||
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
|
assert isinstance(agent, default_agent.DefaultAgent)
|
||||||
|
|
||||||
|
entity_id = "light.test_light"
|
||||||
|
hass.states.async_set(entity_id, "off")
|
||||||
|
expose_entity(hass, entity_id, False) # not exposed
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
user_input = ConversationInput(
|
||||||
|
text="turn on test light",
|
||||||
|
context=Context(),
|
||||||
|
conversation_id=None,
|
||||||
|
device_id=None,
|
||||||
|
language=hass.config.language,
|
||||||
|
agent_id=None,
|
||||||
|
)
|
||||||
|
result = await agent.async_recognize_intent(user_input)
|
||||||
|
assert result is not None
|
||||||
|
assert result.entities["name"].text == "test light"
|
||||||
|
|
||||||
|
# Mark this result so we know it is from cache next time
|
||||||
|
mark = "_from_cache"
|
||||||
|
setattr(result, mark, True)
|
||||||
|
|
||||||
|
# Should be from cache this time
|
||||||
|
result = await agent.async_recognize_intent(user_input)
|
||||||
|
assert result is not None
|
||||||
|
assert getattr(result, mark, None) is True
|
||||||
|
|
||||||
|
# Adding a new entity clears the cache
|
||||||
|
hass.states.async_set("light.new_light", "off")
|
||||||
|
result = await agent.async_recognize_intent(user_input)
|
||||||
|
assert result is not None
|
||||||
|
assert getattr(result, mark, None) is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("init_components")
|
||||||
|
async def test_intent_cache_fuzzy(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that intent recognition results are cached for fuzzy matches."""
|
||||||
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
|
assert isinstance(agent, default_agent.DefaultAgent)
|
||||||
|
|
||||||
|
# There is no entity named test light
|
||||||
|
user_input = ConversationInput(
|
||||||
|
text="turn on test light",
|
||||||
|
context=Context(),
|
||||||
|
conversation_id=None,
|
||||||
|
device_id=None,
|
||||||
|
language=hass.config.language,
|
||||||
|
agent_id=None,
|
||||||
|
)
|
||||||
|
result = await agent.async_recognize_intent(user_input)
|
||||||
|
assert result is not None
|
||||||
|
assert result.unmatched_entities["name"].text == "test light"
|
||||||
|
|
||||||
|
# Mark this result so we know it is from cache next time
|
||||||
|
mark = "_from_cache"
|
||||||
|
setattr(result, mark, True)
|
||||||
|
|
||||||
|
# Should be from cache this time
|
||||||
|
result = await agent.async_recognize_intent(user_input)
|
||||||
|
assert result is not None
|
||||||
|
assert getattr(result, mark, None) is True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user