Only expose default cloud domains in Assist default agent (#88274)

* Only expose default cloud domains in default agent

* Copy exposed domain list to conversation

* Implement requested changes

* Add test for exposed devices/areas
This commit is contained in:
Michael Hansen 2023-02-17 15:19:22 -06:00 committed by GitHub
parent 331102e592
commit 325674ec44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 157 additions and 41 deletions

View File

@ -1,3 +1,18 @@
"""Const for conversation integration."""
DOMAIN = "conversation"
DEFAULT_EXPOSED_DOMAINS = {
"climate",
"cover",
"fan",
"humidifier",
"light",
"lock",
"scene",
"script",
"sensor",
"switch",
"vacuum",
"water_heater",
}

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from dataclasses import dataclass
import logging
from pathlib import Path
@ -19,6 +19,7 @@ import yaml
from homeassistant import core, setup
from homeassistant.helpers import (
area_registry,
device_registry,
entity_registry,
intent,
template,
@ -27,7 +28,7 @@ from homeassistant.helpers import (
from homeassistant.util.json import JsonObjectType, json_loads_object
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import DOMAIN
from .const import DEFAULT_EXPOSED_DOMAINS, DOMAIN
_LOGGER = logging.getLogger(__name__)
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
@ -35,6 +36,11 @@ _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
REGEX_TYPE = type(re.compile(""))
def is_entity_exposed(state: core.State) -> bool:
"""Return true if entity belongs to exposed domain list."""
return state.domain in DEFAULT_EXPOSED_DOMAINS
def json_load(fp: IO[str]) -> JsonObjectType:
"""Wrap json_loads for get_intents."""
return json_loads_object(fp.read())
@ -77,8 +83,7 @@ class DefaultAgent(AbstractConversationAgent):
# intent -> [sentences]
self._config_intents: dict[str, Any] = {}
self._areas_list: TextSlotList | None = None
self._names_list: TextSlotList | None = None
self._slot_lists: dict[str, TextSlotList] | None = None
async def async_initialize(self, config_intents):
"""Initialize the default agent."""
@ -128,10 +133,7 @@ class DefaultAgent(AbstractConversationAgent):
conversation_id,
)
slot_lists: dict[str, SlotList] = {
"area": self._make_areas_list(),
"name": self._make_names_list(),
}
slot_lists: Mapping[str, SlotList] = self._make_slot_lists()
result = await self.hass.async_add_executor_job(
self._recognize,
@ -419,45 +421,38 @@ class DefaultAgent(AbstractConversationAgent):
@core.callback
def _async_handle_area_registry_changed(self, event: core.Event) -> None:
"""Clear area area cache when the area registry has changed."""
self._areas_list = None
self._slot_lists = None
@core.callback
def _async_handle_entity_registry_changed(self, event: core.Event) -> None:
"""Clear names list cache when an entity changes aliases."""
if event.data["action"] == "update" and "aliases" not in event.data["changes"]:
return
self._names_list = None
self._slot_lists = None
@core.callback
def _async_handle_state_changed(self, event: core.Event) -> None:
"""Clear names list cache when a state is added or removed from the state machine."""
if event.data.get("old_state") and event.data.get("new_state"):
return
self._names_list = None
self._slot_lists = None
def _make_areas_list(self) -> TextSlotList:
"""Create slot list mapping area names/aliases to area ids."""
if self._areas_list is not None:
return self._areas_list
registry = area_registry.async_get(self.hass)
areas = []
for entry in registry.async_list_areas():
areas.append((entry.name, entry.id))
if entry.aliases:
for alias in entry.aliases:
areas.append((alias, entry.id))
def _make_slot_lists(self) -> Mapping[str, SlotList]:
"""Create slot lists with areas and entity names/aliases."""
if self._slot_lists is not None:
return self._slot_lists
self._areas_list = TextSlotList.from_tuples(areas, allow_template=False)
return self._areas_list
def _make_names_list(self) -> TextSlotList:
"""Create slot list with entity names/aliases."""
if self._names_list is not None:
return self._names_list
states = self.hass.states.async_all()
area_ids_with_entities: set[str] = set()
states = [
state for state in self.hass.states.async_all() if is_entity_exposed(state)
]
entities = entity_registry.async_get(self.hass)
names = []
devices = device_registry.async_get(self.hass)
# Gather exposed entity names
entity_names = []
for state in states:
# Checked against "requires_context" and "excludes_context" in hassil
context = {"domain": state.domain}
entity = entities.async_get(state.entity_id)
@ -468,17 +463,42 @@ class DefaultAgent(AbstractConversationAgent):
if entity.aliases:
for alias in entity.aliases:
names.append((alias, alias, context))
entity_names.append((alias, alias, context))
# Default name
names.append((state.name, state.name, context))
entity_names.append((state.name, state.name, context))
if entity.area_id:
# Expose area too
area_ids_with_entities.add(entity.area_id)
elif entity.device_id:
# Check device for area as well
device = devices.async_get(entity.device_id)
if (device is not None) and device.area_id:
area_ids_with_entities.add(device.area_id)
else:
# Default name
names.append((state.name, state.name, context))
entity_names.append((state.name, state.name, context))
self._names_list = TextSlotList.from_tuples(names, allow_template=False)
return self._names_list
# Gather areas from exposed entities
areas = area_registry.async_get(self.hass)
area_names = []
for area_id in area_ids_with_entities:
area = areas.async_get_area(area_id)
if area is None:
continue
area_names.append((area.name, area.id))
if area.aliases:
for alias in area.aliases:
area_names.append((alias, area.id))
self._slot_lists = {
"area": TextSlotList.from_tuples(area_names, allow_template=False),
"name": TextSlotList.from_tuples(entity_names, allow_template=False),
}
return self._slot_lists
def _get_error_text(
self, response_type: ResponseType, lang_intents: LanguageIntents

View File

@ -1,9 +1,18 @@
"""Test for the default agent."""
from unittest.mock import patch
import pytest
from homeassistant.components import conversation
from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context, HomeAssistant
from homeassistant.helpers import entity, entity_registry, intent
from homeassistant.helpers import (
area_registry,
device_registry,
entity,
entity_registry,
intent,
)
from homeassistant.setup import async_setup_component
from tests.common import async_mock_service
@ -44,3 +53,70 @@ async def test_hidden_entities_skipped(
assert len(calls) == 0
assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
async def test_exposed_domains(hass: HomeAssistant, init_components) -> None:
"""Test that we can't interact with entities that aren't exposed."""
hass.states.async_set(
"media_player.test", "off", attributes={ATTR_FRIENDLY_NAME: "Test Media Player"}
)
result = await conversation.async_converse(
hass, "turn on test media player", None, Context(), None
)
# This is an intent match failure instead of a handle failure because the
# media player domain is not exposed.
assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
async def test_exposed_areas(hass: HomeAssistant, init_components) -> None:
"""Test that only expose areas with an exposed entity/device."""
areas = area_registry.async_get(hass)
area_kitchen = areas.async_get_or_create("kitchen")
area_bedroom = areas.async_get_or_create("bedroom")
devices = device_registry.async_get(hass)
kitchen_device = devices.async_get_or_create(
config_entry_id="1234", connections=set(), identifiers={("demo", "id-1234")}
)
devices.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
entities = entity_registry.async_get(hass)
kitchen_light = entities.async_get_or_create("light", "demo", "1234")
entities.async_update_entity(kitchen_light.entity_id, device_id=kitchen_device.id)
hass.states.async_set(
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
)
bedroom_light = entities.async_get_or_create("light", "demo", "5678")
entities.async_update_entity(bedroom_light.entity_id, area_id=area_bedroom.id)
hass.states.async_set(
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
)
def is_entity_exposed(state):
return state.entity_id != bedroom_light.entity_id
with patch(
"homeassistant.components.conversation.default_agent.is_entity_exposed",
is_entity_exposed,
):
result = await conversation.async_converse(
hass, "turn on lights in the kitchen", None, Context(), None
)
# All is well for the exposed kitchen light
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
# Bedroom is not exposed because it has no exposed entities
result = await conversation.async_converse(
hass, "turn on lights in the bedroom", None, Context(), None
)
# This should be an intent match failure because the area isn't in the slot list
assert result.response.response_type == intent.IntentResponseType.ERROR
assert (
result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
)

View File

@ -158,14 +158,19 @@ def test_async_validate_slots() -> None:
)
async def test_cant_turn_on_sun(hass: HomeAssistant) -> None:
"""Test we can't turn on entities that don't support it."""
async def test_cant_turn_on_sensor(hass: HomeAssistant) -> None:
"""Test that we can't turn on entities that don't support it."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
assert await async_setup_component(hass, "intent", {})
assert await async_setup_component(hass, "sun", {})
assert await async_setup_component(hass, "sensor", {})
hass.states.async_set(
"sensor.test", "123", attributes={ATTR_FRIENDLY_NAME: "Test Sensor"}
)
result = await conversation.async_converse(
hass, "turn on sun", None, Context(), None
hass, "turn on test sensor", None, Context(), None
)
assert result.response.response_type == intent.IntentResponseType.ERROR