From a8acde62ff501c4d368c50c052b00fcbdaec48db Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 17 Nov 2023 07:34:14 -0600 Subject: [PATCH] Use device area as context during intent recognition (#103939) * Use device area as context during intent recognition * Use guard clauses --- .../components/conversation/default_agent.py | 45 ++++++- .../conversation/test_default_agent.py | 122 ++++++++++++++++++ 2 files changed, 162 insertions(+), 5 deletions(-) diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 9dcf70dda80..c1282bbbac1 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -188,11 +188,14 @@ class DefaultAgent(AbstractConversationAgent): return None slot_lists = self._make_slot_lists() + intent_context = self._make_intent_context(user_input) + result = await self.hass.async_add_executor_job( self._recognize, user_input, lang_intents, slot_lists, + intent_context, ) return result @@ -221,15 +224,24 @@ class DefaultAgent(AbstractConversationAgent): # loaded in async_recognize. assert lang_intents is not None + # Include slot values from intent_context, such as the name of the + # device's area. + slots = { + entity_name: {"value": entity_value} + for entity_name, entity_value in result.context.items() + } + + # Override context with result entities + slots.update( + {entity.name: {"value": entity.value} for entity in result.entities_list} + ) + try: intent_response = await intent.async_handle( self.hass, DOMAIN, result.intent.name, - { - entity.name: {"value": entity.value} - for entity in result.entities_list - }, + slots, user_input.text, user_input.context, language, @@ -277,12 +289,16 @@ class DefaultAgent(AbstractConversationAgent): user_input: ConversationInput, lang_intents: LanguageIntents, slot_lists: dict[str, SlotList], + intent_context: dict[str, Any] | None, ) -> RecognizeResult | None: """Search intents for a match to user input.""" # Prioritize matches with entity names above area names maybe_result: RecognizeResult | None = None for result in recognize_all( - user_input.text, lang_intents.intents, slot_lists=slot_lists + user_input.text, + lang_intents.intents, + slot_lists=slot_lists, + intent_context=intent_context, ): if "name" in result.entities: return result @@ -623,6 +639,25 @@ class DefaultAgent(AbstractConversationAgent): return self._slot_lists + def _make_intent_context( + self, user_input: ConversationInput + ) -> dict[str, Any] | None: + """Return intent recognition context for user input.""" + if not user_input.device_id: + return None + + devices = dr.async_get(self.hass) + device = devices.async_get(user_input.device_id) + if (device is None) or (device.area_id is None): + return None + + areas = ar.async_get(self.hass) + device_area = areas.async_get_area(device.area_id) + if device_area is None: + return None + + return {"area": device_area.name} + def _get_error_text( self, response_type: ResponseType, lang_intents: LanguageIntents | None ) -> str: diff --git a/tests/components/conversation/test_default_agent.py b/tests/components/conversation/test_default_agent.py index c75c96ca59b..bc85cdf604c 100644 --- a/tests/components/conversation/test_default_agent.py +++ b/tests/components/conversation/test_default_agent.py @@ -1,4 +1,5 @@ """Test for the default agent.""" +from collections import defaultdict from unittest.mock import AsyncMock, patch import pytest @@ -293,3 +294,124 @@ async def test_nevermind_item(hass: HomeAssistant, init_components) -> None: assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert not result.response.speech + + +async def test_device_area_context( + hass: HomeAssistant, + init_components, + area_registry: ar.AreaRegistry, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test that including a device_id will target a specific area.""" + turn_on_calls = async_mock_service(hass, "light", "turn_on") + turn_off_calls = async_mock_service(hass, "light", "turn_off") + + area_kitchen = area_registry.async_get_or_create("kitchen") + area_bedroom = area_registry.async_get_or_create("bedroom") + + # Create 2 lights in each area + area_lights = defaultdict(list) + for area in (area_kitchen, area_bedroom): + for i in range(2): + light_entity = entity_registry.async_get_or_create( + "light", "demo", f"{area.name}-light-{i}" + ) + entity_registry.async_update_entity(light_entity.entity_id, area_id=area.id) + hass.states.async_set( + light_entity.entity_id, + "off", + attributes={ATTR_FRIENDLY_NAME: f"{area.name} light {i}"}, + ) + area_lights[area.name].append(light_entity) + + # Create voice satellites in each area + entry = MockConfigEntry() + entry.add_to_hass(hass) + + kitchen_satellite = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections=set(), + identifiers={("demo", "id-satellite-kitchen")}, + ) + device_registry.async_update_device(kitchen_satellite.id, area_id=area_kitchen.id) + + bedroom_satellite = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections=set(), + identifiers={("demo", "id-satellite-bedroom")}, + ) + device_registry.async_update_device(bedroom_satellite.id, area_id=area_bedroom.id) + + # Turn on all lights in the area of a device + result = await conversation.async_converse( + hass, + "turn on all lights", + None, + Context(), + None, + device_id=kitchen_satellite.id, + ) + await hass.async_block_till_done() + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + + # Verify only kitchen lights were targeted + assert {s.entity_id for s in result.response.matched_states} == { + e.entity_id for e in area_lights["kitchen"] + } + assert {c.data["entity_id"][0] for c in turn_on_calls} == { + e.entity_id for e in area_lights["kitchen"] + } + turn_on_calls.clear() + + # Ensure we can still target other areas by name + result = await conversation.async_converse( + hass, + "turn on all lights in the bedroom", + None, + Context(), + None, + device_id=kitchen_satellite.id, + ) + await hass.async_block_till_done() + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + + # Verify only bedroom lights were targeted + assert {s.entity_id for s in result.response.matched_states} == { + e.entity_id for e in area_lights["bedroom"] + } + assert {c.data["entity_id"][0] for c in turn_on_calls} == { + e.entity_id for e in area_lights["bedroom"] + } + turn_on_calls.clear() + + # Turn off all lights in the area of the otherkj device + result = await conversation.async_converse( + hass, + "turn all lights off", + None, + Context(), + None, + device_id=bedroom_satellite.id, + ) + await hass.async_block_till_done() + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + + # Verify only bedroom lights were targeted + assert {s.entity_id for s in result.response.matched_states} == { + e.entity_id for e in area_lights["bedroom"] + } + assert {c.data["entity_id"][0] for c in turn_off_calls} == { + e.entity_id for e in area_lights["bedroom"] + } + turn_off_calls.clear() + + # Not providing a device id should not match + for command in ("on", "off"): + result = await conversation.async_converse( + hass, f"turn {command} all lights", None, Context(), None + ) + assert result.response.response_type == intent.IntentResponseType.ERROR + assert ( + result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH + )