Only expose default cloud domains in Assist default agent (#88274)

* Only expose default cloud domains in default agent

* Copy exposed domain list to conversation

* Implement requested changes

* Add test for exposed devices/areas
This commit is contained in:
Michael Hansen 2023-02-17 15:19:22 -06:00 committed by GitHub
parent 331102e592
commit 325674ec44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 157 additions and 41 deletions

View File

@ -1,3 +1,18 @@
"""Const for conversation integration.""" """Const for conversation integration."""
DOMAIN = "conversation" DOMAIN = "conversation"
DEFAULT_EXPOSED_DOMAINS = {
"climate",
"cover",
"fan",
"humidifier",
"light",
"lock",
"scene",
"script",
"sensor",
"switch",
"vacuum",
"water_heater",
}

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
from pathlib import Path from pathlib import Path
@ -19,6 +19,7 @@ import yaml
from homeassistant import core, setup from homeassistant import core, setup
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry, area_registry,
device_registry,
entity_registry, entity_registry,
intent, intent,
template, template,
@ -27,7 +28,7 @@ from homeassistant.helpers import (
from homeassistant.util.json import JsonObjectType, json_loads_object from homeassistant.util.json import JsonObjectType, json_loads_object
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import DOMAIN from .const import DEFAULT_EXPOSED_DOMAINS, DOMAIN
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that" _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
@ -35,6 +36,11 @@ _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
REGEX_TYPE = type(re.compile("")) REGEX_TYPE = type(re.compile(""))
def is_entity_exposed(state: core.State) -> bool:
"""Return true if entity belongs to exposed domain list."""
return state.domain in DEFAULT_EXPOSED_DOMAINS
def json_load(fp: IO[str]) -> JsonObjectType: def json_load(fp: IO[str]) -> JsonObjectType:
"""Wrap json_loads for get_intents.""" """Wrap json_loads for get_intents."""
return json_loads_object(fp.read()) return json_loads_object(fp.read())
@ -77,8 +83,7 @@ class DefaultAgent(AbstractConversationAgent):
# intent -> [sentences] # intent -> [sentences]
self._config_intents: dict[str, Any] = {} self._config_intents: dict[str, Any] = {}
self._areas_list: TextSlotList | None = None self._slot_lists: dict[str, TextSlotList] | None = None
self._names_list: TextSlotList | None = None
async def async_initialize(self, config_intents): async def async_initialize(self, config_intents):
"""Initialize the default agent.""" """Initialize the default agent."""
@ -128,10 +133,7 @@ class DefaultAgent(AbstractConversationAgent):
conversation_id, conversation_id,
) )
slot_lists: dict[str, SlotList] = { slot_lists: Mapping[str, SlotList] = self._make_slot_lists()
"area": self._make_areas_list(),
"name": self._make_names_list(),
}
result = await self.hass.async_add_executor_job( result = await self.hass.async_add_executor_job(
self._recognize, self._recognize,
@ -419,45 +421,38 @@ class DefaultAgent(AbstractConversationAgent):
@core.callback @core.callback
def _async_handle_area_registry_changed(self, event: core.Event) -> None: def _async_handle_area_registry_changed(self, event: core.Event) -> None:
"""Clear area area cache when the area registry has changed.""" """Clear area area cache when the area registry has changed."""
self._areas_list = None self._slot_lists = None
@core.callback @core.callback
def _async_handle_entity_registry_changed(self, event: core.Event) -> None: def _async_handle_entity_registry_changed(self, event: core.Event) -> None:
"""Clear names list cache when an entity changes aliases.""" """Clear names list cache when an entity changes aliases."""
if event.data["action"] == "update" and "aliases" not in event.data["changes"]: if event.data["action"] == "update" and "aliases" not in event.data["changes"]:
return return
self._names_list = None self._slot_lists = None
@core.callback @core.callback
def _async_handle_state_changed(self, event: core.Event) -> None: def _async_handle_state_changed(self, event: core.Event) -> None:
"""Clear names list cache when a state is added or removed from the state machine.""" """Clear names list cache when a state is added or removed from the state machine."""
if event.data.get("old_state") and event.data.get("new_state"): if event.data.get("old_state") and event.data.get("new_state"):
return return
self._names_list = None self._slot_lists = None
def _make_areas_list(self) -> TextSlotList: def _make_slot_lists(self) -> Mapping[str, SlotList]:
"""Create slot list mapping area names/aliases to area ids.""" """Create slot lists with areas and entity names/aliases."""
if self._areas_list is not None: if self._slot_lists is not None:
return self._areas_list return self._slot_lists
registry = area_registry.async_get(self.hass)
areas = []
for entry in registry.async_list_areas():
areas.append((entry.name, entry.id))
if entry.aliases:
for alias in entry.aliases:
areas.append((alias, entry.id))
self._areas_list = TextSlotList.from_tuples(areas, allow_template=False) area_ids_with_entities: set[str] = set()
return self._areas_list states = [
state for state in self.hass.states.async_all() if is_entity_exposed(state)
def _make_names_list(self) -> TextSlotList: ]
"""Create slot list with entity names/aliases."""
if self._names_list is not None:
return self._names_list
states = self.hass.states.async_all()
entities = entity_registry.async_get(self.hass) entities = entity_registry.async_get(self.hass)
names = [] devices = device_registry.async_get(self.hass)
# Gather exposed entity names
entity_names = []
for state in states: for state in states:
# Checked against "requires_context" and "excludes_context" in hassil
context = {"domain": state.domain} context = {"domain": state.domain}
entity = entities.async_get(state.entity_id) entity = entities.async_get(state.entity_id)
@ -468,17 +463,42 @@ class DefaultAgent(AbstractConversationAgent):
if entity.aliases: if entity.aliases:
for alias in entity.aliases: for alias in entity.aliases:
names.append((alias, alias, context)) entity_names.append((alias, alias, context))
# Default name # Default name
names.append((state.name, state.name, context)) entity_names.append((state.name, state.name, context))
if entity.area_id:
# Expose area too
area_ids_with_entities.add(entity.area_id)
elif entity.device_id:
# Check device for area as well
device = devices.async_get(entity.device_id)
if (device is not None) and device.area_id:
area_ids_with_entities.add(device.area_id)
else: else:
# Default name # Default name
names.append((state.name, state.name, context)) entity_names.append((state.name, state.name, context))
self._names_list = TextSlotList.from_tuples(names, allow_template=False) # Gather areas from exposed entities
return self._names_list areas = area_registry.async_get(self.hass)
area_names = []
for area_id in area_ids_with_entities:
area = areas.async_get_area(area_id)
if area is None:
continue
area_names.append((area.name, area.id))
if area.aliases:
for alias in area.aliases:
area_names.append((alias, area.id))
self._slot_lists = {
"area": TextSlotList.from_tuples(area_names, allow_template=False),
"name": TextSlotList.from_tuples(entity_names, allow_template=False),
}
return self._slot_lists
def _get_error_text( def _get_error_text(
self, response_type: ResponseType, lang_intents: LanguageIntents self, response_type: ResponseType, lang_intents: LanguageIntents

View File

@ -1,9 +1,18 @@
"""Test for the default agent.""" """Test for the default agent."""
from unittest.mock import patch
import pytest import pytest
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context, HomeAssistant from homeassistant.core import DOMAIN as HASS_DOMAIN, Context, HomeAssistant
from homeassistant.helpers import entity, entity_registry, intent from homeassistant.helpers import (
area_registry,
device_registry,
entity,
entity_registry,
intent,
)
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import async_mock_service from tests.common import async_mock_service
@ -44,3 +53,70 @@ async def test_hidden_entities_skipped(
assert len(calls) == 0 assert len(calls) == 0
assert result.response.response_type == intent.IntentResponseType.ERROR assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
async def test_exposed_domains(hass: HomeAssistant, init_components) -> None:
"""Test that we can't interact with entities that aren't exposed."""
hass.states.async_set(
"media_player.test", "off", attributes={ATTR_FRIENDLY_NAME: "Test Media Player"}
)
result = await conversation.async_converse(
hass, "turn on test media player", None, Context(), None
)
# This is an intent match failure instead of a handle failure because the
# media player domain is not exposed.
assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
async def test_exposed_areas(hass: HomeAssistant, init_components) -> None:
"""Test that only expose areas with an exposed entity/device."""
areas = area_registry.async_get(hass)
area_kitchen = areas.async_get_or_create("kitchen")
area_bedroom = areas.async_get_or_create("bedroom")
devices = device_registry.async_get(hass)
kitchen_device = devices.async_get_or_create(
config_entry_id="1234", connections=set(), identifiers={("demo", "id-1234")}
)
devices.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
entities = entity_registry.async_get(hass)
kitchen_light = entities.async_get_or_create("light", "demo", "1234")
entities.async_update_entity(kitchen_light.entity_id, device_id=kitchen_device.id)
hass.states.async_set(
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
)
bedroom_light = entities.async_get_or_create("light", "demo", "5678")
entities.async_update_entity(bedroom_light.entity_id, area_id=area_bedroom.id)
hass.states.async_set(
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
)
def is_entity_exposed(state):
return state.entity_id != bedroom_light.entity_id
with patch(
"homeassistant.components.conversation.default_agent.is_entity_exposed",
is_entity_exposed,
):
result = await conversation.async_converse(
hass, "turn on lights in the kitchen", None, Context(), None
)
# All is well for the exposed kitchen light
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
# Bedroom is not exposed because it has no exposed entities
result = await conversation.async_converse(
hass, "turn on lights in the bedroom", None, Context(), None
)
# This should be an intent match failure because the area isn't in the slot list
assert result.response.response_type == intent.IntentResponseType.ERROR
assert (
result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
)

View File

@ -158,14 +158,19 @@ def test_async_validate_slots() -> None:
) )
async def test_cant_turn_on_sun(hass: HomeAssistant) -> None: async def test_cant_turn_on_sensor(hass: HomeAssistant) -> None:
"""Test we can't turn on entities that don't support it.""" """Test that we can't turn on entities that don't support it."""
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "conversation", {})
assert await async_setup_component(hass, "intent", {}) assert await async_setup_component(hass, "intent", {})
assert await async_setup_component(hass, "sun", {}) assert await async_setup_component(hass, "sensor", {})
hass.states.async_set(
"sensor.test", "123", attributes={ATTR_FRIENDLY_NAME: "Test Sensor"}
)
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "turn on sun", None, Context(), None hass, "turn on test sensor", None, Context(), None
) )
assert result.response.response_type == intent.IntentResponseType.ERROR assert result.response.response_type == intent.IntentResponseType.ERROR