Cache intent recognition results (#131114)

This commit is contained in:
Michael Hansen 2024-11-22 19:57:42 -06:00 committed by GitHub
parent 8f9095ba67
commit f47840d83c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 382 additions and 95 deletions

View File

@ -3,8 +3,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections import OrderedDict
from collections.abc import Awaitable, Callable, Iterable from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto
import functools import functools
import logging import logging
from pathlib import Path from pathlib import Path
@ -102,6 +104,77 @@ class SentenceTriggerResult:
matched_triggers: dict[int, RecognizeResult] matched_triggers: dict[int, RecognizeResult]
class IntentMatchingStage(Enum):
"""Stages of intent matching."""
EXPOSED_ENTITIES_ONLY = auto()
"""Match against exposed entities only."""
ALL_ENTITIES = auto()
"""Match against all entities in Home Assistant."""
FUZZY = auto()
"""Capture names that are not known to Home Assistant."""
@dataclass(frozen=True)
class IntentCacheKey:
"""Key for IntentCache."""
text: str
"""User input text."""
language: str
"""Language of text."""
device_id: str | None
"""Device id from user input."""
@dataclass(frozen=True)
class IntentCacheValue:
"""Value for IntentCache."""
result: RecognizeResult | None
"""Result of intent recognition."""
stage: IntentMatchingStage
"""Stage where result was found."""
class IntentCache:
"""LRU cache for intent recognition results."""
def __init__(self, capacity: int) -> None:
"""Initialize cache."""
self.cache: OrderedDict[IntentCacheKey, IntentCacheValue] = OrderedDict()
self.capacity = capacity
def get(self, key: IntentCacheKey) -> IntentCacheValue | None:
"""Get value for cache or None."""
if key not in self.cache:
return None
# Move the key to the end to show it was recently used
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key: IntentCacheKey, value: IntentCacheValue) -> None:
"""Put a value in the cache, evicting the least recently used item if necessary."""
if key in self.cache:
# Update value and mark as recently used
self.cache.move_to_end(key)
elif len(self.cache) >= self.capacity:
# Evict the oldest item
self.cache.popitem(last=False)
self.cache[key] = value
def clear(self) -> None:
"""Clear the cache."""
self.cache.clear()
def _get_language_variations(language: str) -> Iterable[str]: def _get_language_variations(language: str) -> Iterable[str]:
"""Generate language codes with and without region.""" """Generate language codes with and without region."""
yield language yield language
@ -160,6 +233,7 @@ class DefaultAgent(ConversationEntity):
# intent -> [sentences] # intent -> [sentences]
self._config_intents: dict[str, Any] = config_intents self._config_intents: dict[str, Any] = config_intents
self._slot_lists: dict[str, SlotList] | None = None self._slot_lists: dict[str, SlotList] | None = None
self._all_entity_names: TextSlotList | None = None
# Sentences that will trigger a callback (skipping intent recognition) # Sentences that will trigger a callback (skipping intent recognition)
self._trigger_sentences: list[TriggerData] = [] self._trigger_sentences: list[TriggerData] = []
@ -167,6 +241,9 @@ class DefaultAgent(ConversationEntity):
self._unsub_clear_slot_list: list[Callable[[], None]] | None = None self._unsub_clear_slot_list: list[Callable[[], None]] | None = None
self._load_intents_lock = asyncio.Lock() self._load_intents_lock = asyncio.Lock()
# LRU cache to avoid unnecessary intent matching
self._intent_cache = IntentCache(capacity=128)
@property @property
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str]:
"""Return a list of supported languages.""" """Return a list of supported languages."""
@ -417,18 +494,200 @@ class DefaultAgent(ConversationEntity):
strict_intents_only: bool, strict_intents_only: bool,
) -> RecognizeResult | None: ) -> RecognizeResult | None:
"""Search intents for a match to user input.""" """Search intents for a match to user input."""
strict_result = self._recognize_strict( skip_exposed_match = False
user_input, lang_intents, slot_lists, intent_context, language
)
if strict_result is not None: # Try cache first
# Successful strict match cache_key = IntentCacheKey(
return strict_result text=user_input.text, language=language, device_id=user_input.device_id
)
cache_value = self._intent_cache.get(cache_key)
if cache_value is not None:
if (cache_value.result is not None) and (
cache_value.stage == IntentMatchingStage.EXPOSED_ENTITIES_ONLY
):
_LOGGER.debug("Got cached result for exposed entities")
return cache_value.result
# Continue with matching, but we know we won't succeed for exposed
# entities only.
skip_exposed_match = True
if not skip_exposed_match:
start_time = time.monotonic()
strict_result = self._recognize_strict(
user_input, lang_intents, slot_lists, intent_context, language
)
_LOGGER.debug(
"Checked exposed entities in %s second(s)",
time.monotonic() - start_time,
)
# Update cache
self._intent_cache.put(
cache_key,
IntentCacheValue(
result=strict_result,
stage=IntentMatchingStage.EXPOSED_ENTITIES_ONLY,
),
)
if strict_result is not None:
# Successful strict match with exposed entities
return strict_result
if strict_intents_only: if strict_intents_only:
# Don't try matching against all entities or doing a fuzzy match
return None return None
# Try again with all entities (including unexposed) # Try again with all entities (including unexposed)
skip_all_entities_match = False
if cache_value is not None:
if (cache_value.result is not None) and (
cache_value.stage == IntentMatchingStage.ALL_ENTITIES
):
_LOGGER.debug("Got cached result for all entities")
return cache_value.result
# Continue with matching, but we know we won't succeed for all
# entities.
skip_all_entities_match = True
if not skip_all_entities_match:
all_entities_slot_lists = {
**slot_lists,
"name": self._get_all_entity_names(),
}
start_time = time.monotonic()
strict_result = self._recognize_strict(
user_input,
lang_intents,
all_entities_slot_lists,
intent_context,
language,
)
_LOGGER.debug(
"Checked all entities in %s second(s)", time.monotonic() - start_time
)
# Update cache
self._intent_cache.put(
cache_key,
IntentCacheValue(
result=strict_result, stage=IntentMatchingStage.ALL_ENTITIES
),
)
if strict_result is not None:
# Not a successful match, but useful for an error message.
# This should fail the intent handling phase (async_match_targets).
return strict_result
# Try again with missing entities enabled
skip_fuzzy_match = False
if cache_value is not None:
if (cache_value.result is not None) and (
cache_value.stage == IntentMatchingStage.FUZZY
):
_LOGGER.debug("Got cached result for fuzzy match")
return cache_value.result
# We know we won't succeed for fuzzy matching.
skip_fuzzy_match = True
maybe_result: RecognizeResult | None = None
if not skip_fuzzy_match:
start_time = time.monotonic()
best_num_matched_entities = 0
best_num_unmatched_entities = 0
best_num_unmatched_ranges = 0
for result in recognize_all(
user_input.text,
lang_intents.intents,
slot_lists=slot_lists,
intent_context=intent_context,
allow_unmatched_entities=True,
):
if result.text_chunks_matched < 1:
# Skip results that don't match any literal text
continue
# Don't count missing entities that couldn't be filled from context
num_matched_entities = 0
for matched_entity in result.entities_list:
if matched_entity.name not in result.unmatched_entities:
num_matched_entities += 1
num_unmatched_entities = 0
num_unmatched_ranges = 0
for unmatched_entity in result.unmatched_entities_list:
if isinstance(unmatched_entity, UnmatchedTextEntity):
if unmatched_entity.text != MISSING_ENTITY:
num_unmatched_entities += 1
elif isinstance(unmatched_entity, UnmatchedRangeEntity):
num_unmatched_ranges += 1
num_unmatched_entities += 1
else:
num_unmatched_entities += 1
if (
(maybe_result is None) # first result
or (num_matched_entities > best_num_matched_entities)
or (
# Fewer unmatched entities
(num_matched_entities == best_num_matched_entities)
and (num_unmatched_entities < best_num_unmatched_entities)
)
or (
# Prefer unmatched ranges
(num_matched_entities == best_num_matched_entities)
and (num_unmatched_entities == best_num_unmatched_entities)
and (num_unmatched_ranges > best_num_unmatched_ranges)
)
or (
# More literal text matched
(num_matched_entities == best_num_matched_entities)
and (num_unmatched_entities == best_num_unmatched_entities)
and (num_unmatched_ranges == best_num_unmatched_ranges)
and (
result.text_chunks_matched
> maybe_result.text_chunks_matched
)
)
or (
# Prefer match failures with entities
(result.text_chunks_matched == maybe_result.text_chunks_matched)
and (num_unmatched_entities == best_num_unmatched_entities)
and (num_unmatched_ranges == best_num_unmatched_ranges)
and (
("name" in result.entities)
or ("name" in result.unmatched_entities)
)
)
):
maybe_result = result
best_num_matched_entities = num_matched_entities
best_num_unmatched_entities = num_unmatched_entities
best_num_unmatched_ranges = num_unmatched_ranges
# Update cache
self._intent_cache.put(
cache_key,
IntentCacheValue(result=maybe_result, stage=IntentMatchingStage.FUZZY),
)
_LOGGER.debug(
"Did fuzzy match in %s second(s)", time.monotonic() - start_time
)
return maybe_result
def _get_all_entity_names(self) -> TextSlotList:
"""Get slot list with all entity names in Home Assistant."""
if self._all_entity_names is not None:
return self._all_entity_names
entity_registry = er.async_get(self.hass) entity_registry = er.async_get(self.hass)
all_entity_names: list[tuple[str, str, dict[str, Any]]] = [] all_entity_names: list[tuple[str, str, dict[str, Any]]] = []
@ -459,96 +718,10 @@ class DefaultAgent(ConversationEntity):
# Default name # Default name
all_entity_names.append((state.name, state.name, context)) all_entity_names.append((state.name, state.name, context))
slot_lists = { self._all_entity_names = TextSlotList.from_tuples(
**slot_lists, all_entity_names, allow_template=False
"name": TextSlotList.from_tuples(all_entity_names, allow_template=False),
}
strict_result = self._recognize_strict(
user_input,
lang_intents,
slot_lists,
intent_context,
language,
) )
return self._all_entity_names
if strict_result is not None:
# Not a successful match, but useful for an error message.
# This should fail the intent handling phase (async_match_targets).
return strict_result
# Try again with missing entities enabled
maybe_result: RecognizeResult | None = None
best_num_matched_entities = 0
best_num_unmatched_entities = 0
best_num_unmatched_ranges = 0
for result in recognize_all(
user_input.text,
lang_intents.intents,
slot_lists=slot_lists,
intent_context=intent_context,
allow_unmatched_entities=True,
):
if result.text_chunks_matched < 1:
# Skip results that don't match any literal text
continue
# Don't count missing entities that couldn't be filled from context
num_matched_entities = 0
for matched_entity in result.entities_list:
if matched_entity.name not in result.unmatched_entities:
num_matched_entities += 1
num_unmatched_entities = 0
num_unmatched_ranges = 0
for unmatched_entity in result.unmatched_entities_list:
if isinstance(unmatched_entity, UnmatchedTextEntity):
if unmatched_entity.text != MISSING_ENTITY:
num_unmatched_entities += 1
elif isinstance(unmatched_entity, UnmatchedRangeEntity):
num_unmatched_ranges += 1
num_unmatched_entities += 1
else:
num_unmatched_entities += 1
if (
(maybe_result is None) # first result
or (num_matched_entities > best_num_matched_entities)
or (
# Fewer unmatched entities
(num_matched_entities == best_num_matched_entities)
and (num_unmatched_entities < best_num_unmatched_entities)
)
or (
# Prefer unmatched ranges
(num_matched_entities == best_num_matched_entities)
and (num_unmatched_entities == best_num_unmatched_entities)
and (num_unmatched_ranges > best_num_unmatched_ranges)
)
or (
# More literal text matched
(num_matched_entities == best_num_matched_entities)
and (num_unmatched_entities == best_num_unmatched_entities)
and (num_unmatched_ranges == best_num_unmatched_ranges)
and (result.text_chunks_matched > maybe_result.text_chunks_matched)
)
or (
# Prefer match failures with entities
(result.text_chunks_matched == maybe_result.text_chunks_matched)
and (num_unmatched_entities == best_num_unmatched_entities)
and (num_unmatched_ranges == best_num_unmatched_ranges)
and (
("name" in result.entities)
or ("name" in result.unmatched_entities)
)
)
):
maybe_result = result
best_num_matched_entities = num_matched_entities
best_num_unmatched_entities = num_unmatched_entities
best_num_unmatched_ranges = num_unmatched_ranges
return maybe_result
def _recognize_strict( def _recognize_strict(
self, self,
@ -653,6 +826,9 @@ class DefaultAgent(ConversationEntity):
self._lang_intents.pop(language, None) self._lang_intents.pop(language, None)
_LOGGER.debug("Cleared intents for language: %s", language) _LOGGER.debug("Cleared intents for language: %s", language)
# Intents have changed, so we must clear the cache
self._intent_cache.clear()
async def async_prepare(self, language: str | None = None) -> None: async def async_prepare(self, language: str | None = None) -> None:
"""Load intents for a language.""" """Load intents for a language."""
if language is None: if language is None:
@ -837,10 +1013,14 @@ class DefaultAgent(ConversationEntity):
if self._unsub_clear_slot_list is None: if self._unsub_clear_slot_list is None:
return return
self._slot_lists = None self._slot_lists = None
self._all_entity_names = None
for unsub in self._unsub_clear_slot_list: for unsub in self._unsub_clear_slot_list:
unsub() unsub()
self._unsub_clear_slot_list = None self._unsub_clear_slot_list = None
# Slot lists have changed, so we must clear the cache
self._intent_cache.clear()
@core.callback @core.callback
def _make_slot_lists(self) -> dict[str, SlotList]: def _make_slot_lists(self) -> dict[str, SlotList]:
"""Create slot lists with areas and entity names/aliases.""" """Create slot lists with areas and entity names/aliases."""

View File

@ -2833,3 +2833,110 @@ async def test_query_same_name_different_areas(
assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER
assert len(result.response.matched_states) == 1 assert len(result.response.matched_states) == 1
assert result.response.matched_states[0].entity_id == kitchen_light.entity_id assert result.response.matched_states[0].entity_id == kitchen_light.entity_id
@pytest.mark.usefixtures("init_components")
async def test_intent_cache_exposed(hass: HomeAssistant) -> None:
"""Test that intent recognition results are cached for exposed entities."""
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
entity_id = "light.test_light"
hass.states.async_set(entity_id, "off")
expose_entity(hass, entity_id, True)
await hass.async_block_till_done()
user_input = ConversationInput(
text="turn on test light",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
agent_id=None,
)
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert result.entities["name"].text == "test light"
# Mark this result so we know it is from cache next time
mark = "_from_cache"
setattr(result, mark, True)
# Should be from cache this time
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert getattr(result, mark, None) is True
# Unexposing clears the cache
expose_entity(hass, entity_id, False)
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert getattr(result, mark, None) is None
@pytest.mark.usefixtures("init_components")
async def test_intent_cache_all_entities(hass: HomeAssistant) -> None:
"""Test that intent recognition results are cached for all entities."""
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
entity_id = "light.test_light"
hass.states.async_set(entity_id, "off")
expose_entity(hass, entity_id, False) # not exposed
await hass.async_block_till_done()
user_input = ConversationInput(
text="turn on test light",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
agent_id=None,
)
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert result.entities["name"].text == "test light"
# Mark this result so we know it is from cache next time
mark = "_from_cache"
setattr(result, mark, True)
# Should be from cache this time
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert getattr(result, mark, None) is True
# Adding a new entity clears the cache
hass.states.async_set("light.new_light", "off")
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert getattr(result, mark, None) is None
@pytest.mark.usefixtures("init_components")
async def test_intent_cache_fuzzy(hass: HomeAssistant) -> None:
"""Test that intent recognition results are cached for fuzzy matches."""
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
# There is no entity named test light
user_input = ConversationInput(
text="turn on test light",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
agent_id=None,
)
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert result.unmatched_entities["name"].text == "test light"
# Mark this result so we know it is from cache next time
mark = "_from_cache"
setattr(result, mark, True)
# Should be from cache this time
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert getattr(result, mark, None) is True