Cache the names and area lists in the default agent (#86874)

* Cache the names and area lists in the default agent

fixes #86803

* add coverage to make sure the entity cache busts

* add areas test

* cover the last line
This commit is contained in:
J. Nick Koston 2023-01-29 02:16:29 -10:00 committed by GitHub
parent eebc338c3b
commit 691a234090
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 256 additions and 6 deletions

View File

@ -71,6 +71,8 @@ class DefaultAgent(AbstractConversationAgent):
# intent -> [sentences] # intent -> [sentences]
self._config_intents: dict[str, Any] = {} self._config_intents: dict[str, Any] = {}
self._areas_list: TextSlotList | None = None
self._names_list: TextSlotList | None = None
async def async_initialize(self, config_intents): async def async_initialize(self, config_intents):
"""Initialize the default agent.""" """Initialize the default agent."""
@ -81,6 +83,22 @@ class DefaultAgent(AbstractConversationAgent):
if config_intents: if config_intents:
self._config_intents = config_intents self._config_intents = config_intents
self.hass.bus.async_listen(
area_registry.EVENT_AREA_REGISTRY_UPDATED,
self._async_handle_area_registry_changed,
run_immediately=True,
)
self.hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
self._async_handle_entity_registry_changed,
run_immediately=True,
)
self.hass.bus.async_listen(
core.EVENT_STATE_CHANGED,
self._async_handle_state_changed,
run_immediately=True,
)
async def async_process(self, user_input: ConversationInput) -> ConversationResult: async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence.""" """Process a sentence."""
language = user_input.language or self.hass.config.language language = user_input.language or self.hass.config.language
@ -312,8 +330,29 @@ class DefaultAgent(AbstractConversationAgent):
return lang_intents return lang_intents
@core.callback
def _async_handle_area_registry_changed(self, event: core.Event) -> None:
"""Clear area area cache when the area registry has changed."""
self._areas_list = None
@core.callback
def _async_handle_entity_registry_changed(self, event: core.Event) -> None:
"""Clear names list cache when an entity changes aliases."""
if event.data["action"] == "update" and "aliases" not in event.data["changes"]:
return
self._names_list = None
@core.callback
def _async_handle_state_changed(self, event: core.Event) -> None:
"""Clear names list cache when a state is added or removed from the state machine."""
if event.data.get("old_state") and event.data.get("new_state"):
return
self._names_list = None
def _make_areas_list(self) -> TextSlotList: def _make_areas_list(self) -> TextSlotList:
"""Create slot list mapping area names/aliases to area ids.""" """Create slot list mapping area names/aliases to area ids."""
if self._areas_list is not None:
return self._areas_list
registry = area_registry.async_get(self.hass) registry = area_registry.async_get(self.hass)
areas = [] areas = []
for entry in registry.async_list_areas(): for entry in registry.async_list_areas():
@ -322,16 +361,18 @@ class DefaultAgent(AbstractConversationAgent):
for alias in entry.aliases: for alias in entry.aliases:
areas.append((alias, entry.id)) areas.append((alias, entry.id))
return TextSlotList.from_tuples(areas) self._areas_list = TextSlotList.from_tuples(areas)
return self._areas_list
def _make_names_list(self) -> TextSlotList: def _make_names_list(self) -> TextSlotList:
"""Create slot list mapping entity names/aliases to entity ids.""" """Create slot list mapping entity names/aliases to entity ids."""
if self._names_list is not None:
return self._names_list
states = self.hass.states.async_all() states = self.hass.states.async_all()
registry = entity_registry.async_get(self.hass) registry = entity_registry.async_get(self.hass)
names = [] names = []
for state in states: for state in states:
domain = state.entity_id.split(".", maxsplit=1)[0] context = {"domain": state.domain}
context = {"domain": domain}
entry = registry.async_get(state.entity_id) entry = registry.async_get(state.entity_id)
if entry is not None: if entry is not None:
@ -346,7 +387,8 @@ class DefaultAgent(AbstractConversationAgent):
# Default name # Default name
names.append((state.name, state.entity_id, context)) names.append((state.name, state.entity_id, context))
return TextSlotList.from_tuples(names) self._names_list = TextSlotList.from_tuples(names)
return self._names_list
def _get_error_text( def _get_error_text(
self, response_type: ResponseType, lang_intents: LanguageIntents self, response_type: ResponseType, lang_intents: LanguageIntents

View File

@ -7,10 +7,15 @@ import pytest
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.cover import SERVICE_OPEN_COVER from homeassistant.components.cover import SERVICE_OPEN_COVER
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context from homeassistant.core import DOMAIN as HASS_DOMAIN, Context
from homeassistant.helpers import entity_registry, intent from homeassistant.helpers import (
area_registry,
device_registry,
entity_registry,
intent,
)
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import async_mock_service from tests.common import MockConfigEntry, async_mock_service
class OrderBeerIntentHandler(intent.IntentHandler): class OrderBeerIntentHandler(intent.IntentHandler):
@ -75,6 +80,143 @@ async def test_http_processing_intent(
} }
async def test_http_processing_intent_entity_added(
hass, init_components, hass_client, hass_admin_user
):
"""Test processing intent via HTTP API with entities added later.
We want to ensure that adding an entity later busts the cache
so that the new entity is available as well as any aliases.
"""
er = entity_registry.async_get(hass)
er.async_get_or_create("light", "demo", "1234", suggested_object_id="kitchen")
er.async_update_entity("light.kitchen", aliases={"my cool light"})
hass.states.async_set("light.kitchen", "off")
client = await hass_client()
resp = await client.post(
"/api/conversation/process", json={"text": "turn on my cool light"}
)
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert data == {
"response": {
"response_type": "action_done",
"card": {},
"speech": {
"plain": {
"extra_data": None,
"speech": "Turned on my cool light",
}
},
"language": hass.config.language,
"data": {
"targets": [],
"success": [
{"id": "light.kitchen", "name": "kitchen", "type": "entity"}
],
"failed": [],
},
},
"conversation_id": None,
}
# Add an alias
er.async_get_or_create("light", "demo", "5678", suggested_object_id="late")
hass.states.async_set("light.late", "off", {"friendly_name": "friendly light"})
client = await hass_client()
resp = await client.post(
"/api/conversation/process", json={"text": "turn on friendly light"}
)
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert data == {
"response": {
"response_type": "action_done",
"card": {},
"speech": {
"plain": {
"extra_data": None,
"speech": "Turned on friendly light",
}
},
"language": hass.config.language,
"data": {
"targets": [],
"success": [
{"id": "light.late", "name": "friendly light", "type": "entity"}
],
"failed": [],
},
},
"conversation_id": None,
}
# Now add an alias
er.async_update_entity("light.late", aliases={"late added light"})
client = await hass_client()
resp = await client.post(
"/api/conversation/process", json={"text": "turn on late added light"}
)
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert data == {
"response": {
"response_type": "action_done",
"card": {},
"speech": {
"plain": {
"extra_data": None,
"speech": "Turned on late added light",
}
},
"language": hass.config.language,
"data": {
"targets": [],
"success": [
{"id": "light.late", "name": "friendly light", "type": "entity"}
],
"failed": [],
},
},
"conversation_id": None,
}
# Now delete the entity
er.async_remove("light.late")
client = await hass_client()
resp = await client.post(
"/api/conversation/process", json={"text": "turn on late added light"}
)
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert data == {
"conversation_id": None,
"response": {
"card": {},
"data": {"code": "no_intent_match"},
"language": hass.config.language,
"response_type": "error",
"speech": {
"plain": {
"extra_data": None,
"speech": "Sorry, I couldn't understand " "that",
}
},
},
}
@pytest.mark.parametrize("sentence", ("turn on kitchen", "turn kitchen on")) @pytest.mark.parametrize("sentence", ("turn on kitchen", "turn kitchen on"))
async def test_turn_on_intent(hass, init_components, sentence): async def test_turn_on_intent(hass, init_components, sentence):
"""Test calling the turn on intent.""" """Test calling the turn on intent."""
@ -569,3 +711,69 @@ async def test_non_default_response(hass, init_components):
) )
) )
assert result.response.speech["plain"]["speech"] == "Opened front door" assert result.response.speech["plain"]["speech"] == "Opened front door"
async def test_turn_on_area(hass, init_components):
"""Test turning on an area."""
er = entity_registry.async_get(hass)
dr = device_registry.async_get(hass)
ar = area_registry.async_get(hass)
entry = MockConfigEntry(domain="test")
device = dr.async_get_or_create(
config_entry_id=entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
kitchen_area = ar.async_create("kitchen")
dr.async_update_device(device.id, area_id=kitchen_area.id)
er.async_get_or_create("light", "demo", "1234", suggested_object_id="stove")
er.async_update_entity(
"light.stove", aliases={"my stove light"}, area_id=kitchen_area.id
)
hass.states.async_set("light.stove", "off")
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
await hass.services.async_call(
"conversation",
"process",
{conversation.ATTR_TEXT: "turn on lights in the kitchen"},
)
await hass.async_block_till_done()
assert len(calls) == 1
call = calls[0]
assert call.domain == HASS_DOMAIN
assert call.service == "turn_on"
assert call.data == {"entity_id": "light.stove"}
basement_area = ar.async_create("basement")
dr.async_update_device(device.id, area_id=basement_area.id)
er.async_update_entity("light.stove", area_id=basement_area.id)
calls.clear()
# Test that the area is updated
await hass.services.async_call(
"conversation",
"process",
{conversation.ATTR_TEXT: "turn on lights in the kitchen"},
)
await hass.async_block_till_done()
assert len(calls) == 0
# Test the new area works
await hass.services.async_call(
"conversation",
"process",
{conversation.ATTR_TEXT: "turn on lights in the basement"},
)
await hass.async_block_till_done()
assert len(calls) == 1
call = calls[0]
assert call.domain == HASS_DOMAIN
assert call.service == "turn_on"
assert call.data == {"entity_id": "light.stove"}