Cache intent recognition results (#131114)

This commit is contained in:
Michael Hansen 2024-11-22 19:57:42 -06:00 committed by GitHub
parent 8f9095ba67
commit f47840d83c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 382 additions and 95 deletions

View File

@ -3,8 +3,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections import OrderedDict
from collections.abc import Awaitable, Callable, Iterable from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto
import functools import functools
import logging import logging
from pathlib import Path from pathlib import Path
@ -102,6 +104,77 @@ class SentenceTriggerResult:
matched_triggers: dict[int, RecognizeResult] matched_triggers: dict[int, RecognizeResult]
class IntentMatchingStage(Enum):
"""Stages of intent matching."""
EXPOSED_ENTITIES_ONLY = auto()
"""Match against exposed entities only."""
ALL_ENTITIES = auto()
"""Match against all entities in Home Assistant."""
FUZZY = auto()
"""Capture names that are not known to Home Assistant."""
@dataclass(frozen=True)
class IntentCacheKey:
"""Key for IntentCache."""
text: str
"""User input text."""
language: str
"""Language of text."""
device_id: str | None
"""Device id from user input."""
@dataclass(frozen=True)
class IntentCacheValue:
"""Value for IntentCache."""
result: RecognizeResult | None
"""Result of intent recognition."""
stage: IntentMatchingStage
"""Stage where result was found."""
class IntentCache:
"""LRU cache for intent recognition results."""
def __init__(self, capacity: int) -> None:
"""Initialize cache."""
self.cache: OrderedDict[IntentCacheKey, IntentCacheValue] = OrderedDict()
self.capacity = capacity
def get(self, key: IntentCacheKey) -> IntentCacheValue | None:
"""Get value for cache or None."""
if key not in self.cache:
return None
# Move the key to the end to show it was recently used
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key: IntentCacheKey, value: IntentCacheValue) -> None:
"""Put a value in the cache, evicting the least recently used item if necessary."""
if key in self.cache:
# Update value and mark as recently used
self.cache.move_to_end(key)
elif len(self.cache) >= self.capacity:
# Evict the oldest item
self.cache.popitem(last=False)
self.cache[key] = value
def clear(self) -> None:
"""Clear the cache."""
self.cache.clear()
def _get_language_variations(language: str) -> Iterable[str]: def _get_language_variations(language: str) -> Iterable[str]:
"""Generate language codes with and without region.""" """Generate language codes with and without region."""
yield language yield language
@ -160,6 +233,7 @@ class DefaultAgent(ConversationEntity):
# intent -> [sentences] # intent -> [sentences]
self._config_intents: dict[str, Any] = config_intents self._config_intents: dict[str, Any] = config_intents
self._slot_lists: dict[str, SlotList] | None = None self._slot_lists: dict[str, SlotList] | None = None
self._all_entity_names: TextSlotList | None = None
# Sentences that will trigger a callback (skipping intent recognition) # Sentences that will trigger a callback (skipping intent recognition)
self._trigger_sentences: list[TriggerData] = [] self._trigger_sentences: list[TriggerData] = []
@ -167,6 +241,9 @@ class DefaultAgent(ConversationEntity):
self._unsub_clear_slot_list: list[Callable[[], None]] | None = None self._unsub_clear_slot_list: list[Callable[[], None]] | None = None
self._load_intents_lock = asyncio.Lock() self._load_intents_lock = asyncio.Lock()
# LRU cache to avoid unnecessary intent matching
self._intent_cache = IntentCache(capacity=128)
@property @property
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str]:
"""Return a list of supported languages.""" """Return a list of supported languages."""
@ -417,68 +494,111 @@ class DefaultAgent(ConversationEntity):
strict_intents_only: bool, strict_intents_only: bool,
) -> RecognizeResult | None: ) -> RecognizeResult | None:
"""Search intents for a match to user input.""" """Search intents for a match to user input."""
skip_exposed_match = False
# Try cache first
cache_key = IntentCacheKey(
text=user_input.text, language=language, device_id=user_input.device_id
)
cache_value = self._intent_cache.get(cache_key)
if cache_value is not None:
if (cache_value.result is not None) and (
cache_value.stage == IntentMatchingStage.EXPOSED_ENTITIES_ONLY
):
_LOGGER.debug("Got cached result for exposed entities")
return cache_value.result
# Continue with matching, but we know we won't succeed for exposed
# entities only.
skip_exposed_match = True
if not skip_exposed_match:
start_time = time.monotonic()
strict_result = self._recognize_strict( strict_result = self._recognize_strict(
user_input, lang_intents, slot_lists, intent_context, language user_input, lang_intents, slot_lists, intent_context, language
) )
_LOGGER.debug(
"Checked exposed entities in %s second(s)",
time.monotonic() - start_time,
)
# Update cache
self._intent_cache.put(
cache_key,
IntentCacheValue(
result=strict_result,
stage=IntentMatchingStage.EXPOSED_ENTITIES_ONLY,
),
)
if strict_result is not None: if strict_result is not None:
# Successful strict match # Successful strict match with exposed entities
return strict_result return strict_result
if strict_intents_only: if strict_intents_only:
# Don't try matching against all entities or doing a fuzzy match
return None return None
# Try again with all entities (including unexposed) # Try again with all entities (including unexposed)
entity_registry = er.async_get(self.hass) skip_all_entities_match = False
all_entity_names: list[tuple[str, str, dict[str, Any]]] = [] if cache_value is not None:
if (cache_value.result is not None) and (
for state in self.hass.states.async_all(): cache_value.stage == IntentMatchingStage.ALL_ENTITIES
context = {"domain": state.domain}
if state.attributes:
# Include some attributes
for attr in DEFAULT_EXPOSED_ATTRIBUTES:
if attr not in state.attributes:
continue
context[attr] = state.attributes[attr]
if entity := entity_registry.async_get(state.entity_id):
# Skip config/hidden entities
if (entity.entity_category is not None) or (
entity.hidden_by is not None
): ):
continue _LOGGER.debug("Got cached result for all entities")
return cache_value.result
if entity.aliases: # Continue with matching, but we know we won't succeed for all
# Also add aliases # entities.
for alias in entity.aliases: skip_all_entities_match = True
if not alias.strip():
continue
all_entity_names.append((alias, alias, context)) if not skip_all_entities_match:
all_entities_slot_lists = {
# Default name
all_entity_names.append((state.name, state.name, context))
slot_lists = {
**slot_lists, **slot_lists,
"name": TextSlotList.from_tuples(all_entity_names, allow_template=False), "name": self._get_all_entity_names(),
} }
start_time = time.monotonic()
strict_result = self._recognize_strict( strict_result = self._recognize_strict(
user_input, user_input,
lang_intents, lang_intents,
slot_lists, all_entities_slot_lists,
intent_context, intent_context,
language, language,
) )
_LOGGER.debug(
"Checked all entities in %s second(s)", time.monotonic() - start_time
)
# Update cache
self._intent_cache.put(
cache_key,
IntentCacheValue(
result=strict_result, stage=IntentMatchingStage.ALL_ENTITIES
),
)
if strict_result is not None: if strict_result is not None:
# Not a successful match, but useful for an error message. # Not a successful match, but useful for an error message.
# This should fail the intent handling phase (async_match_targets). # This should fail the intent handling phase (async_match_targets).
return strict_result return strict_result
# Try again with missing entities enabled # Try again with missing entities enabled
skip_fuzzy_match = False
if cache_value is not None:
if (cache_value.result is not None) and (
cache_value.stage == IntentMatchingStage.FUZZY
):
_LOGGER.debug("Got cached result for fuzzy match")
return cache_value.result
# We know we won't succeed for fuzzy matching.
skip_fuzzy_match = True
maybe_result: RecognizeResult | None = None maybe_result: RecognizeResult | None = None
if not skip_fuzzy_match:
start_time = time.monotonic()
best_num_matched_entities = 0 best_num_matched_entities = 0
best_num_unmatched_entities = 0 best_num_unmatched_entities = 0
best_num_unmatched_ranges = 0 best_num_unmatched_ranges = 0
@ -530,7 +650,10 @@ class DefaultAgent(ConversationEntity):
(num_matched_entities == best_num_matched_entities) (num_matched_entities == best_num_matched_entities)
and (num_unmatched_entities == best_num_unmatched_entities) and (num_unmatched_entities == best_num_unmatched_entities)
and (num_unmatched_ranges == best_num_unmatched_ranges) and (num_unmatched_ranges == best_num_unmatched_ranges)
and (result.text_chunks_matched > maybe_result.text_chunks_matched) and (
result.text_chunks_matched
> maybe_result.text_chunks_matched
)
) )
or ( or (
# Prefer match failures with entities # Prefer match failures with entities
@ -548,8 +671,58 @@ class DefaultAgent(ConversationEntity):
best_num_unmatched_entities = num_unmatched_entities best_num_unmatched_entities = num_unmatched_entities
best_num_unmatched_ranges = num_unmatched_ranges best_num_unmatched_ranges = num_unmatched_ranges
# Update cache
self._intent_cache.put(
cache_key,
IntentCacheValue(result=maybe_result, stage=IntentMatchingStage.FUZZY),
)
_LOGGER.debug(
"Did fuzzy match in %s second(s)", time.monotonic() - start_time
)
return maybe_result return maybe_result
def _get_all_entity_names(self) -> TextSlotList:
"""Get slot list with all entity names in Home Assistant."""
if self._all_entity_names is not None:
return self._all_entity_names
entity_registry = er.async_get(self.hass)
all_entity_names: list[tuple[str, str, dict[str, Any]]] = []
for state in self.hass.states.async_all():
context = {"domain": state.domain}
if state.attributes:
# Include some attributes
for attr in DEFAULT_EXPOSED_ATTRIBUTES:
if attr not in state.attributes:
continue
context[attr] = state.attributes[attr]
if entity := entity_registry.async_get(state.entity_id):
# Skip config/hidden entities
if (entity.entity_category is not None) or (
entity.hidden_by is not None
):
continue
if entity.aliases:
# Also add aliases
for alias in entity.aliases:
if not alias.strip():
continue
all_entity_names.append((alias, alias, context))
# Default name
all_entity_names.append((state.name, state.name, context))
self._all_entity_names = TextSlotList.from_tuples(
all_entity_names, allow_template=False
)
return self._all_entity_names
def _recognize_strict( def _recognize_strict(
self, self,
user_input: ConversationInput, user_input: ConversationInput,
@ -653,6 +826,9 @@ class DefaultAgent(ConversationEntity):
self._lang_intents.pop(language, None) self._lang_intents.pop(language, None)
_LOGGER.debug("Cleared intents for language: %s", language) _LOGGER.debug("Cleared intents for language: %s", language)
# Intents have changed, so we must clear the cache
self._intent_cache.clear()
async def async_prepare(self, language: str | None = None) -> None: async def async_prepare(self, language: str | None = None) -> None:
"""Load intents for a language.""" """Load intents for a language."""
if language is None: if language is None:
@ -837,10 +1013,14 @@ class DefaultAgent(ConversationEntity):
if self._unsub_clear_slot_list is None: if self._unsub_clear_slot_list is None:
return return
self._slot_lists = None self._slot_lists = None
self._all_entity_names = None
for unsub in self._unsub_clear_slot_list: for unsub in self._unsub_clear_slot_list:
unsub() unsub()
self._unsub_clear_slot_list = None self._unsub_clear_slot_list = None
# Slot lists have changed, so we must clear the cache
self._intent_cache.clear()
@core.callback @core.callback
def _make_slot_lists(self) -> dict[str, SlotList]: def _make_slot_lists(self) -> dict[str, SlotList]:
"""Create slot lists with areas and entity names/aliases.""" """Create slot lists with areas and entity names/aliases."""

View File

@ -2833,3 +2833,110 @@ async def test_query_same_name_different_areas(
assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER
assert len(result.response.matched_states) == 1 assert len(result.response.matched_states) == 1
assert result.response.matched_states[0].entity_id == kitchen_light.entity_id assert result.response.matched_states[0].entity_id == kitchen_light.entity_id
@pytest.mark.usefixtures("init_components")
async def test_intent_cache_exposed(hass: HomeAssistant) -> None:
"""Test that intent recognition results are cached for exposed entities."""
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
entity_id = "light.test_light"
hass.states.async_set(entity_id, "off")
expose_entity(hass, entity_id, True)
await hass.async_block_till_done()
user_input = ConversationInput(
text="turn on test light",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
agent_id=None,
)
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert result.entities["name"].text == "test light"
# Mark this result so we know it is from cache next time
mark = "_from_cache"
setattr(result, mark, True)
# Should be from cache this time
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert getattr(result, mark, None) is True
# Unexposing clears the cache
expose_entity(hass, entity_id, False)
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert getattr(result, mark, None) is None
@pytest.mark.usefixtures("init_components")
async def test_intent_cache_all_entities(hass: HomeAssistant) -> None:
"""Test that intent recognition results are cached for all entities."""
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
entity_id = "light.test_light"
hass.states.async_set(entity_id, "off")
expose_entity(hass, entity_id, False) # not exposed
await hass.async_block_till_done()
user_input = ConversationInput(
text="turn on test light",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
agent_id=None,
)
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert result.entities["name"].text == "test light"
# Mark this result so we know it is from cache next time
mark = "_from_cache"
setattr(result, mark, True)
# Should be from cache this time
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert getattr(result, mark, None) is True
# Adding a new entity clears the cache
hass.states.async_set("light.new_light", "off")
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert getattr(result, mark, None) is None
@pytest.mark.usefixtures("init_components")
async def test_intent_cache_fuzzy(hass: HomeAssistant) -> None:
"""Test that intent recognition results are cached for fuzzy matches."""
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
# There is no entity named test light
user_input = ConversationInput(
text="turn on test light",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
agent_id=None,
)
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert result.unmatched_entities["name"].text == "test light"
# Mark this result so we know it is from cache next time
mark = "_from_cache"
setattr(result, mark, True)
# Should be from cache this time
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert getattr(result, mark, None) is True