From f47840d83c0b92466917b40c85a7bdf10c8e0464 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 22 Nov 2024 19:57:42 -0600 Subject: [PATCH] Cache intent recognition results (#131114) --- .../components/conversation/default_agent.py | 370 +++++++++++++----- .../conversation/test_default_agent.py | 107 +++++ 2 files changed, 382 insertions(+), 95 deletions(-) diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index c6d394a1366..20720b90423 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -3,8 +3,10 @@ from __future__ import annotations import asyncio +from collections import OrderedDict from collections.abc import Awaitable, Callable, Iterable from dataclasses import dataclass +from enum import Enum, auto import functools import logging from pathlib import Path @@ -102,6 +104,77 @@ class SentenceTriggerResult: matched_triggers: dict[int, RecognizeResult] +class IntentMatchingStage(Enum): + """Stages of intent matching.""" + + EXPOSED_ENTITIES_ONLY = auto() + """Match against exposed entities only.""" + + ALL_ENTITIES = auto() + """Match against all entities in Home Assistant.""" + + FUZZY = auto() + """Capture names that are not known to Home Assistant.""" + + +@dataclass(frozen=True) +class IntentCacheKey: + """Key for IntentCache.""" + + text: str + """User input text.""" + + language: str + """Language of text.""" + + device_id: str | None + """Device id from user input.""" + + +@dataclass(frozen=True) +class IntentCacheValue: + """Value for IntentCache.""" + + result: RecognizeResult | None + """Result of intent recognition.""" + + stage: IntentMatchingStage + """Stage where result was found.""" + + +class IntentCache: + """LRU cache for intent recognition results.""" + + def __init__(self, capacity: int) -> None: + """Initialize cache.""" + self.cache: OrderedDict[IntentCacheKey, IntentCacheValue] = OrderedDict() + self.capacity = capacity + + def get(self, key: IntentCacheKey) -> IntentCacheValue | None: + """Get value for cache or None.""" + if key not in self.cache: + return None + + # Move the key to the end to show it was recently used + self.cache.move_to_end(key) + return self.cache[key] + + def put(self, key: IntentCacheKey, value: IntentCacheValue) -> None: + """Put a value in the cache, evicting the least recently used item if necessary.""" + if key in self.cache: + # Update value and mark as recently used + self.cache.move_to_end(key) + elif len(self.cache) >= self.capacity: + # Evict the oldest item + self.cache.popitem(last=False) + + self.cache[key] = value + + def clear(self) -> None: + """Clear the cache.""" + self.cache.clear() + + def _get_language_variations(language: str) -> Iterable[str]: """Generate language codes with and without region.""" yield language @@ -160,6 +233,7 @@ class DefaultAgent(ConversationEntity): # intent -> [sentences] self._config_intents: dict[str, Any] = config_intents self._slot_lists: dict[str, SlotList] | None = None + self._all_entity_names: TextSlotList | None = None # Sentences that will trigger a callback (skipping intent recognition) self._trigger_sentences: list[TriggerData] = [] @@ -167,6 +241,9 @@ class DefaultAgent(ConversationEntity): self._unsub_clear_slot_list: list[Callable[[], None]] | None = None self._load_intents_lock = asyncio.Lock() + # LRU cache to avoid unnecessary intent matching + self._intent_cache = IntentCache(capacity=128) + @property def supported_languages(self) -> list[str]: """Return a list of supported languages.""" @@ -417,18 +494,200 @@ class DefaultAgent(ConversationEntity): strict_intents_only: bool, ) -> RecognizeResult | None: """Search intents for a match to user input.""" - strict_result = self._recognize_strict( - user_input, lang_intents, slot_lists, intent_context, language - ) + skip_exposed_match = False - if strict_result is not None: - # Successful strict match - return strict_result + # Try cache first + cache_key = IntentCacheKey( + text=user_input.text, language=language, device_id=user_input.device_id + ) + cache_value = self._intent_cache.get(cache_key) + if cache_value is not None: + if (cache_value.result is not None) and ( + cache_value.stage == IntentMatchingStage.EXPOSED_ENTITIES_ONLY + ): + _LOGGER.debug("Got cached result for exposed entities") + return cache_value.result + + # Continue with matching, but we know we won't succeed for exposed + # entities only. + skip_exposed_match = True + + if not skip_exposed_match: + start_time = time.monotonic() + strict_result = self._recognize_strict( + user_input, lang_intents, slot_lists, intent_context, language + ) + _LOGGER.debug( + "Checked exposed entities in %s second(s)", + time.monotonic() - start_time, + ) + + # Update cache + self._intent_cache.put( + cache_key, + IntentCacheValue( + result=strict_result, + stage=IntentMatchingStage.EXPOSED_ENTITIES_ONLY, + ), + ) + + if strict_result is not None: + # Successful strict match with exposed entities + return strict_result if strict_intents_only: + # Don't try matching against all entities or doing a fuzzy match return None # Try again with all entities (including unexposed) + skip_all_entities_match = False + if cache_value is not None: + if (cache_value.result is not None) and ( + cache_value.stage == IntentMatchingStage.ALL_ENTITIES + ): + _LOGGER.debug("Got cached result for all entities") + return cache_value.result + + # Continue with matching, but we know we won't succeed for all + # entities. + skip_all_entities_match = True + + if not skip_all_entities_match: + all_entities_slot_lists = { + **slot_lists, + "name": self._get_all_entity_names(), + } + + start_time = time.monotonic() + strict_result = self._recognize_strict( + user_input, + lang_intents, + all_entities_slot_lists, + intent_context, + language, + ) + + _LOGGER.debug( + "Checked all entities in %s second(s)", time.monotonic() - start_time + ) + + # Update cache + self._intent_cache.put( + cache_key, + IntentCacheValue( + result=strict_result, stage=IntentMatchingStage.ALL_ENTITIES + ), + ) + + if strict_result is not None: + # Not a successful match, but useful for an error message. + # This should fail the intent handling phase (async_match_targets). + return strict_result + + # Try again with missing entities enabled + skip_fuzzy_match = False + if cache_value is not None: + if (cache_value.result is not None) and ( + cache_value.stage == IntentMatchingStage.FUZZY + ): + _LOGGER.debug("Got cached result for fuzzy match") + return cache_value.result + + # We know we won't succeed for fuzzy matching. + skip_fuzzy_match = True + + maybe_result: RecognizeResult | None = None + if not skip_fuzzy_match: + start_time = time.monotonic() + best_num_matched_entities = 0 + best_num_unmatched_entities = 0 + best_num_unmatched_ranges = 0 + for result in recognize_all( + user_input.text, + lang_intents.intents, + slot_lists=slot_lists, + intent_context=intent_context, + allow_unmatched_entities=True, + ): + if result.text_chunks_matched < 1: + # Skip results that don't match any literal text + continue + + # Don't count missing entities that couldn't be filled from context + num_matched_entities = 0 + for matched_entity in result.entities_list: + if matched_entity.name not in result.unmatched_entities: + num_matched_entities += 1 + + num_unmatched_entities = 0 + num_unmatched_ranges = 0 + for unmatched_entity in result.unmatched_entities_list: + if isinstance(unmatched_entity, UnmatchedTextEntity): + if unmatched_entity.text != MISSING_ENTITY: + num_unmatched_entities += 1 + elif isinstance(unmatched_entity, UnmatchedRangeEntity): + num_unmatched_ranges += 1 + num_unmatched_entities += 1 + else: + num_unmatched_entities += 1 + + if ( + (maybe_result is None) # first result + or (num_matched_entities > best_num_matched_entities) + or ( + # Fewer unmatched entities + (num_matched_entities == best_num_matched_entities) + and (num_unmatched_entities < best_num_unmatched_entities) + ) + or ( + # Prefer unmatched ranges + (num_matched_entities == best_num_matched_entities) + and (num_unmatched_entities == best_num_unmatched_entities) + and (num_unmatched_ranges > best_num_unmatched_ranges) + ) + or ( + # More literal text matched + (num_matched_entities == best_num_matched_entities) + and (num_unmatched_entities == best_num_unmatched_entities) + and (num_unmatched_ranges == best_num_unmatched_ranges) + and ( + result.text_chunks_matched + > maybe_result.text_chunks_matched + ) + ) + or ( + # Prefer match failures with entities + (result.text_chunks_matched == maybe_result.text_chunks_matched) + and (num_unmatched_entities == best_num_unmatched_entities) + and (num_unmatched_ranges == best_num_unmatched_ranges) + and ( + ("name" in result.entities) + or ("name" in result.unmatched_entities) + ) + ) + ): + maybe_result = result + best_num_matched_entities = num_matched_entities + best_num_unmatched_entities = num_unmatched_entities + best_num_unmatched_ranges = num_unmatched_ranges + + # Update cache + self._intent_cache.put( + cache_key, + IntentCacheValue(result=maybe_result, stage=IntentMatchingStage.FUZZY), + ) + + _LOGGER.debug( + "Did fuzzy match in %s second(s)", time.monotonic() - start_time + ) + + return maybe_result + + def _get_all_entity_names(self) -> TextSlotList: + """Get slot list with all entity names in Home Assistant.""" + if self._all_entity_names is not None: + return self._all_entity_names + entity_registry = er.async_get(self.hass) all_entity_names: list[tuple[str, str, dict[str, Any]]] = [] @@ -459,96 +718,10 @@ class DefaultAgent(ConversationEntity): # Default name all_entity_names.append((state.name, state.name, context)) - slot_lists = { - **slot_lists, - "name": TextSlotList.from_tuples(all_entity_names, allow_template=False), - } - - strict_result = self._recognize_strict( - user_input, - lang_intents, - slot_lists, - intent_context, - language, + self._all_entity_names = TextSlotList.from_tuples( + all_entity_names, allow_template=False ) - - if strict_result is not None: - # Not a successful match, but useful for an error message. - # This should fail the intent handling phase (async_match_targets). - return strict_result - - # Try again with missing entities enabled - maybe_result: RecognizeResult | None = None - best_num_matched_entities = 0 - best_num_unmatched_entities = 0 - best_num_unmatched_ranges = 0 - for result in recognize_all( - user_input.text, - lang_intents.intents, - slot_lists=slot_lists, - intent_context=intent_context, - allow_unmatched_entities=True, - ): - if result.text_chunks_matched < 1: - # Skip results that don't match any literal text - continue - - # Don't count missing entities that couldn't be filled from context - num_matched_entities = 0 - for matched_entity in result.entities_list: - if matched_entity.name not in result.unmatched_entities: - num_matched_entities += 1 - - num_unmatched_entities = 0 - num_unmatched_ranges = 0 - for unmatched_entity in result.unmatched_entities_list: - if isinstance(unmatched_entity, UnmatchedTextEntity): - if unmatched_entity.text != MISSING_ENTITY: - num_unmatched_entities += 1 - elif isinstance(unmatched_entity, UnmatchedRangeEntity): - num_unmatched_ranges += 1 - num_unmatched_entities += 1 - else: - num_unmatched_entities += 1 - - if ( - (maybe_result is None) # first result - or (num_matched_entities > best_num_matched_entities) - or ( - # Fewer unmatched entities - (num_matched_entities == best_num_matched_entities) - and (num_unmatched_entities < best_num_unmatched_entities) - ) - or ( - # Prefer unmatched ranges - (num_matched_entities == best_num_matched_entities) - and (num_unmatched_entities == best_num_unmatched_entities) - and (num_unmatched_ranges > best_num_unmatched_ranges) - ) - or ( - # More literal text matched - (num_matched_entities == best_num_matched_entities) - and (num_unmatched_entities == best_num_unmatched_entities) - and (num_unmatched_ranges == best_num_unmatched_ranges) - and (result.text_chunks_matched > maybe_result.text_chunks_matched) - ) - or ( - # Prefer match failures with entities - (result.text_chunks_matched == maybe_result.text_chunks_matched) - and (num_unmatched_entities == best_num_unmatched_entities) - and (num_unmatched_ranges == best_num_unmatched_ranges) - and ( - ("name" in result.entities) - or ("name" in result.unmatched_entities) - ) - ) - ): - maybe_result = result - best_num_matched_entities = num_matched_entities - best_num_unmatched_entities = num_unmatched_entities - best_num_unmatched_ranges = num_unmatched_ranges - - return maybe_result + return self._all_entity_names def _recognize_strict( self, @@ -653,6 +826,9 @@ class DefaultAgent(ConversationEntity): self._lang_intents.pop(language, None) _LOGGER.debug("Cleared intents for language: %s", language) + # Intents have changed, so we must clear the cache + self._intent_cache.clear() + async def async_prepare(self, language: str | None = None) -> None: """Load intents for a language.""" if language is None: @@ -837,10 +1013,14 @@ class DefaultAgent(ConversationEntity): if self._unsub_clear_slot_list is None: return self._slot_lists = None + self._all_entity_names = None for unsub in self._unsub_clear_slot_list: unsub() self._unsub_clear_slot_list = None + # Slot lists have changed, so we must clear the cache + self._intent_cache.clear() + @core.callback def _make_slot_lists(self) -> dict[str, SlotList]: """Create slot lists with areas and entity names/aliases.""" diff --git a/tests/components/conversation/test_default_agent.py b/tests/components/conversation/test_default_agent.py index 3c6b463670a..1e5e284a245 100644 --- a/tests/components/conversation/test_default_agent.py +++ b/tests/components/conversation/test_default_agent.py @@ -2833,3 +2833,110 @@ async def test_query_same_name_different_areas( assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER assert len(result.response.matched_states) == 1 assert result.response.matched_states[0].entity_id == kitchen_light.entity_id + + +@pytest.mark.usefixtures("init_components") +async def test_intent_cache_exposed(hass: HomeAssistant) -> None: + """Test that intent recognition results are cached for exposed entities.""" + agent = hass.data[DATA_DEFAULT_ENTITY] + assert isinstance(agent, default_agent.DefaultAgent) + + entity_id = "light.test_light" + hass.states.async_set(entity_id, "off") + expose_entity(hass, entity_id, True) + await hass.async_block_till_done() + + user_input = ConversationInput( + text="turn on test light", + context=Context(), + conversation_id=None, + device_id=None, + language=hass.config.language, + agent_id=None, + ) + result = await agent.async_recognize_intent(user_input) + assert result is not None + assert result.entities["name"].text == "test light" + + # Mark this result so we know it is from cache next time + mark = "_from_cache" + setattr(result, mark, True) + + # Should be from cache this time + result = await agent.async_recognize_intent(user_input) + assert result is not None + assert getattr(result, mark, None) is True + + # Unexposing clears the cache + expose_entity(hass, entity_id, False) + result = await agent.async_recognize_intent(user_input) + assert result is not None + assert getattr(result, mark, None) is None + + +@pytest.mark.usefixtures("init_components") +async def test_intent_cache_all_entities(hass: HomeAssistant) -> None: + """Test that intent recognition results are cached for all entities.""" + agent = hass.data[DATA_DEFAULT_ENTITY] + assert isinstance(agent, default_agent.DefaultAgent) + + entity_id = "light.test_light" + hass.states.async_set(entity_id, "off") + expose_entity(hass, entity_id, False) # not exposed + await hass.async_block_till_done() + + user_input = ConversationInput( + text="turn on test light", + context=Context(), + conversation_id=None, + device_id=None, + language=hass.config.language, + agent_id=None, + ) + result = await agent.async_recognize_intent(user_input) + assert result is not None + assert result.entities["name"].text == "test light" + + # Mark this result so we know it is from cache next time + mark = "_from_cache" + setattr(result, mark, True) + + # Should be from cache this time + result = await agent.async_recognize_intent(user_input) + assert result is not None + assert getattr(result, mark, None) is True + + # Adding a new entity clears the cache + hass.states.async_set("light.new_light", "off") + result = await agent.async_recognize_intent(user_input) + assert result is not None + assert getattr(result, mark, None) is None + + +@pytest.mark.usefixtures("init_components") +async def test_intent_cache_fuzzy(hass: HomeAssistant) -> None: + """Test that intent recognition results are cached for fuzzy matches.""" + agent = hass.data[DATA_DEFAULT_ENTITY] + assert isinstance(agent, default_agent.DefaultAgent) + + # There is no entity named test light + user_input = ConversationInput( + text="turn on test light", + context=Context(), + conversation_id=None, + device_id=None, + language=hass.config.language, + agent_id=None, + ) + result = await agent.async_recognize_intent(user_input) + assert result is not None + assert result.unmatched_entities["name"].text == "test light" + + # Mark this result so we know it is from cache next time + mark = "_from_cache" + setattr(result, mark, True) + + # Should be from cache this time + result = await agent.async_recognize_intent(user_input) + assert result is not None + assert getattr(result, mark, None) is True