From 167e66d45c27e295b6468119beadea563f68d2df Mon Sep 17 00:00:00 2001 From: Franck Nijhof Date: Mon, 18 Mar 2024 22:32:23 +0100 Subject: [PATCH] Add labels to service target (#113753) --- homeassistant/const.py | 3 + homeassistant/helpers/config_validation.py | 7 + homeassistant/helpers/service.py | 50 ++++- tests/helpers/test_service.py | 205 ++++++++++++++++++++- 4 files changed, 256 insertions(+), 9 deletions(-) diff --git a/homeassistant/const.py b/homeassistant/const.py index 28409ba0907..ac934dcc0d6 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -513,6 +513,9 @@ ATTR_DEVICE_ID: Final = "device_id" # Contains one string or a list of strings, each being an floor id ATTR_FLOOR_ID: Final = "floor_id" +# Contains one string or a list of strings, each being an label id +ATTR_LABEL_ID: Final = "label_id" + # String with a friendly name for the entity ATTR_FRIENDLY_NAME: Final = "friendly_name" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index fbe98ccd387..bf666cf2e03 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -30,6 +30,7 @@ from homeassistant.const import ( ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_FLOOR_ID, + ATTR_LABEL_ID, CONF_ABOVE, CONF_ALIAS, CONF_ATTRIBUTE, @@ -1220,6 +1221,9 @@ ENTITY_SERVICE_FIELDS = { vol.Optional(ATTR_FLOOR_ID): vol.Any( ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)]) ), + vol.Optional(ATTR_LABEL_ID): vol.Any( + ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)]) + ), } TARGET_SERVICE_FIELDS = { @@ -1240,6 +1244,9 @@ TARGET_SERVICE_FIELDS = { vol.Optional(ATTR_FLOOR_ID): vol.Any( ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)]) ), + vol.Optional(ATTR_LABEL_ID): vol.Any( + ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)]) + ), } diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index df8284c0b4c..34cee93d971 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -19,6 +19,7 @@ from homeassistant.const import ( ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_FLOOR_ID, + ATTR_LABEL_ID, CONF_ENTITY_ID, CONF_SERVICE, CONF_SERVICE_DATA, @@ -55,6 +56,7 @@ from . import ( device_registry, entity_registry, floor_registry, + label_registry, template, translation, ) @@ -196,7 +198,7 @@ class ServiceParams(TypedDict): class ServiceTargetSelector: """Class to hold a target selector for a service.""" - __slots__ = ("entity_ids", "device_ids", "area_ids", "floor_ids") + __slots__ = ("entity_ids", "device_ids", "area_ids", "floor_ids", "label_ids") def __init__(self, service_call: ServiceCall) -> None: """Extract ids from service call data.""" @@ -205,6 +207,7 @@ class ServiceTargetSelector: device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID) area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID) floor_ids: str | list | None = service_call_data.get(ATTR_FLOOR_ID) + label_ids: str | list | None = service_call_data.get(ATTR_LABEL_ID) self.entity_ids = ( set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() @@ -216,12 +219,19 @@ class ServiceTargetSelector: self.floor_ids = ( set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set() ) + self.label_ids = ( + set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set() + ) @property def has_any_selector(self) -> bool: """Determine if any selectors are present.""" return bool( - self.entity_ids or self.device_ids or self.area_ids or self.floor_ids + self.entity_ids + or self.device_ids + or self.area_ids + or self.floor_ids + or self.label_ids ) @@ -232,7 +242,7 @@ class SelectedEntities: # Entities that were explicitly mentioned. referenced: set[str] = dataclasses.field(default_factory=set) - # Entities that were referenced via device/area/floor ID. + # Entities that were referenced via device/area/floor/label ID. # Should not trigger a warning when they don't exist. indirectly_referenced: set[str] = dataclasses.field(default_factory=set) @@ -240,6 +250,7 @@ class SelectedEntities: missing_devices: set[str] = dataclasses.field(default_factory=set) missing_areas: set[str] = dataclasses.field(default_factory=set) missing_floors: set[str] = dataclasses.field(default_factory=set) + missing_labels: set[str] = dataclasses.field(default_factory=set) # Referenced devices referenced_devices: set[str] = dataclasses.field(default_factory=set) @@ -253,6 +264,7 @@ class SelectedEntities: ("areas", self.missing_areas), ("devices", self.missing_devices), ("entities", missing_entities), + ("labels", self.missing_labels), ): if items: parts.append(f"{label} {', '.join(sorted(items))}") @@ -467,7 +479,7 @@ def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: @bind_hass -def async_extract_referenced_entity_ids( +def async_extract_referenced_entity_ids( # noqa: C901 hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True ) -> SelectedEntities: """Extract referenced entity IDs from a service call.""" @@ -483,13 +495,19 @@ def async_extract_referenced_entity_ids( selected.referenced.update(entity_ids) - if not selector.device_ids and not selector.area_ids and not selector.floor_ids: + if ( + not selector.device_ids + and not selector.area_ids + and not selector.floor_ids + and not selector.label_ids + ): return selected ent_reg = entity_registry.async_get(hass) dev_reg = device_registry.async_get(hass) area_reg = area_registry.async_get(hass) floor_reg = floor_registry.async_get(hass) + label_reg = label_registry.async_get(hass) for floor_id in selector.floor_ids: if floor_id not in floor_reg.floors: @@ -503,6 +521,28 @@ def async_extract_referenced_entity_ids( if device_id not in dev_reg.devices: selected.missing_devices.add(device_id) + for label_id in selector.label_ids: + if label_id not in label_reg.labels: + selected.missing_labels.add(label_id) + + # Find areas, devices & entities for targeted labels + if selector.label_ids: + for area_entry in area_reg.areas.values(): + if area_entry.labels.intersection(selector.label_ids): + selected.referenced_areas.add(area_entry.id) + + for device_entry in dev_reg.devices.values(): + if device_entry.labels.intersection(selector.label_ids): + selected.referenced_devices.add(device_entry.id) + + for entity_entry in ent_reg.entities.values(): + if ( + entity_entry.entity_category is None + and entity_entry.hidden_by is None + and entity_entry.labels.intersection(selector.label_ids) + ): + selected.indirectly_referenced.add(entity_entry.entity_id) + # Find areas for targeted floors if selector.floor_ids: for area_entry in area_reg.areas.values(): diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 2209e60a37d..4f889478460 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -266,6 +266,120 @@ def floor_area_mock(hass: HomeAssistant) -> None: ) +@pytest.fixture +def label_mock(hass: HomeAssistant) -> None: + """Mock including label info.""" + hass.states.async_set("light.bowl", STATE_ON) + hass.states.async_set("light.ceiling", STATE_OFF) + hass.states.async_set("light.kitchen", STATE_OFF) + + area_with_labels = ar.AreaEntry( + id="area-with-labels", + name="Area with labels", + aliases={}, + normalized_name="with_labels", + floor_id=None, + icon=None, + labels={"label_area"}, + picture=None, + ) + area_without_labels = ar.AreaEntry( + id="area-no-labels", + name="Area without labels", + aliases={}, + normalized_name="without_labels", + floor_id=None, + icon=None, + labels=set(), + picture=None, + ) + mock_area_registry( + hass, + { + area_with_labels.id: area_with_labels, + area_without_labels.id: area_without_labels, + }, + ) + + device_has_label1 = dr.DeviceEntry(labels={"label1"}) + device_has_label2 = dr.DeviceEntry(labels={"label2"}) + device_has_labels = dr.DeviceEntry( + labels={"label1", "label2"}, area_id=area_with_labels.id + ) + device_no_labels = dr.DeviceEntry( + id="device-no-labels", area_id=area_without_labels.id + ) + + mock_device_registry( + hass, + { + device_has_label1.id: device_has_label1, + device_has_label2.id: device_has_label2, + device_has_labels.id: device_has_labels, + device_no_labels.id: device_no_labels, + }, + ) + + entity_with_my_label = er.RegistryEntry( + entity_id="light.with_my_label", + unique_id="with_my_label", + platform="test", + labels={"my-label"}, + ) + hidden_entity_with_my_label = er.RegistryEntry( + entity_id="light.hidden_with_my_label", + unique_id="hidden_with_my_label", + platform="test", + labels={"my-label"}, + hidden_by=er.RegistryEntryHider.USER, + ) + config_entity_with_my_label = er.RegistryEntry( + entity_id="light.config_with_my_label", + unique_id="config_with_my_label", + platform="test", + labels={"my-label"}, + entity_category=EntityCategory.CONFIG, + ) + entity_with_label1_from_device = er.RegistryEntry( + entity_id="light.with_label1_from_device", + unique_id="with_label1_from_device", + platform="test", + device_id=device_has_label1.id, + ) + entity_with_label1_and_label2_from_device = er.RegistryEntry( + entity_id="light.with_label1_and_label2_from_device", + unique_id="with_label1_and_label2_from_device", + platform="test", + labels={"label1"}, + device_id=device_has_label2.id, + ) + entity_with_labels_from_device = er.RegistryEntry( + entity_id="light.with_labels_from_device", + unique_id="with_labels_from_device", + platform="test", + device_id=device_has_labels.id, + ) + entity_with_no_labels = er.RegistryEntry( + entity_id="light.no_labels", + unique_id="no_labels", + platform="test", + device_id=device_no_labels.id, + ) + + mock_registry( + hass, + { + config_entity_with_my_label.entity_id: config_entity_with_my_label, + entity_with_label1_and_label2_from_device.entity_id: entity_with_label1_and_label2_from_device, + entity_with_label1_from_device.entity_id: entity_with_label1_from_device, + entity_with_labels_from_device.entity_id: entity_with_labels_from_device, + entity_with_my_label.entity_id: entity_with_my_label, + entity_with_no_labels.entity_id: entity_with_no_labels, + hidden_entity_with_my_label.entity_id: hidden_entity_with_my_label, + }, + ) + + async def test_call_from_config(hass: HomeAssistant) -> None: """Test the sync wrapper of service.async_call_from_config.""" calls = async_mock_service(hass, "test_domain", "test_service") @@ -629,6 +743,44 @@ async def test_extract_entity_ids_from_floor(hass: HomeAssistant) -> None: ) +@pytest.mark.usefixtures("label_mock") +async def test_extract_entity_ids_from_labels(hass: HomeAssistant) -> None: + """Test extract_entity_ids method with labels.""" + call = ServiceCall("light", "turn_on", {"label_id": "my-label"}) + + assert { + "light.with_my_label", + } == await service.async_extract_entity_ids(hass, call) + + call = ServiceCall("light", "turn_on", {"label_id": "label1"}) + + assert { + "light.with_label1_from_device", + "light.with_labels_from_device", + "light.with_label1_and_label2_from_device", + } == await service.async_extract_entity_ids(hass, call) + + call = ServiceCall("light", "turn_on", {"label_id": ["label2"]}) + + assert { + "light.with_labels_from_device", + "light.with_label1_and_label2_from_device", + } == await service.async_extract_entity_ids(hass, call) + + call = ServiceCall("light", "turn_on", {"label_id": ["label_area"]}) + + assert { + "light.with_labels_from_device", + } == await service.async_extract_entity_ids(hass, call) + + assert ( + await service.async_extract_entity_ids( + hass, ServiceCall("light", "turn_on", {"label_id": ENTITY_MATCH_NONE}) + ) + == set() + ) + + async def test_async_get_all_descriptions(hass: HomeAssistant) -> None: """Test async_get_all_descriptions.""" group_config = {DOMAIN_GROUP: {}} @@ -1578,6 +1730,49 @@ async def test_extract_from_service_area_id( ] +@pytest.mark.usefixtures("label_mock") +async def test_extract_from_service_label_id(hass: HomeAssistant) -> None: + """Test the extraction using label ID as reference.""" + entities = [ + MockEntity(name="with_my_label", entity_id="light.with_my_label"), + MockEntity(name="no_labels", entity_id="light.no_labels"), + MockEntity( + name="with_labels_from_device", entity_id="light.with_labels_from_device" + ), + ] + + call = ServiceCall("light", "turn_on", {"label_id": "label_area"}) + extracted = await service.async_extract_entities(hass, entities, call) + assert len(extracted) == 1 + assert extracted[0].entity_id == "light.with_labels_from_device" + + call = ServiceCall("light", "turn_on", {"label_id": "my-label"}) + extracted = await service.async_extract_entities(hass, entities, call) + assert len(extracted) == 1 + assert extracted[0].entity_id == "light.with_my_label" + + call = ServiceCall("light", "turn_on", {"label_id": ["my-label", "label1"]}) + extracted = await service.async_extract_entities(hass, entities, call) + assert len(extracted) == 2 + assert sorted(ent.entity_id for ent in extracted) == [ + "light.with_labels_from_device", + "light.with_my_label", + ] + + call = ServiceCall( + "light", + "turn_on", + {"label_id": ["my-label", "label1"], "device_id": "device-no-labels"}, + ) + extracted = await service.async_extract_entities(hass, entities, call) + assert len(extracted) == 3 + assert sorted(ent.entity_id for ent in extracted) == [ + "light.no_labels", + "light.with_labels_from_device", + "light.with_my_label", + ] + + async def test_entity_service_call_warn_referenced( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: @@ -1590,13 +1785,14 @@ async def test_entity_service_call_warn_referenced( "entity_id": "non.existent", "device_id": "non-existent-device", "floor_id": "non-existent-floor", + "label_id": "non-existent-label", }, ) await service.entity_service_call(hass, {}, "", call) assert ( "Referenced floors non-existent-floor, areas non-existent-area, " - "devices non-existent-device, entities non.existent are missing " - "or not currently available" + "devices non-existent-device, entities non.existent, " + "labels non-existent-label are missing or not currently available" ) in caplog.text @@ -1612,14 +1808,15 @@ async def test_async_extract_entities_warn_referenced( "entity_id": "non.existent", "device_id": "non-existent-device", "floor_id": "non-existent-floor", + "label_id": "non-existent-label", }, ) extracted = await service.async_extract_entities(hass, {}, call) assert len(extracted) == 0 assert ( "Referenced floors non-existent-floor, areas non-existent-area, " - "devices non-existent-device, entities non.existent are missing " - "or not currently available" + "devices non-existent-device, entities non.existent, " + "labels non-existent-label are missing or not currently available" ) in caplog.text