mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Only expose default cloud domains in Assist default agent (#88274)
* Only expose default cloud domains in default agent * Copy exposed domain list to conversation * Implement requested changes * Add test for exposed devices/areas
This commit is contained in:
parent
331102e592
commit
325674ec44
@ -1,3 +1,18 @@
|
||||
"""Const for conversation integration."""
|
||||
|
||||
DOMAIN = "conversation"
|
||||
|
||||
DEFAULT_EXPOSED_DOMAINS = {
|
||||
"climate",
|
||||
"cover",
|
||||
"fan",
|
||||
"humidifier",
|
||||
"light",
|
||||
"lock",
|
||||
"scene",
|
||||
"script",
|
||||
"sensor",
|
||||
"switch",
|
||||
"vacuum",
|
||||
"water_heater",
|
||||
}
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@ -19,6 +19,7 @@ import yaml
|
||||
from homeassistant import core, setup
|
||||
from homeassistant.helpers import (
|
||||
area_registry,
|
||||
device_registry,
|
||||
entity_registry,
|
||||
intent,
|
||||
template,
|
||||
@ -27,7 +28,7 @@ from homeassistant.helpers import (
|
||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||
|
||||
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .const import DOMAIN
|
||||
from .const import DEFAULT_EXPOSED_DOMAINS, DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
|
||||
@ -35,6 +36,11 @@ _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
|
||||
REGEX_TYPE = type(re.compile(""))
|
||||
|
||||
|
||||
def is_entity_exposed(state: core.State) -> bool:
|
||||
"""Return true if entity belongs to exposed domain list."""
|
||||
return state.domain in DEFAULT_EXPOSED_DOMAINS
|
||||
|
||||
|
||||
def json_load(fp: IO[str]) -> JsonObjectType:
|
||||
"""Wrap json_loads for get_intents."""
|
||||
return json_loads_object(fp.read())
|
||||
@ -77,8 +83,7 @@ class DefaultAgent(AbstractConversationAgent):
|
||||
|
||||
# intent -> [sentences]
|
||||
self._config_intents: dict[str, Any] = {}
|
||||
self._areas_list: TextSlotList | None = None
|
||||
self._names_list: TextSlotList | None = None
|
||||
self._slot_lists: dict[str, TextSlotList] | None = None
|
||||
|
||||
async def async_initialize(self, config_intents):
|
||||
"""Initialize the default agent."""
|
||||
@ -128,10 +133,7 @@ class DefaultAgent(AbstractConversationAgent):
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
slot_lists: dict[str, SlotList] = {
|
||||
"area": self._make_areas_list(),
|
||||
"name": self._make_names_list(),
|
||||
}
|
||||
slot_lists: Mapping[str, SlotList] = self._make_slot_lists()
|
||||
|
||||
result = await self.hass.async_add_executor_job(
|
||||
self._recognize,
|
||||
@ -419,45 +421,38 @@ class DefaultAgent(AbstractConversationAgent):
|
||||
@core.callback
|
||||
def _async_handle_area_registry_changed(self, event: core.Event) -> None:
|
||||
"""Clear area area cache when the area registry has changed."""
|
||||
self._areas_list = None
|
||||
self._slot_lists = None
|
||||
|
||||
@core.callback
|
||||
def _async_handle_entity_registry_changed(self, event: core.Event) -> None:
|
||||
"""Clear names list cache when an entity changes aliases."""
|
||||
if event.data["action"] == "update" and "aliases" not in event.data["changes"]:
|
||||
return
|
||||
self._names_list = None
|
||||
self._slot_lists = None
|
||||
|
||||
@core.callback
|
||||
def _async_handle_state_changed(self, event: core.Event) -> None:
|
||||
"""Clear names list cache when a state is added or removed from the state machine."""
|
||||
if event.data.get("old_state") and event.data.get("new_state"):
|
||||
return
|
||||
self._names_list = None
|
||||
self._slot_lists = None
|
||||
|
||||
def _make_areas_list(self) -> TextSlotList:
|
||||
"""Create slot list mapping area names/aliases to area ids."""
|
||||
if self._areas_list is not None:
|
||||
return self._areas_list
|
||||
registry = area_registry.async_get(self.hass)
|
||||
areas = []
|
||||
for entry in registry.async_list_areas():
|
||||
areas.append((entry.name, entry.id))
|
||||
if entry.aliases:
|
||||
for alias in entry.aliases:
|
||||
areas.append((alias, entry.id))
|
||||
def _make_slot_lists(self) -> Mapping[str, SlotList]:
|
||||
"""Create slot lists with areas and entity names/aliases."""
|
||||
if self._slot_lists is not None:
|
||||
return self._slot_lists
|
||||
|
||||
self._areas_list = TextSlotList.from_tuples(areas, allow_template=False)
|
||||
return self._areas_list
|
||||
|
||||
def _make_names_list(self) -> TextSlotList:
|
||||
"""Create slot list with entity names/aliases."""
|
||||
if self._names_list is not None:
|
||||
return self._names_list
|
||||
states = self.hass.states.async_all()
|
||||
area_ids_with_entities: set[str] = set()
|
||||
states = [
|
||||
state for state in self.hass.states.async_all() if is_entity_exposed(state)
|
||||
]
|
||||
entities = entity_registry.async_get(self.hass)
|
||||
names = []
|
||||
devices = device_registry.async_get(self.hass)
|
||||
|
||||
# Gather exposed entity names
|
||||
entity_names = []
|
||||
for state in states:
|
||||
# Checked against "requires_context" and "excludes_context" in hassil
|
||||
context = {"domain": state.domain}
|
||||
|
||||
entity = entities.async_get(state.entity_id)
|
||||
@ -468,17 +463,42 @@ class DefaultAgent(AbstractConversationAgent):
|
||||
|
||||
if entity.aliases:
|
||||
for alias in entity.aliases:
|
||||
names.append((alias, alias, context))
|
||||
entity_names.append((alias, alias, context))
|
||||
|
||||
# Default name
|
||||
names.append((state.name, state.name, context))
|
||||
entity_names.append((state.name, state.name, context))
|
||||
|
||||
if entity.area_id:
|
||||
# Expose area too
|
||||
area_ids_with_entities.add(entity.area_id)
|
||||
elif entity.device_id:
|
||||
# Check device for area as well
|
||||
device = devices.async_get(entity.device_id)
|
||||
if (device is not None) and device.area_id:
|
||||
area_ids_with_entities.add(device.area_id)
|
||||
else:
|
||||
# Default name
|
||||
names.append((state.name, state.name, context))
|
||||
entity_names.append((state.name, state.name, context))
|
||||
|
||||
self._names_list = TextSlotList.from_tuples(names, allow_template=False)
|
||||
return self._names_list
|
||||
# Gather areas from exposed entities
|
||||
areas = area_registry.async_get(self.hass)
|
||||
area_names = []
|
||||
for area_id in area_ids_with_entities:
|
||||
area = areas.async_get_area(area_id)
|
||||
if area is None:
|
||||
continue
|
||||
|
||||
area_names.append((area.name, area.id))
|
||||
if area.aliases:
|
||||
for alias in area.aliases:
|
||||
area_names.append((alias, area.id))
|
||||
|
||||
self._slot_lists = {
|
||||
"area": TextSlotList.from_tuples(area_names, allow_template=False),
|
||||
"name": TextSlotList.from_tuples(entity_names, allow_template=False),
|
||||
}
|
||||
|
||||
return self._slot_lists
|
||||
|
||||
def _get_error_text(
|
||||
self, response_type: ResponseType, lang_intents: LanguageIntents
|
||||
|
@ -1,9 +1,18 @@
|
||||
"""Test for the default agent."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME
|
||||
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context, HomeAssistant
|
||||
from homeassistant.helpers import entity, entity_registry, intent
|
||||
from homeassistant.helpers import (
|
||||
area_registry,
|
||||
device_registry,
|
||||
entity,
|
||||
entity_registry,
|
||||
intent,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import async_mock_service
|
||||
@ -44,3 +53,70 @@ async def test_hidden_entities_skipped(
|
||||
assert len(calls) == 0
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||
assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
|
||||
|
||||
|
||||
async def test_exposed_domains(hass: HomeAssistant, init_components) -> None:
|
||||
"""Test that we can't interact with entities that aren't exposed."""
|
||||
hass.states.async_set(
|
||||
"media_player.test", "off", attributes={ATTR_FRIENDLY_NAME: "Test Media Player"}
|
||||
)
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "turn on test media player", None, Context(), None
|
||||
)
|
||||
|
||||
# This is an intent match failure instead of a handle failure because the
|
||||
# media player domain is not exposed.
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||
assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
|
||||
|
||||
|
||||
async def test_exposed_areas(hass: HomeAssistant, init_components) -> None:
|
||||
"""Test that only expose areas with an exposed entity/device."""
|
||||
areas = area_registry.async_get(hass)
|
||||
area_kitchen = areas.async_get_or_create("kitchen")
|
||||
area_bedroom = areas.async_get_or_create("bedroom")
|
||||
|
||||
devices = device_registry.async_get(hass)
|
||||
kitchen_device = devices.async_get_or_create(
|
||||
config_entry_id="1234", connections=set(), identifiers={("demo", "id-1234")}
|
||||
)
|
||||
devices.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
|
||||
|
||||
entities = entity_registry.async_get(hass)
|
||||
kitchen_light = entities.async_get_or_create("light", "demo", "1234")
|
||||
entities.async_update_entity(kitchen_light.entity_id, device_id=kitchen_device.id)
|
||||
hass.states.async_set(
|
||||
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
|
||||
)
|
||||
|
||||
bedroom_light = entities.async_get_or_create("light", "demo", "5678")
|
||||
entities.async_update_entity(bedroom_light.entity_id, area_id=area_bedroom.id)
|
||||
hass.states.async_set(
|
||||
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
|
||||
)
|
||||
|
||||
def is_entity_exposed(state):
|
||||
return state.entity_id != bedroom_light.entity_id
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.default_agent.is_entity_exposed",
|
||||
is_entity_exposed,
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass, "turn on lights in the kitchen", None, Context(), None
|
||||
)
|
||||
|
||||
# All is well for the exposed kitchen light
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
|
||||
# Bedroom is not exposed because it has no exposed entities
|
||||
result = await conversation.async_converse(
|
||||
hass, "turn on lights in the bedroom", None, Context(), None
|
||||
)
|
||||
|
||||
# This should be an intent match failure because the area isn't in the slot list
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||
assert (
|
||||
result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
|
||||
)
|
||||
|
@ -158,14 +158,19 @@ def test_async_validate_slots() -> None:
|
||||
)
|
||||
|
||||
|
||||
async def test_cant_turn_on_sun(hass: HomeAssistant) -> None:
|
||||
"""Test we can't turn on entities that don't support it."""
|
||||
async def test_cant_turn_on_sensor(hass: HomeAssistant) -> None:
|
||||
"""Test that we can't turn on entities that don't support it."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
assert await async_setup_component(hass, "sun", {})
|
||||
assert await async_setup_component(hass, "sensor", {})
|
||||
|
||||
hass.states.async_set(
|
||||
"sensor.test", "123", attributes={ATTR_FRIENDLY_NAME: "Test Sensor"}
|
||||
)
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "turn on sun", None, Context(), None
|
||||
hass, "turn on test sensor", None, Context(), None
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||
|
Loading…
x
Reference in New Issue
Block a user