diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index 74a8383dd8b..16094ff797a 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -16,7 +16,7 @@ from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass from .agent import AbstractConversationAgent, ConversationResult -from .default_agent import DefaultAgent, async_register +from .default_agent import DefaultAgent _LOGGER = logging.getLogger(__name__) @@ -27,7 +27,6 @@ DOMAIN = "conversation" REGEX_TYPE = type(re.compile("")) DATA_AGENT = "conversation_agent" -DATA_CONFIG = "conversation_config" SERVICE_PROCESS = "process" SERVICE_RELOAD = "reload" @@ -47,22 +46,6 @@ SERVICE_RELOAD_SCHEMA = vol.Schema( ) -CONFIG_SCHEMA = vol.Schema( - { - DOMAIN: vol.Schema( - { - vol.Optional("intents"): vol.Schema( - {cv.string: vol.All(cv.ensure_list, [cv.string])} - ) - } - ) - }, - extra=vol.ALLOW_EXTRA, -) - -async_register = bind_hass(async_register) - - @core.callback @bind_hass def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent | None): @@ -72,7 +55,6 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent | async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Register the process service.""" - hass.data[DATA_CONFIG] = config async def handle_process(service: core.ServiceCall) -> None: """Parse text into commands.""" @@ -228,7 +210,7 @@ async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent: """Get the active conversation agent.""" if (agent := hass.data.get(DATA_AGENT)) is None: agent = hass.data[DATA_AGENT] = DefaultAgent(hass) - await agent.async_initialize(hass.data.get(DATA_CONFIG)) + await agent.async_initialize() return agent diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 34d27583f3d..fff28e02ced 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -20,30 +20,12 @@ from homeassistant.helpers import area_registry, entity_registry, intent from .agent import AbstractConversationAgent, ConversationResult from .const import DOMAIN -from .util import create_matcher _LOGGER = logging.getLogger(__name__) REGEX_TYPE = type(re.compile("")) -@core.callback -def async_register(hass, intent_type, utterances): - """Register utterances and any custom intents for the default agent. - - Registrations don't require conversations to be loaded. They will become - active once the conversation component is loaded. - """ - intents = hass.data.setdefault(DOMAIN, {}) - conf = intents.setdefault(intent_type, []) - - for utterance in utterances: - if isinstance(utterance, REGEX_TYPE): - conf.append(utterance) - else: - conf.append(create_matcher(utterance)) - - @dataclass class LanguageIntents: """Loaded intents for a language.""" @@ -62,16 +44,11 @@ class DefaultAgent(AbstractConversationAgent): self._lang_intents: dict[str, LanguageIntents] = {} self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) - async def async_initialize(self, config): + async def async_initialize(self): """Initialize the default agent.""" if "intent" not in self.hass.config.components: await setup.async_setup_component(self.hass, "intent", {}) - if config and config.get(DOMAIN): - _LOGGER.warning( - "Custom intent sentences have been moved to config/custom_sentences" - ) - self.hass.data.setdefault(DOMAIN, {}) async def async_process( diff --git a/homeassistant/components/humidifier/intent.py b/homeassistant/components/humidifier/intent.py index 4d28cf5838c..d949874cc67 100644 --- a/homeassistant/components/humidifier/intent.py +++ b/homeassistant/components/humidifier/intent.py @@ -41,10 +41,18 @@ class HumidityHandler(intent.IntentHandler): """Handle the hass intent.""" hass = intent_obj.hass slots = self.async_validate_slots(intent_obj.slots) - state = intent.async_match_state( - hass, slots["name"]["value"], hass.states.async_all(DOMAIN) + states = list( + intent.async_match_states( + hass, + name=slots["name"]["value"], + states=hass.states.async_all(DOMAIN), + ) ) + if not states: + raise intent.IntentHandleError("No entities matched") + + state = states[0] service_data = {ATTR_ENTITY_ID: state.entity_id} humidity = slots["humidity"]["value"] @@ -85,12 +93,18 @@ class SetModeHandler(intent.IntentHandler): """Handle the hass intent.""" hass = intent_obj.hass slots = self.async_validate_slots(intent_obj.slots) - state = intent.async_match_state( - hass, - slots["name"]["value"], - hass.states.async_all(DOMAIN), + states = list( + intent.async_match_states( + hass, + name=slots["name"]["value"], + states=hass.states.async_all(DOMAIN), + ) ) + if not states: + raise intent.IntentHandleError("No entities matched") + + state = states[0] service_data = {ATTR_ENTITY_ID: state.entity_id} intent.async_test_feature(state, HumidifierEntityFeature.MODES, "modes") diff --git a/homeassistant/components/intent/__init__.py b/homeassistant/components/intent/__init__.py index c6ca9212c74..9171f5b9fc0 100644 --- a/homeassistant/components/intent/__init__.py +++ b/homeassistant/components/intent/__init__.py @@ -2,9 +2,19 @@ import voluptuous as vol from homeassistant.components import http +from homeassistant.components.cover import ( + DOMAIN as COVER_DOMAIN, + SERVICE_CLOSE_COVER, + SERVICE_OPEN_COVER, +) from homeassistant.components.http.data_validator import RequestDataValidator -from homeassistant.const import SERVICE_TOGGLE, SERVICE_TURN_OFF, SERVICE_TURN_ON -from homeassistant.core import DOMAIN as HA_DOMAIN, HomeAssistant +from homeassistant.const import ( + ATTR_ENTITY_ID, + SERVICE_TOGGLE, + SERVICE_TURN_OFF, + SERVICE_TURN_ON, +) +from homeassistant.core import DOMAIN as HA_DOMAIN, HomeAssistant, State from homeassistant.helpers import config_validation as cv, integration_platform, intent from homeassistant.helpers.typing import ConfigType @@ -21,13 +31,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: intent.async_register( hass, - intent.ServiceIntentHandler( + OnOffIntentHandler( intent.INTENT_TURN_ON, HA_DOMAIN, SERVICE_TURN_ON, "Turned {} on" ), ) intent.async_register( hass, - intent.ServiceIntentHandler( + OnOffIntentHandler( intent.INTENT_TURN_OFF, HA_DOMAIN, SERVICE_TURN_OFF, "Turned {} off" ), ) @@ -41,6 +51,29 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True +class OnOffIntentHandler(intent.ServiceIntentHandler): + """Intent handler for on/off that handles covers too.""" + + async def async_call_service(self, intent_obj: intent.Intent, state: State) -> None: + """Call service on entity with special case for covers.""" + hass = intent_obj.hass + + if state.domain == COVER_DOMAIN: + # on = open + # off = close + await hass.services.async_call( + COVER_DOMAIN, + SERVICE_OPEN_COVER + if self.service == SERVICE_TURN_ON + else SERVICE_CLOSE_COVER, + {ATTR_ENTITY_ID: state.entity_id}, + context=intent_obj.context, + ) + else: + # Fall back to homeassistant.turn_on/off + await super().async_call_service(intent_obj, state) + + async def _async_process_intent(hass: HomeAssistant, domain: str, platform): """Process the intents of an integration.""" await platform.async_setup_intents(hass) diff --git a/homeassistant/components/light/intent.py b/homeassistant/components/light/intent.py index e85602f763a..5ee60459128 100644 --- a/homeassistant/components/light/intent.py +++ b/homeassistant/components/light/intent.py @@ -1,12 +1,15 @@ """Intents for the light integration.""" from __future__ import annotations +import asyncio +import logging +from typing import Any + import voluptuous as vol from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON -from homeassistant.core import HomeAssistant, State -from homeassistant.helpers import intent -import homeassistant.helpers.config_validation as cv +from homeassistant.core import HomeAssistant +from homeassistant.helpers import area_registry, config_validation as cv, intent import homeassistant.util.color as color_util from . import ( @@ -18,6 +21,8 @@ from . import ( color_supported, ) +_LOGGER = logging.getLogger(__name__) + INTENT_SET = "HassLightSet" @@ -26,30 +31,14 @@ async def async_setup_intents(hass: HomeAssistant) -> None: intent.async_register(hass, SetIntentHandler()) -def _test_supports_color(state: State) -> None: - """Test if state supports colors.""" - supported_color_modes = state.attributes.get(ATTR_SUPPORTED_COLOR_MODES) - if not color_supported(supported_color_modes): - raise intent.IntentHandleError( - f"Entity {state.name} does not support changing colors" - ) - - -def _test_supports_brightness(state: State) -> None: - """Test if state supports brightness.""" - supported_color_modes = state.attributes.get(ATTR_SUPPORTED_COLOR_MODES) - if not brightness_supported(supported_color_modes): - raise intent.IntentHandleError( - f"Entity {state.name} does not support changing brightness" - ) - - class SetIntentHandler(intent.IntentHandler): """Handle set color intents.""" intent_type = INTENT_SET slot_schema = { - vol.Required("name"): cv.string, + vol.Any("name", "area"): cv.string, + vol.Optional("domain"): vol.All(cv.ensure_list, [cv.string]), + vol.Optional("device_class"): vol.All(cv.ensure_list, [cv.string]), vol.Optional("color"): color_util.color_name_to_rgb, vol.Optional("brightness"): vol.All(vol.Coerce(int), vol.Range(0, 100)), } @@ -57,36 +46,116 @@ class SetIntentHandler(intent.IntentHandler): async def async_handle(self, intent_obj: intent.Intent) -> intent.IntentResponse: """Handle the hass intent.""" hass = intent_obj.hass + service_data: dict[str, Any] = {} + speech_parts: list[str] = [] slots = self.async_validate_slots(intent_obj.slots) - state = intent.async_match_state( - hass, slots["name"]["value"], hass.states.async_all(DOMAIN) + + name: str | None = slots.get("name", {}).get("value") + if name == "all": + # Don't match on name if targeting all entities + name = None + + # Look up area first to fail early + area_name = slots.get("area", {}).get("value") + area: area_registry.AreaEntry | None = None + if area_name is not None: + areas = area_registry.async_get(hass) + area = areas.async_get_area(area_name) or areas.async_get_area_by_name( + area_name + ) + if area is None: + raise intent.IntentHandleError(f"No area named {area_name}") + + # Optional domain/device class filters. + # Convert to sets for speed. + domains: set[str] | None = None + device_classes: set[str] | None = None + + if "domain" in slots: + domains = set(slots["domain"]["value"]) + + if "device_class" in slots: + device_classes = set(slots["device_class"]["value"]) + + states = list( + intent.async_match_states( + hass, + name=name, + area=area, + domains=domains, + device_classes=device_classes, + ) ) - service_data = {ATTR_ENTITY_ID: state.entity_id} - speech_parts = [] + if not states: + raise intent.IntentHandleError("No entities matched") if "color" in slots: - _test_supports_color(state) service_data[ATTR_RGB_COLOR] = slots["color"]["value"] # Use original passed in value of the color because we don't have # human readable names for that internally. speech_parts.append(f"the color {intent_obj.slots['color']['value']}") if "brightness" in slots: - _test_supports_brightness(state) service_data[ATTR_BRIGHTNESS_PCT] = slots["brightness"]["value"] speech_parts.append(f"{slots['brightness']['value']}% brightness") - await hass.services.async_call( - DOMAIN, SERVICE_TURN_ON, service_data, context=intent_obj.context + response = intent_obj.create_response() + needs_brightness = ATTR_BRIGHTNESS_PCT in service_data + needs_color = ATTR_RGB_COLOR in service_data + + success_results: list[intent.IntentResponseTarget] = [] + failed_results: list[intent.IntentResponseTarget] = [] + service_coros = [] + + if area is not None: + success_results.append( + intent.IntentResponseTarget( + type=intent.IntentResponseTargetType.AREA, + name=area.name, + id=area.id, + ) + ) + speech_name = area.name + else: + speech_name = states[0].name + + for state in states: + target = intent.IntentResponseTarget( + type=intent.IntentResponseTargetType.ENTITY, + name=state.name, + id=state.entity_id, + ) + + # Test brightness/color + supported_color_modes = state.attributes.get(ATTR_SUPPORTED_COLOR_MODES) + if (needs_color and not color_supported(supported_color_modes)) or ( + needs_brightness and not brightness_supported(supported_color_modes) + ): + failed_results.append(target) + continue + + service_coros.append( + hass.services.async_call( + DOMAIN, + SERVICE_TURN_ON, + {**service_data, ATTR_ENTITY_ID: state.entity_id}, + context=intent_obj.context, + ) + ) + success_results.append(target) + + # Handle service calls in parallel. + await asyncio.gather(*service_coros) + + response.async_set_results( + success_results=success_results, failed_results=failed_results ) - response = intent_obj.create_response() - if not speech_parts: # No attributes changed - speech = f"Turned on {state.name}" + speech = f"Turned on {speech_name}" else: - parts = [f"Changed {state.name} to"] + parts = [f"Changed {speech_name} to"] for index, part in enumerate(speech_parts): if index == 0: parts.append(f" {part}") @@ -97,4 +166,5 @@ class SetIntentHandler(intent.IntentHandler): speech = "".join(parts) response.async_set_speech(speech) + return response diff --git a/homeassistant/helpers/intent.py b/homeassistant/helpers/intent.py index ba6461e1d60..c1b0b9d3a3f 100644 --- a/homeassistant/helpers/intent.py +++ b/homeassistant/helpers/intent.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import Iterable +from collections.abc import Collection, Iterable import dataclasses from dataclasses import dataclass from enum import Enum @@ -11,7 +11,11 @@ from typing import Any, TypeVar import voluptuous as vol -from homeassistant.const import ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES +from homeassistant.const import ( + ATTR_DEVICE_CLASS, + ATTR_ENTITY_ID, + ATTR_SUPPORTED_FEATURES, +) from homeassistant.core import Context, HomeAssistant, State, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import bind_hass @@ -110,51 +114,117 @@ class IntentUnexpectedError(IntentError): """Unexpected error while handling intent.""" +def _is_device_class( + state: State, + entity: entity_registry.RegistryEntry | None, + device_classes: Collection[str], +) -> bool: + """Return true if entity device class matches.""" + # Try entity first + if (entity is not None) and (entity.device_class is not None): + # Entity device class can be None or blank as "unset" + if entity.device_class in device_classes: + return True + + # Fall back to state attribute + device_class = state.attributes.get(ATTR_DEVICE_CLASS) + return (device_class is not None) and (device_class in device_classes) + + +def _has_name( + state: State, entity: entity_registry.RegistryEntry | None, name: str +) -> bool: + """Return true if entity name or alias matches.""" + if name in (state.entity_id, state.name.casefold()): + return True + + # Check aliases + if (entity is not None) and entity.aliases: + for alias in entity.aliases: + if name == alias.casefold(): + return True + + return False + + @callback @bind_hass -def async_match_state( - hass: HomeAssistant, name: str, states: Iterable[State] | None = None -) -> State: - """Find a state that matches the name.""" +def async_match_states( + hass: HomeAssistant, + name: str | None = None, + area_name: str | None = None, + area: area_registry.AreaEntry | None = None, + domains: Collection[str] | None = None, + device_classes: Collection[str] | None = None, + states: Iterable[State] | None = None, + entities: entity_registry.EntityRegistry | None = None, + areas: area_registry.AreaRegistry | None = None, +) -> Iterable[State]: + """Find states that match the constraints.""" if states is None: + # All states states = hass.states.async_all() - name = name.casefold() - state: State | None = None - registry = entity_registry.async_get(hass) + if entities is None: + entities = entity_registry.async_get(hass) - for maybe_state in states: - # Check entity id and name - if name in (maybe_state.entity_id, maybe_state.name.casefold()): - state = maybe_state - else: - # Check aliases - entry = registry.async_get(maybe_state.entity_id) - if (entry is not None) and entry.aliases: - for alias in entry.aliases: - if name == alias.casefold(): - state = maybe_state - break + # Gather entities + states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]] = [] + for state in states: + entity = entities.async_get(state.entity_id) + if (entity is not None) and entity.entity_category: + # Skip diagnostic entities + continue - if state is not None: - break + states_and_entities.append((state, entity)) - if state is None: - raise IntentHandleError(f"Unable to find an entity called {name}") + # Filter by domain and device class + if domains: + states_and_entities = [ + (state, entity) + for state, entity in states_and_entities + if state.domain in domains + ] - return state + if device_classes: + # Check device class in state attribute and in entity entry (if available) + states_and_entities = [ + (state, entity) + for state, entity in states_and_entities + if _is_device_class(state, entity, device_classes) + ] + if (area is None) and (area_name is not None): + # Look up area by name + if areas is None: + areas = area_registry.async_get(hass) -@callback -@bind_hass -def async_match_area( - hass: HomeAssistant, area_name: str -) -> area_registry.AreaEntry | None: - """Find an area that matches the name.""" - registry = area_registry.async_get(hass) - return registry.async_get_area(area_name) or registry.async_get_area_by_name( - area_name - ) + # id or name + area = areas.async_get_area(area_name) or areas.async_get_area_by_name( + area_name + ) + assert area is not None, f"No area named {area_name}" + + if area is not None: + # Filter by area + states_and_entities = [ + (state, entity) + for state, entity in states_and_entities + if (entity is not None) and (entity.area_id == area.id) + ] + + if name is not None: + # Filter by name + name = name.casefold() + + for state, entity in states_and_entities: + if _has_name(state, entity, name): + yield state + break + else: + # Not filtered by name + for state, _entity in states_and_entities: + yield state @callback @@ -229,102 +299,103 @@ class ServiceIntentHandler(IntentHandler): hass = intent_obj.hass slots = self.async_validate_slots(intent_obj.slots) - if "area" in slots: - # Entities in an area - area_name = slots["area"]["value"] - area = async_match_area(hass, area_name) - assert area is not None - assert area.id is not None + name: str | None = slots.get("name", {}).get("value") + if name == "all": + # Don't match on name if targeting all entities + name = None - # Optional domain filter - domains: set[str] | None = None - if "domain" in slots: - domains = set(slots["domain"]["value"]) + # Look up area first to fail early + area_name = slots.get("area", {}).get("value") + area: area_registry.AreaEntry | None = None + if area_name is not None: + areas = area_registry.async_get(hass) + area = areas.async_get_area(area_name) or areas.async_get_area_by_name( + area_name + ) + if area is None: + raise IntentHandleError(f"No area named {area_name}") - # Optional device class filter - device_classes: set[str] | None = None - if "device_class" in slots: - device_classes = set(slots["device_class"]["value"]) + # Optional domain/device class filters. + # Convert to sets for speed. + domains: set[str] | None = None + device_classes: set[str] | None = None - success_results = [ + if "domain" in slots: + domains = set(slots["domain"]["value"]) + + if "device_class" in slots: + device_classes = set(slots["device_class"]["value"]) + + states = list( + async_match_states( + hass, + name=name, + area=area, + domains=domains, + device_classes=device_classes, + ) + ) + + if not states: + raise IntentHandleError("No entities matched") + + response = await self.async_handle_states(intent_obj, states, area) + + return response + + async def async_handle_states( + self, + intent_obj: Intent, + states: list[State], + area: area_registry.AreaEntry | None = None, + ) -> IntentResponse: + """Complete action on matched entity states.""" + assert states + success_results: list[IntentResponseTarget] = [] + response = intent_obj.create_response() + + if area is not None: + success_results.append( IntentResponseTarget( type=IntentResponseTargetType.AREA, name=area.name, id=area.id ) - ] - service_coros = [] - registry = entity_registry.async_get(hass) - for entity_entry in entity_registry.async_entries_for_area( - registry, area.id - ): - if entity_entry.entity_category: - # Skip diagnostic entities - continue - - if domains and (entity_entry.domain not in domains): - # Skip entity not in the domain - continue - - if device_classes and (entity_entry.device_class not in device_classes): - # Skip entity with wrong device class - continue - - service_coros.append( - hass.services.async_call( - self.domain, - self.service, - {ATTR_ENTITY_ID: entity_entry.entity_id}, - context=intent_obj.context, - ) - ) - - state = hass.states.get(entity_entry.entity_id) - assert state is not None - - success_results.append( - IntentResponseTarget( - type=IntentResponseTargetType.ENTITY, - name=state.name, - id=entity_entry.entity_id, - ), - ) - - if not service_coros: - raise IntentHandleError("No entities matched") - - # Handle service calls in parallel. - # We will need to handle partial failures here. - await asyncio.gather(*service_coros) - - response = intent_obj.create_response() - response.async_set_speech(self.speech.format(area.name)) - response.async_set_results( - success_results=success_results, ) + speech_name = area.name else: - # Single entity - state = async_match_state(hass, slots["name"]["value"]) + speech_name = states[0].name - await hass.services.async_call( - self.domain, - self.service, - {ATTR_ENTITY_ID: state.entity_id}, - context=intent_obj.context, + service_coros = [] + for state in states: + service_coros.append(self.async_call_service(intent_obj, state)) + success_results.append( + IntentResponseTarget( + type=IntentResponseTargetType.ENTITY, + name=state.name, + id=state.entity_id, + ), ) - response = intent_obj.create_response() - response.async_set_speech(self.speech.format(state.name)) - response.async_set_results( - success_results=[ - IntentResponseTarget( - type=IntentResponseTargetType.ENTITY, - name=state.name, - id=state.entity_id, - ), - ], - ) + # Handle service calls in parallel. + # We will need to handle partial failures here. + await asyncio.gather(*service_coros) + + response.async_set_results( + success_results=success_results, + ) + response.async_set_speech(self.speech.format(speech_name)) return response + async def async_call_service(self, intent_obj: Intent, state: State) -> None: + """Call service on entity.""" + hass = intent_obj.hass + await hass.services.async_call( + self.domain, + self.service, + {ATTR_ENTITY_ID: state.entity_id}, + context=intent_obj.context, + ) + class IntentCategory(Enum): """Category of an intent.""" diff --git a/tests/components/light/test_intent.py b/tests/components/light/test_intent.py index 458e27bc6c6..9c665ded03b 100644 --- a/tests/components/light/test_intent.py +++ b/tests/components/light/test_intent.py @@ -2,7 +2,7 @@ from homeassistant.components import light from homeassistant.components.light import ATTR_SUPPORTED_COLOR_MODES, ColorMode, intent from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON -from homeassistant.helpers.intent import IntentHandleError, async_handle +from homeassistant.helpers.intent import async_handle from tests.common import async_mock_service @@ -40,17 +40,16 @@ async def test_intent_set_color_tests_feature(hass): calls = async_mock_service(hass, light.DOMAIN, light.SERVICE_TURN_ON) await intent.async_setup_intents(hass) - try: - await async_handle( - hass, - "test", - intent.INTENT_SET, - {"name": {"value": "Hello"}, "color": {"value": "blue"}}, - ) - assert False, "handling intent should have raised" - except IntentHandleError as err: - assert str(err) == "Entity hello does not support changing colors" + response = await async_handle( + hass, + "test", + intent.INTENT_SET, + {"name": {"value": "Hello"}, "color": {"value": "blue"}}, + ) + # Response should contain one failed target + assert len(response.success_results) == 0 + assert len(response.failed_results) == 1 assert len(calls) == 0 diff --git a/tests/helpers/test_intent.py b/tests/helpers/test_intent.py index 1d7aaeba366..f190f41072f 100644 --- a/tests/helpers/test_intent.py +++ b/tests/helpers/test_intent.py @@ -3,9 +3,15 @@ import pytest import voluptuous as vol +from homeassistant.components.switch import SwitchDeviceClass from homeassistant.const import ATTR_FRIENDLY_NAME from homeassistant.core import State -from homeassistant.helpers import config_validation as cv, entity_registry, intent +from homeassistant.helpers import ( + area_registry, + config_validation as cv, + entity_registry, + intent, +) class MockIntentHandler(intent.IntentHandler): @@ -16,25 +22,74 @@ class MockIntentHandler(intent.IntentHandler): self.slot_schema = slot_schema -async def test_async_match_state(hass): +async def test_async_match_states(hass): """Test async_match_state helper.""" + areas = area_registry.async_get(hass) + area_kitchen = areas.async_get_or_create("kitchen") + area_bedroom = areas.async_get_or_create("bedroom") + state1 = State( "light.kitchen", "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"} ) state2 = State( - "switch.kitchen", "on", attributes={ATTR_FRIENDLY_NAME: "kitchen switch"} + "switch.bedroom", "on", attributes={ATTR_FRIENDLY_NAME: "bedroom switch"} ) - registry = entity_registry.async_get(hass) - registry.async_get_or_create( - "switch", "demo", "1234", suggested_object_id="kitchen" + + # Put entities into different areas + entities = entity_registry.async_get(hass) + entities.async_get_or_create("light", "demo", "1234", suggested_object_id="kitchen") + entities.async_update_entity(state1.entity_id, area_id=area_kitchen.id) + + entities.async_get_or_create( + "switch", "demo", "1234", suggested_object_id="bedroom" + ) + entities.async_update_entity( + state2.entity_id, + area_id=area_bedroom.id, + device_class=SwitchDeviceClass.OUTLET, + aliases={"kill switch"}, ) - registry.async_update_entity(state2.entity_id, aliases={"kill switch"}) - state = intent.async_match_state(hass, "kitchen light", [state1, state2]) - assert state is state1 + # Match on name + assert [state1] == list( + intent.async_match_states(hass, name="kitchen light", states=[state1, state2]) + ) - state = intent.async_match_state(hass, "kill switch", [state1, state2]) - assert state is state2 + # Test alias + assert [state2] == list( + intent.async_match_states(hass, name="kill switch", states=[state1, state2]) + ) + + # Name + area + assert [state1] == list( + intent.async_match_states( + hass, name="kitchen light", area_name="kitchen", states=[state1, state2] + ) + ) + + # Wrong area + assert not list( + intent.async_match_states( + hass, name="kitchen light", area_name="bedroom", states=[state1, state2] + ) + ) + + # Domain + area + assert [state2] == list( + intent.async_match_states( + hass, domains={"switch"}, area_name="bedroom", states=[state1, state2] + ) + ) + + # Device class + area + assert [state2] == list( + intent.async_match_states( + hass, + device_classes={SwitchDeviceClass.OUTLET}, + area_name="bedroom", + states=[state1, state2], + ) + ) def test_async_validate_slots():