mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Cache intent recognition results (#131114)
This commit is contained in:
parent
8f9095ba67
commit
f47840d83c
@ -3,8 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
import functools
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@ -102,6 +104,77 @@ class SentenceTriggerResult:
|
||||
matched_triggers: dict[int, RecognizeResult]
|
||||
|
||||
|
||||
class IntentMatchingStage(Enum):
|
||||
"""Stages of intent matching."""
|
||||
|
||||
EXPOSED_ENTITIES_ONLY = auto()
|
||||
"""Match against exposed entities only."""
|
||||
|
||||
ALL_ENTITIES = auto()
|
||||
"""Match against all entities in Home Assistant."""
|
||||
|
||||
FUZZY = auto()
|
||||
"""Capture names that are not known to Home Assistant."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IntentCacheKey:
|
||||
"""Key for IntentCache."""
|
||||
|
||||
text: str
|
||||
"""User input text."""
|
||||
|
||||
language: str
|
||||
"""Language of text."""
|
||||
|
||||
device_id: str | None
|
||||
"""Device id from user input."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IntentCacheValue:
|
||||
"""Value for IntentCache."""
|
||||
|
||||
result: RecognizeResult | None
|
||||
"""Result of intent recognition."""
|
||||
|
||||
stage: IntentMatchingStage
|
||||
"""Stage where result was found."""
|
||||
|
||||
|
||||
class IntentCache:
|
||||
"""LRU cache for intent recognition results."""
|
||||
|
||||
def __init__(self, capacity: int) -> None:
|
||||
"""Initialize cache."""
|
||||
self.cache: OrderedDict[IntentCacheKey, IntentCacheValue] = OrderedDict()
|
||||
self.capacity = capacity
|
||||
|
||||
def get(self, key: IntentCacheKey) -> IntentCacheValue | None:
|
||||
"""Get value for cache or None."""
|
||||
if key not in self.cache:
|
||||
return None
|
||||
|
||||
# Move the key to the end to show it was recently used
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key: IntentCacheKey, value: IntentCacheValue) -> None:
|
||||
"""Put a value in the cache, evicting the least recently used item if necessary."""
|
||||
if key in self.cache:
|
||||
# Update value and mark as recently used
|
||||
self.cache.move_to_end(key)
|
||||
elif len(self.cache) >= self.capacity:
|
||||
# Evict the oldest item
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
self.cache[key] = value
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the cache."""
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def _get_language_variations(language: str) -> Iterable[str]:
|
||||
"""Generate language codes with and without region."""
|
||||
yield language
|
||||
@ -160,6 +233,7 @@ class DefaultAgent(ConversationEntity):
|
||||
# intent -> [sentences]
|
||||
self._config_intents: dict[str, Any] = config_intents
|
||||
self._slot_lists: dict[str, SlotList] | None = None
|
||||
self._all_entity_names: TextSlotList | None = None
|
||||
|
||||
# Sentences that will trigger a callback (skipping intent recognition)
|
||||
self._trigger_sentences: list[TriggerData] = []
|
||||
@ -167,6 +241,9 @@ class DefaultAgent(ConversationEntity):
|
||||
self._unsub_clear_slot_list: list[Callable[[], None]] | None = None
|
||||
self._load_intents_lock = asyncio.Lock()
|
||||
|
||||
# LRU cache to avoid unnecessary intent matching
|
||||
self._intent_cache = IntentCache(capacity=128)
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
@ -417,18 +494,200 @@ class DefaultAgent(ConversationEntity):
|
||||
strict_intents_only: bool,
|
||||
) -> RecognizeResult | None:
|
||||
"""Search intents for a match to user input."""
|
||||
strict_result = self._recognize_strict(
|
||||
user_input, lang_intents, slot_lists, intent_context, language
|
||||
)
|
||||
skip_exposed_match = False
|
||||
|
||||
if strict_result is not None:
|
||||
# Successful strict match
|
||||
return strict_result
|
||||
# Try cache first
|
||||
cache_key = IntentCacheKey(
|
||||
text=user_input.text, language=language, device_id=user_input.device_id
|
||||
)
|
||||
cache_value = self._intent_cache.get(cache_key)
|
||||
if cache_value is not None:
|
||||
if (cache_value.result is not None) and (
|
||||
cache_value.stage == IntentMatchingStage.EXPOSED_ENTITIES_ONLY
|
||||
):
|
||||
_LOGGER.debug("Got cached result for exposed entities")
|
||||
return cache_value.result
|
||||
|
||||
# Continue with matching, but we know we won't succeed for exposed
|
||||
# entities only.
|
||||
skip_exposed_match = True
|
||||
|
||||
if not skip_exposed_match:
|
||||
start_time = time.monotonic()
|
||||
strict_result = self._recognize_strict(
|
||||
user_input, lang_intents, slot_lists, intent_context, language
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"Checked exposed entities in %s second(s)",
|
||||
time.monotonic() - start_time,
|
||||
)
|
||||
|
||||
# Update cache
|
||||
self._intent_cache.put(
|
||||
cache_key,
|
||||
IntentCacheValue(
|
||||
result=strict_result,
|
||||
stage=IntentMatchingStage.EXPOSED_ENTITIES_ONLY,
|
||||
),
|
||||
)
|
||||
|
||||
if strict_result is not None:
|
||||
# Successful strict match with exposed entities
|
||||
return strict_result
|
||||
|
||||
if strict_intents_only:
|
||||
# Don't try matching against all entities or doing a fuzzy match
|
||||
return None
|
||||
|
||||
# Try again with all entities (including unexposed)
|
||||
skip_all_entities_match = False
|
||||
if cache_value is not None:
|
||||
if (cache_value.result is not None) and (
|
||||
cache_value.stage == IntentMatchingStage.ALL_ENTITIES
|
||||
):
|
||||
_LOGGER.debug("Got cached result for all entities")
|
||||
return cache_value.result
|
||||
|
||||
# Continue with matching, but we know we won't succeed for all
|
||||
# entities.
|
||||
skip_all_entities_match = True
|
||||
|
||||
if not skip_all_entities_match:
|
||||
all_entities_slot_lists = {
|
||||
**slot_lists,
|
||||
"name": self._get_all_entity_names(),
|
||||
}
|
||||
|
||||
start_time = time.monotonic()
|
||||
strict_result = self._recognize_strict(
|
||||
user_input,
|
||||
lang_intents,
|
||||
all_entities_slot_lists,
|
||||
intent_context,
|
||||
language,
|
||||
)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Checked all entities in %s second(s)", time.monotonic() - start_time
|
||||
)
|
||||
|
||||
# Update cache
|
||||
self._intent_cache.put(
|
||||
cache_key,
|
||||
IntentCacheValue(
|
||||
result=strict_result, stage=IntentMatchingStage.ALL_ENTITIES
|
||||
),
|
||||
)
|
||||
|
||||
if strict_result is not None:
|
||||
# Not a successful match, but useful for an error message.
|
||||
# This should fail the intent handling phase (async_match_targets).
|
||||
return strict_result
|
||||
|
||||
# Try again with missing entities enabled
|
||||
skip_fuzzy_match = False
|
||||
if cache_value is not None:
|
||||
if (cache_value.result is not None) and (
|
||||
cache_value.stage == IntentMatchingStage.FUZZY
|
||||
):
|
||||
_LOGGER.debug("Got cached result for fuzzy match")
|
||||
return cache_value.result
|
||||
|
||||
# We know we won't succeed for fuzzy matching.
|
||||
skip_fuzzy_match = True
|
||||
|
||||
maybe_result: RecognizeResult | None = None
|
||||
if not skip_fuzzy_match:
|
||||
start_time = time.monotonic()
|
||||
best_num_matched_entities = 0
|
||||
best_num_unmatched_entities = 0
|
||||
best_num_unmatched_ranges = 0
|
||||
for result in recognize_all(
|
||||
user_input.text,
|
||||
lang_intents.intents,
|
||||
slot_lists=slot_lists,
|
||||
intent_context=intent_context,
|
||||
allow_unmatched_entities=True,
|
||||
):
|
||||
if result.text_chunks_matched < 1:
|
||||
# Skip results that don't match any literal text
|
||||
continue
|
||||
|
||||
# Don't count missing entities that couldn't be filled from context
|
||||
num_matched_entities = 0
|
||||
for matched_entity in result.entities_list:
|
||||
if matched_entity.name not in result.unmatched_entities:
|
||||
num_matched_entities += 1
|
||||
|
||||
num_unmatched_entities = 0
|
||||
num_unmatched_ranges = 0
|
||||
for unmatched_entity in result.unmatched_entities_list:
|
||||
if isinstance(unmatched_entity, UnmatchedTextEntity):
|
||||
if unmatched_entity.text != MISSING_ENTITY:
|
||||
num_unmatched_entities += 1
|
||||
elif isinstance(unmatched_entity, UnmatchedRangeEntity):
|
||||
num_unmatched_ranges += 1
|
||||
num_unmatched_entities += 1
|
||||
else:
|
||||
num_unmatched_entities += 1
|
||||
|
||||
if (
|
||||
(maybe_result is None) # first result
|
||||
or (num_matched_entities > best_num_matched_entities)
|
||||
or (
|
||||
# Fewer unmatched entities
|
||||
(num_matched_entities == best_num_matched_entities)
|
||||
and (num_unmatched_entities < best_num_unmatched_entities)
|
||||
)
|
||||
or (
|
||||
# Prefer unmatched ranges
|
||||
(num_matched_entities == best_num_matched_entities)
|
||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||
and (num_unmatched_ranges > best_num_unmatched_ranges)
|
||||
)
|
||||
or (
|
||||
# More literal text matched
|
||||
(num_matched_entities == best_num_matched_entities)
|
||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
||||
and (
|
||||
result.text_chunks_matched
|
||||
> maybe_result.text_chunks_matched
|
||||
)
|
||||
)
|
||||
or (
|
||||
# Prefer match failures with entities
|
||||
(result.text_chunks_matched == maybe_result.text_chunks_matched)
|
||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
||||
and (
|
||||
("name" in result.entities)
|
||||
or ("name" in result.unmatched_entities)
|
||||
)
|
||||
)
|
||||
):
|
||||
maybe_result = result
|
||||
best_num_matched_entities = num_matched_entities
|
||||
best_num_unmatched_entities = num_unmatched_entities
|
||||
best_num_unmatched_ranges = num_unmatched_ranges
|
||||
|
||||
# Update cache
|
||||
self._intent_cache.put(
|
||||
cache_key,
|
||||
IntentCacheValue(result=maybe_result, stage=IntentMatchingStage.FUZZY),
|
||||
)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Did fuzzy match in %s second(s)", time.monotonic() - start_time
|
||||
)
|
||||
|
||||
return maybe_result
|
||||
|
||||
def _get_all_entity_names(self) -> TextSlotList:
|
||||
"""Get slot list with all entity names in Home Assistant."""
|
||||
if self._all_entity_names is not None:
|
||||
return self._all_entity_names
|
||||
|
||||
entity_registry = er.async_get(self.hass)
|
||||
all_entity_names: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
@ -459,96 +718,10 @@ class DefaultAgent(ConversationEntity):
|
||||
# Default name
|
||||
all_entity_names.append((state.name, state.name, context))
|
||||
|
||||
slot_lists = {
|
||||
**slot_lists,
|
||||
"name": TextSlotList.from_tuples(all_entity_names, allow_template=False),
|
||||
}
|
||||
|
||||
strict_result = self._recognize_strict(
|
||||
user_input,
|
||||
lang_intents,
|
||||
slot_lists,
|
||||
intent_context,
|
||||
language,
|
||||
self._all_entity_names = TextSlotList.from_tuples(
|
||||
all_entity_names, allow_template=False
|
||||
)
|
||||
|
||||
if strict_result is not None:
|
||||
# Not a successful match, but useful for an error message.
|
||||
# This should fail the intent handling phase (async_match_targets).
|
||||
return strict_result
|
||||
|
||||
# Try again with missing entities enabled
|
||||
maybe_result: RecognizeResult | None = None
|
||||
best_num_matched_entities = 0
|
||||
best_num_unmatched_entities = 0
|
||||
best_num_unmatched_ranges = 0
|
||||
for result in recognize_all(
|
||||
user_input.text,
|
||||
lang_intents.intents,
|
||||
slot_lists=slot_lists,
|
||||
intent_context=intent_context,
|
||||
allow_unmatched_entities=True,
|
||||
):
|
||||
if result.text_chunks_matched < 1:
|
||||
# Skip results that don't match any literal text
|
||||
continue
|
||||
|
||||
# Don't count missing entities that couldn't be filled from context
|
||||
num_matched_entities = 0
|
||||
for matched_entity in result.entities_list:
|
||||
if matched_entity.name not in result.unmatched_entities:
|
||||
num_matched_entities += 1
|
||||
|
||||
num_unmatched_entities = 0
|
||||
num_unmatched_ranges = 0
|
||||
for unmatched_entity in result.unmatched_entities_list:
|
||||
if isinstance(unmatched_entity, UnmatchedTextEntity):
|
||||
if unmatched_entity.text != MISSING_ENTITY:
|
||||
num_unmatched_entities += 1
|
||||
elif isinstance(unmatched_entity, UnmatchedRangeEntity):
|
||||
num_unmatched_ranges += 1
|
||||
num_unmatched_entities += 1
|
||||
else:
|
||||
num_unmatched_entities += 1
|
||||
|
||||
if (
|
||||
(maybe_result is None) # first result
|
||||
or (num_matched_entities > best_num_matched_entities)
|
||||
or (
|
||||
# Fewer unmatched entities
|
||||
(num_matched_entities == best_num_matched_entities)
|
||||
and (num_unmatched_entities < best_num_unmatched_entities)
|
||||
)
|
||||
or (
|
||||
# Prefer unmatched ranges
|
||||
(num_matched_entities == best_num_matched_entities)
|
||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||
and (num_unmatched_ranges > best_num_unmatched_ranges)
|
||||
)
|
||||
or (
|
||||
# More literal text matched
|
||||
(num_matched_entities == best_num_matched_entities)
|
||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
||||
and (result.text_chunks_matched > maybe_result.text_chunks_matched)
|
||||
)
|
||||
or (
|
||||
# Prefer match failures with entities
|
||||
(result.text_chunks_matched == maybe_result.text_chunks_matched)
|
||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
||||
and (
|
||||
("name" in result.entities)
|
||||
or ("name" in result.unmatched_entities)
|
||||
)
|
||||
)
|
||||
):
|
||||
maybe_result = result
|
||||
best_num_matched_entities = num_matched_entities
|
||||
best_num_unmatched_entities = num_unmatched_entities
|
||||
best_num_unmatched_ranges = num_unmatched_ranges
|
||||
|
||||
return maybe_result
|
||||
return self._all_entity_names
|
||||
|
||||
def _recognize_strict(
|
||||
self,
|
||||
@ -653,6 +826,9 @@ class DefaultAgent(ConversationEntity):
|
||||
self._lang_intents.pop(language, None)
|
||||
_LOGGER.debug("Cleared intents for language: %s", language)
|
||||
|
||||
# Intents have changed, so we must clear the cache
|
||||
self._intent_cache.clear()
|
||||
|
||||
async def async_prepare(self, language: str | None = None) -> None:
|
||||
"""Load intents for a language."""
|
||||
if language is None:
|
||||
@ -837,10 +1013,14 @@ class DefaultAgent(ConversationEntity):
|
||||
if self._unsub_clear_slot_list is None:
|
||||
return
|
||||
self._slot_lists = None
|
||||
self._all_entity_names = None
|
||||
for unsub in self._unsub_clear_slot_list:
|
||||
unsub()
|
||||
self._unsub_clear_slot_list = None
|
||||
|
||||
# Slot lists have changed, so we must clear the cache
|
||||
self._intent_cache.clear()
|
||||
|
||||
@core.callback
|
||||
def _make_slot_lists(self) -> dict[str, SlotList]:
|
||||
"""Create slot lists with areas and entity names/aliases."""
|
||||
|
@ -2833,3 +2833,110 @@ async def test_query_same_name_different_areas(
|
||||
assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER
|
||||
assert len(result.response.matched_states) == 1
|
||||
assert result.response.matched_states[0].entity_id == kitchen_light.entity_id
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("init_components")
|
||||
async def test_intent_cache_exposed(hass: HomeAssistant) -> None:
|
||||
"""Test that intent recognition results are cached for exposed entities."""
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
entity_id = "light.test_light"
|
||||
hass.states.async_set(entity_id, "off")
|
||||
expose_entity(hass, entity_id, True)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
user_input = ConversationInput(
|
||||
text="turn on test light",
|
||||
context=Context(),
|
||||
conversation_id=None,
|
||||
device_id=None,
|
||||
language=hass.config.language,
|
||||
agent_id=None,
|
||||
)
|
||||
result = await agent.async_recognize_intent(user_input)
|
||||
assert result is not None
|
||||
assert result.entities["name"].text == "test light"
|
||||
|
||||
# Mark this result so we know it is from cache next time
|
||||
mark = "_from_cache"
|
||||
setattr(result, mark, True)
|
||||
|
||||
# Should be from cache this time
|
||||
result = await agent.async_recognize_intent(user_input)
|
||||
assert result is not None
|
||||
assert getattr(result, mark, None) is True
|
||||
|
||||
# Unexposing clears the cache
|
||||
expose_entity(hass, entity_id, False)
|
||||
result = await agent.async_recognize_intent(user_input)
|
||||
assert result is not None
|
||||
assert getattr(result, mark, None) is None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("init_components")
|
||||
async def test_intent_cache_all_entities(hass: HomeAssistant) -> None:
|
||||
"""Test that intent recognition results are cached for all entities."""
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
entity_id = "light.test_light"
|
||||
hass.states.async_set(entity_id, "off")
|
||||
expose_entity(hass, entity_id, False) # not exposed
|
||||
await hass.async_block_till_done()
|
||||
|
||||
user_input = ConversationInput(
|
||||
text="turn on test light",
|
||||
context=Context(),
|
||||
conversation_id=None,
|
||||
device_id=None,
|
||||
language=hass.config.language,
|
||||
agent_id=None,
|
||||
)
|
||||
result = await agent.async_recognize_intent(user_input)
|
||||
assert result is not None
|
||||
assert result.entities["name"].text == "test light"
|
||||
|
||||
# Mark this result so we know it is from cache next time
|
||||
mark = "_from_cache"
|
||||
setattr(result, mark, True)
|
||||
|
||||
# Should be from cache this time
|
||||
result = await agent.async_recognize_intent(user_input)
|
||||
assert result is not None
|
||||
assert getattr(result, mark, None) is True
|
||||
|
||||
# Adding a new entity clears the cache
|
||||
hass.states.async_set("light.new_light", "off")
|
||||
result = await agent.async_recognize_intent(user_input)
|
||||
assert result is not None
|
||||
assert getattr(result, mark, None) is None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("init_components")
|
||||
async def test_intent_cache_fuzzy(hass: HomeAssistant) -> None:
|
||||
"""Test that intent recognition results are cached for fuzzy matches."""
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
# There is no entity named test light
|
||||
user_input = ConversationInput(
|
||||
text="turn on test light",
|
||||
context=Context(),
|
||||
conversation_id=None,
|
||||
device_id=None,
|
||||
language=hass.config.language,
|
||||
agent_id=None,
|
||||
)
|
||||
result = await agent.async_recognize_intent(user_input)
|
||||
assert result is not None
|
||||
assert result.unmatched_entities["name"].text == "test light"
|
||||
|
||||
# Mark this result so we know it is from cache next time
|
||||
mark = "_from_cache"
|
||||
setattr(result, mark, True)
|
||||
|
||||
# Should be from cache this time
|
||||
result = await agent.async_recognize_intent(user_input)
|
||||
assert result is not None
|
||||
assert getattr(result, mark, None) is True
|
||||
|
Loading…
x
Reference in New Issue
Block a user