diff --git a/homeassistant/components/conversation/const.py b/homeassistant/components/conversation/const.py index 04bfa373061..b79a557698f 100644 --- a/homeassistant/components/conversation/const.py +++ b/homeassistant/components/conversation/const.py @@ -1,3 +1,18 @@ """Const for conversation integration.""" DOMAIN = "conversation" + +DEFAULT_EXPOSED_DOMAINS = { + "climate", + "cover", + "fan", + "humidifier", + "light", + "lock", + "scene", + "script", + "sensor", + "switch", + "vacuum", + "water_heater", +} diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 6d9afcbafa0..b418275f857 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from dataclasses import dataclass import logging from pathlib import Path @@ -19,6 +19,7 @@ import yaml from homeassistant import core, setup from homeassistant.helpers import ( area_registry, + device_registry, entity_registry, intent, template, @@ -27,7 +28,7 @@ from homeassistant.helpers import ( from homeassistant.util.json import JsonObjectType, json_loads_object from .agent import AbstractConversationAgent, ConversationInput, ConversationResult -from .const import DOMAIN +from .const import DEFAULT_EXPOSED_DOMAINS, DOMAIN _LOGGER = logging.getLogger(__name__) _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that" @@ -35,6 +36,11 @@ _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that" REGEX_TYPE = type(re.compile("")) +def is_entity_exposed(state: core.State) -> bool: + """Return true if entity belongs to exposed domain list.""" + return state.domain in DEFAULT_EXPOSED_DOMAINS + + def json_load(fp: IO[str]) -> JsonObjectType: """Wrap json_loads for get_intents.""" return json_loads_object(fp.read()) @@ -77,8 +83,7 @@ class DefaultAgent(AbstractConversationAgent): # intent -> [sentences] self._config_intents: dict[str, Any] = {} - self._areas_list: TextSlotList | None = None - self._names_list: TextSlotList | None = None + self._slot_lists: dict[str, TextSlotList] | None = None async def async_initialize(self, config_intents): """Initialize the default agent.""" @@ -128,10 +133,7 @@ class DefaultAgent(AbstractConversationAgent): conversation_id, ) - slot_lists: dict[str, SlotList] = { - "area": self._make_areas_list(), - "name": self._make_names_list(), - } + slot_lists: Mapping[str, SlotList] = self._make_slot_lists() result = await self.hass.async_add_executor_job( self._recognize, @@ -419,45 +421,38 @@ class DefaultAgent(AbstractConversationAgent): @core.callback def _async_handle_area_registry_changed(self, event: core.Event) -> None: """Clear area area cache when the area registry has changed.""" - self._areas_list = None + self._slot_lists = None @core.callback def _async_handle_entity_registry_changed(self, event: core.Event) -> None: """Clear names list cache when an entity changes aliases.""" if event.data["action"] == "update" and "aliases" not in event.data["changes"]: return - self._names_list = None + self._slot_lists = None @core.callback def _async_handle_state_changed(self, event: core.Event) -> None: """Clear names list cache when a state is added or removed from the state machine.""" if event.data.get("old_state") and event.data.get("new_state"): return - self._names_list = None + self._slot_lists = None - def _make_areas_list(self) -> TextSlotList: - """Create slot list mapping area names/aliases to area ids.""" - if self._areas_list is not None: - return self._areas_list - registry = area_registry.async_get(self.hass) - areas = [] - for entry in registry.async_list_areas(): - areas.append((entry.name, entry.id)) - if entry.aliases: - for alias in entry.aliases: - areas.append((alias, entry.id)) + def _make_slot_lists(self) -> Mapping[str, SlotList]: + """Create slot lists with areas and entity names/aliases.""" + if self._slot_lists is not None: + return self._slot_lists - self._areas_list = TextSlotList.from_tuples(areas, allow_template=False) - return self._areas_list - - def _make_names_list(self) -> TextSlotList: - """Create slot list with entity names/aliases.""" - if self._names_list is not None: - return self._names_list - states = self.hass.states.async_all() + area_ids_with_entities: set[str] = set() + states = [ + state for state in self.hass.states.async_all() if is_entity_exposed(state) + ] entities = entity_registry.async_get(self.hass) - names = [] + devices = device_registry.async_get(self.hass) + + # Gather exposed entity names + entity_names = [] for state in states: + # Checked against "requires_context" and "excludes_context" in hassil context = {"domain": state.domain} entity = entities.async_get(state.entity_id) @@ -468,17 +463,42 @@ class DefaultAgent(AbstractConversationAgent): if entity.aliases: for alias in entity.aliases: - names.append((alias, alias, context)) + entity_names.append((alias, alias, context)) # Default name - names.append((state.name, state.name, context)) + entity_names.append((state.name, state.name, context)) + if entity.area_id: + # Expose area too + area_ids_with_entities.add(entity.area_id) + elif entity.device_id: + # Check device for area as well + device = devices.async_get(entity.device_id) + if (device is not None) and device.area_id: + area_ids_with_entities.add(device.area_id) else: # Default name - names.append((state.name, state.name, context)) + entity_names.append((state.name, state.name, context)) - self._names_list = TextSlotList.from_tuples(names, allow_template=False) - return self._names_list + # Gather areas from exposed entities + areas = area_registry.async_get(self.hass) + area_names = [] + for area_id in area_ids_with_entities: + area = areas.async_get_area(area_id) + if area is None: + continue + + area_names.append((area.name, area.id)) + if area.aliases: + for alias in area.aliases: + area_names.append((alias, area.id)) + + self._slot_lists = { + "area": TextSlotList.from_tuples(area_names, allow_template=False), + "name": TextSlotList.from_tuples(entity_names, allow_template=False), + } + + return self._slot_lists def _get_error_text( self, response_type: ResponseType, lang_intents: LanguageIntents diff --git a/tests/components/conversation/test_default_agent.py b/tests/components/conversation/test_default_agent.py index 7621c42abbc..726ee4dc6e3 100644 --- a/tests/components/conversation/test_default_agent.py +++ b/tests/components/conversation/test_default_agent.py @@ -1,9 +1,18 @@ """Test for the default agent.""" +from unittest.mock import patch + import pytest from homeassistant.components import conversation +from homeassistant.const import ATTR_FRIENDLY_NAME from homeassistant.core import DOMAIN as HASS_DOMAIN, Context, HomeAssistant -from homeassistant.helpers import entity, entity_registry, intent +from homeassistant.helpers import ( + area_registry, + device_registry, + entity, + entity_registry, + intent, +) from homeassistant.setup import async_setup_component from tests.common import async_mock_service @@ -44,3 +53,70 @@ async def test_hidden_entities_skipped( assert len(calls) == 0 assert result.response.response_type == intent.IntentResponseType.ERROR assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH + + +async def test_exposed_domains(hass: HomeAssistant, init_components) -> None: + """Test that we can't interact with entities that aren't exposed.""" + hass.states.async_set( + "media_player.test", "off", attributes={ATTR_FRIENDLY_NAME: "Test Media Player"} + ) + + result = await conversation.async_converse( + hass, "turn on test media player", None, Context(), None + ) + + # This is an intent match failure instead of a handle failure because the + # media player domain is not exposed. + assert result.response.response_type == intent.IntentResponseType.ERROR + assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH + + +async def test_exposed_areas(hass: HomeAssistant, init_components) -> None: + """Test that only expose areas with an exposed entity/device.""" + areas = area_registry.async_get(hass) + area_kitchen = areas.async_get_or_create("kitchen") + area_bedroom = areas.async_get_or_create("bedroom") + + devices = device_registry.async_get(hass) + kitchen_device = devices.async_get_or_create( + config_entry_id="1234", connections=set(), identifiers={("demo", "id-1234")} + ) + devices.async_update_device(kitchen_device.id, area_id=area_kitchen.id) + + entities = entity_registry.async_get(hass) + kitchen_light = entities.async_get_or_create("light", "demo", "1234") + entities.async_update_entity(kitchen_light.entity_id, device_id=kitchen_device.id) + hass.states.async_set( + kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"} + ) + + bedroom_light = entities.async_get_or_create("light", "demo", "5678") + entities.async_update_entity(bedroom_light.entity_id, area_id=area_bedroom.id) + hass.states.async_set( + bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"} + ) + + def is_entity_exposed(state): + return state.entity_id != bedroom_light.entity_id + + with patch( + "homeassistant.components.conversation.default_agent.is_entity_exposed", + is_entity_exposed, + ): + result = await conversation.async_converse( + hass, "turn on lights in the kitchen", None, Context(), None + ) + + # All is well for the exposed kitchen light + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + + # Bedroom is not exposed because it has no exposed entities + result = await conversation.async_converse( + hass, "turn on lights in the bedroom", None, Context(), None + ) + + # This should be an intent match failure because the area isn't in the slot list + assert result.response.response_type == intent.IntentResponseType.ERROR + assert ( + result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH + ) diff --git a/tests/helpers/test_intent.py b/tests/helpers/test_intent.py index b975ae4bd72..b706a8c7551 100644 --- a/tests/helpers/test_intent.py +++ b/tests/helpers/test_intent.py @@ -158,14 +158,19 @@ def test_async_validate_slots() -> None: ) -async def test_cant_turn_on_sun(hass: HomeAssistant) -> None: - """Test we can't turn on entities that don't support it.""" +async def test_cant_turn_on_sensor(hass: HomeAssistant) -> None: + """Test that we can't turn on entities that don't support it.""" assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "conversation", {}) assert await async_setup_component(hass, "intent", {}) - assert await async_setup_component(hass, "sun", {}) + assert await async_setup_component(hass, "sensor", {}) + + hass.states.async_set( + "sensor.test", "123", attributes={ATTR_FRIENDLY_NAME: "Test Sensor"} + ) + result = await conversation.async_converse( - hass, "turn on sun", None, Context(), None + hass, "turn on test sensor", None, Context(), None ) assert result.response.response_type == intent.IntentResponseType.ERROR