mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Cache the names and area lists in the default agent (#86874)
* Cache the names and area lists in the default agent fixes #86803 * add coverage to make sure the entity cache busts * add areas test * cover the last line
This commit is contained in:
parent
eebc338c3b
commit
691a234090
@ -71,6 +71,8 @@ class DefaultAgent(AbstractConversationAgent):
|
|||||||
|
|
||||||
# intent -> [sentences]
|
# intent -> [sentences]
|
||||||
self._config_intents: dict[str, Any] = {}
|
self._config_intents: dict[str, Any] = {}
|
||||||
|
self._areas_list: TextSlotList | None = None
|
||||||
|
self._names_list: TextSlotList | None = None
|
||||||
|
|
||||||
async def async_initialize(self, config_intents):
|
async def async_initialize(self, config_intents):
|
||||||
"""Initialize the default agent."""
|
"""Initialize the default agent."""
|
||||||
@ -81,6 +83,22 @@ class DefaultAgent(AbstractConversationAgent):
|
|||||||
if config_intents:
|
if config_intents:
|
||||||
self._config_intents = config_intents
|
self._config_intents = config_intents
|
||||||
|
|
||||||
|
self.hass.bus.async_listen(
|
||||||
|
area_registry.EVENT_AREA_REGISTRY_UPDATED,
|
||||||
|
self._async_handle_area_registry_changed,
|
||||||
|
run_immediately=True,
|
||||||
|
)
|
||||||
|
self.hass.bus.async_listen(
|
||||||
|
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
|
||||||
|
self._async_handle_entity_registry_changed,
|
||||||
|
run_immediately=True,
|
||||||
|
)
|
||||||
|
self.hass.bus.async_listen(
|
||||||
|
core.EVENT_STATE_CHANGED,
|
||||||
|
self._async_handle_state_changed,
|
||||||
|
run_immediately=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
language = user_input.language or self.hass.config.language
|
language = user_input.language or self.hass.config.language
|
||||||
@ -312,8 +330,29 @@ class DefaultAgent(AbstractConversationAgent):
|
|||||||
|
|
||||||
return lang_intents
|
return lang_intents
|
||||||
|
|
||||||
|
@core.callback
|
||||||
|
def _async_handle_area_registry_changed(self, event: core.Event) -> None:
|
||||||
|
"""Clear area area cache when the area registry has changed."""
|
||||||
|
self._areas_list = None
|
||||||
|
|
||||||
|
@core.callback
|
||||||
|
def _async_handle_entity_registry_changed(self, event: core.Event) -> None:
|
||||||
|
"""Clear names list cache when an entity changes aliases."""
|
||||||
|
if event.data["action"] == "update" and "aliases" not in event.data["changes"]:
|
||||||
|
return
|
||||||
|
self._names_list = None
|
||||||
|
|
||||||
|
@core.callback
|
||||||
|
def _async_handle_state_changed(self, event: core.Event) -> None:
|
||||||
|
"""Clear names list cache when a state is added or removed from the state machine."""
|
||||||
|
if event.data.get("old_state") and event.data.get("new_state"):
|
||||||
|
return
|
||||||
|
self._names_list = None
|
||||||
|
|
||||||
def _make_areas_list(self) -> TextSlotList:
|
def _make_areas_list(self) -> TextSlotList:
|
||||||
"""Create slot list mapping area names/aliases to area ids."""
|
"""Create slot list mapping area names/aliases to area ids."""
|
||||||
|
if self._areas_list is not None:
|
||||||
|
return self._areas_list
|
||||||
registry = area_registry.async_get(self.hass)
|
registry = area_registry.async_get(self.hass)
|
||||||
areas = []
|
areas = []
|
||||||
for entry in registry.async_list_areas():
|
for entry in registry.async_list_areas():
|
||||||
@ -322,16 +361,18 @@ class DefaultAgent(AbstractConversationAgent):
|
|||||||
for alias in entry.aliases:
|
for alias in entry.aliases:
|
||||||
areas.append((alias, entry.id))
|
areas.append((alias, entry.id))
|
||||||
|
|
||||||
return TextSlotList.from_tuples(areas)
|
self._areas_list = TextSlotList.from_tuples(areas)
|
||||||
|
return self._areas_list
|
||||||
|
|
||||||
def _make_names_list(self) -> TextSlotList:
|
def _make_names_list(self) -> TextSlotList:
|
||||||
"""Create slot list mapping entity names/aliases to entity ids."""
|
"""Create slot list mapping entity names/aliases to entity ids."""
|
||||||
|
if self._names_list is not None:
|
||||||
|
return self._names_list
|
||||||
states = self.hass.states.async_all()
|
states = self.hass.states.async_all()
|
||||||
registry = entity_registry.async_get(self.hass)
|
registry = entity_registry.async_get(self.hass)
|
||||||
names = []
|
names = []
|
||||||
for state in states:
|
for state in states:
|
||||||
domain = state.entity_id.split(".", maxsplit=1)[0]
|
context = {"domain": state.domain}
|
||||||
context = {"domain": domain}
|
|
||||||
|
|
||||||
entry = registry.async_get(state.entity_id)
|
entry = registry.async_get(state.entity_id)
|
||||||
if entry is not None:
|
if entry is not None:
|
||||||
@ -346,7 +387,8 @@ class DefaultAgent(AbstractConversationAgent):
|
|||||||
# Default name
|
# Default name
|
||||||
names.append((state.name, state.entity_id, context))
|
names.append((state.name, state.entity_id, context))
|
||||||
|
|
||||||
return TextSlotList.from_tuples(names)
|
self._names_list = TextSlotList.from_tuples(names)
|
||||||
|
return self._names_list
|
||||||
|
|
||||||
def _get_error_text(
|
def _get_error_text(
|
||||||
self, response_type: ResponseType, lang_intents: LanguageIntents
|
self, response_type: ResponseType, lang_intents: LanguageIntents
|
||||||
|
@ -7,10 +7,15 @@ import pytest
|
|||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
||||||
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context
|
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context
|
||||||
from homeassistant.helpers import entity_registry, intent
|
from homeassistant.helpers import (
|
||||||
|
area_registry,
|
||||||
|
device_registry,
|
||||||
|
entity_registry,
|
||||||
|
intent,
|
||||||
|
)
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import async_mock_service
|
from tests.common import MockConfigEntry, async_mock_service
|
||||||
|
|
||||||
|
|
||||||
class OrderBeerIntentHandler(intent.IntentHandler):
|
class OrderBeerIntentHandler(intent.IntentHandler):
|
||||||
@ -75,6 +80,143 @@ async def test_http_processing_intent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_http_processing_intent_entity_added(
|
||||||
|
hass, init_components, hass_client, hass_admin_user
|
||||||
|
):
|
||||||
|
"""Test processing intent via HTTP API with entities added later.
|
||||||
|
|
||||||
|
We want to ensure that adding an entity later busts the cache
|
||||||
|
so that the new entity is available as well as any aliases.
|
||||||
|
"""
|
||||||
|
er = entity_registry.async_get(hass)
|
||||||
|
er.async_get_or_create("light", "demo", "1234", suggested_object_id="kitchen")
|
||||||
|
er.async_update_entity("light.kitchen", aliases={"my cool light"})
|
||||||
|
hass.states.async_set("light.kitchen", "off")
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/conversation/process", json={"text": "turn on my cool light"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
assert data == {
|
||||||
|
"response": {
|
||||||
|
"response_type": "action_done",
|
||||||
|
"card": {},
|
||||||
|
"speech": {
|
||||||
|
"plain": {
|
||||||
|
"extra_data": None,
|
||||||
|
"speech": "Turned on my cool light",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"language": hass.config.language,
|
||||||
|
"data": {
|
||||||
|
"targets": [],
|
||||||
|
"success": [
|
||||||
|
{"id": "light.kitchen", "name": "kitchen", "type": "entity"}
|
||||||
|
],
|
||||||
|
"failed": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"conversation_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add an alias
|
||||||
|
er.async_get_or_create("light", "demo", "5678", suggested_object_id="late")
|
||||||
|
hass.states.async_set("light.late", "off", {"friendly_name": "friendly light"})
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/conversation/process", json={"text": "turn on friendly light"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
assert data == {
|
||||||
|
"response": {
|
||||||
|
"response_type": "action_done",
|
||||||
|
"card": {},
|
||||||
|
"speech": {
|
||||||
|
"plain": {
|
||||||
|
"extra_data": None,
|
||||||
|
"speech": "Turned on friendly light",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"language": hass.config.language,
|
||||||
|
"data": {
|
||||||
|
"targets": [],
|
||||||
|
"success": [
|
||||||
|
{"id": "light.late", "name": "friendly light", "type": "entity"}
|
||||||
|
],
|
||||||
|
"failed": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"conversation_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Now add an alias
|
||||||
|
er.async_update_entity("light.late", aliases={"late added light"})
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/conversation/process", json={"text": "turn on late added light"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
assert data == {
|
||||||
|
"response": {
|
||||||
|
"response_type": "action_done",
|
||||||
|
"card": {},
|
||||||
|
"speech": {
|
||||||
|
"plain": {
|
||||||
|
"extra_data": None,
|
||||||
|
"speech": "Turned on late added light",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"language": hass.config.language,
|
||||||
|
"data": {
|
||||||
|
"targets": [],
|
||||||
|
"success": [
|
||||||
|
{"id": "light.late", "name": "friendly light", "type": "entity"}
|
||||||
|
],
|
||||||
|
"failed": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"conversation_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Now delete the entity
|
||||||
|
er.async_remove("light.late")
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/conversation/process", json={"text": "turn on late added light"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
data = await resp.json()
|
||||||
|
assert data == {
|
||||||
|
"conversation_id": None,
|
||||||
|
"response": {
|
||||||
|
"card": {},
|
||||||
|
"data": {"code": "no_intent_match"},
|
||||||
|
"language": hass.config.language,
|
||||||
|
"response_type": "error",
|
||||||
|
"speech": {
|
||||||
|
"plain": {
|
||||||
|
"extra_data": None,
|
||||||
|
"speech": "Sorry, I couldn't understand " "that",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sentence", ("turn on kitchen", "turn kitchen on"))
|
@pytest.mark.parametrize("sentence", ("turn on kitchen", "turn kitchen on"))
|
||||||
async def test_turn_on_intent(hass, init_components, sentence):
|
async def test_turn_on_intent(hass, init_components, sentence):
|
||||||
"""Test calling the turn on intent."""
|
"""Test calling the turn on intent."""
|
||||||
@ -569,3 +711,69 @@ async def test_non_default_response(hass, init_components):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert result.response.speech["plain"]["speech"] == "Opened front door"
|
assert result.response.speech["plain"]["speech"] == "Opened front door"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_turn_on_area(hass, init_components):
|
||||||
|
"""Test turning on an area."""
|
||||||
|
er = entity_registry.async_get(hass)
|
||||||
|
dr = device_registry.async_get(hass)
|
||||||
|
ar = area_registry.async_get(hass)
|
||||||
|
entry = MockConfigEntry(domain="test")
|
||||||
|
|
||||||
|
device = dr.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||||
|
)
|
||||||
|
|
||||||
|
kitchen_area = ar.async_create("kitchen")
|
||||||
|
dr.async_update_device(device.id, area_id=kitchen_area.id)
|
||||||
|
|
||||||
|
er.async_get_or_create("light", "demo", "1234", suggested_object_id="stove")
|
||||||
|
er.async_update_entity(
|
||||||
|
"light.stove", aliases={"my stove light"}, area_id=kitchen_area.id
|
||||||
|
)
|
||||||
|
hass.states.async_set("light.stove", "off")
|
||||||
|
|
||||||
|
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{conversation.ATTR_TEXT: "turn on lights in the kitchen"},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
call = calls[0]
|
||||||
|
assert call.domain == HASS_DOMAIN
|
||||||
|
assert call.service == "turn_on"
|
||||||
|
assert call.data == {"entity_id": "light.stove"}
|
||||||
|
|
||||||
|
basement_area = ar.async_create("basement")
|
||||||
|
dr.async_update_device(device.id, area_id=basement_area.id)
|
||||||
|
er.async_update_entity("light.stove", area_id=basement_area.id)
|
||||||
|
calls.clear()
|
||||||
|
|
||||||
|
# Test that the area is updated
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{conversation.ATTR_TEXT: "turn on lights in the kitchen"},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(calls) == 0
|
||||||
|
|
||||||
|
# Test the new area works
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{conversation.ATTR_TEXT: "turn on lights in the basement"},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
call = calls[0]
|
||||||
|
assert call.domain == HASS_DOMAIN
|
||||||
|
assert call.service == "turn_on"
|
||||||
|
assert call.data == {"entity_id": "light.stove"}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user