From 2aadd643edb1f6da283322b35120d5a989b2445b Mon Sep 17 00:00:00 2001 From: Franck Nijhof Date: Thu, 14 Mar 2024 19:02:23 +0100 Subject: [PATCH] Add floors to service target (#110850) --- homeassistant/const.py | 3 + homeassistant/helpers/config_validation.py | 7 ++ homeassistant/helpers/service.py | 45 ++++++++--- tests/helpers/test_service.py | 91 +++++++++++++++++++--- 4 files changed, 125 insertions(+), 21 deletions(-) diff --git a/homeassistant/const.py b/homeassistant/const.py index 6331b5d46ce..04561f489dd 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -510,6 +510,9 @@ ATTR_AREA_ID: Final = "area_id" # Contains one string, the device ID ATTR_DEVICE_ID: Final = "device_id" +# Contains one string or a list of strings, each being an floor id +ATTR_FLOOR_ID: Final = "floor_id" + # String with a friendly name for the entity ATTR_FRIENDLY_NAME: Final = "friendly_name" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index c545862f48a..8f9e0d5353f 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -29,6 +29,7 @@ from homeassistant.const import ( ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID, + ATTR_FLOOR_ID, CONF_ABOVE, CONF_ALIAS, CONF_ATTRIBUTE, @@ -1216,6 +1217,9 @@ ENTITY_SERVICE_FIELDS = { vol.Optional(ATTR_AREA_ID): vol.Any( ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)]) ), + vol.Optional(ATTR_FLOOR_ID): vol.Any( + ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)]) + ), } TARGET_SERVICE_FIELDS = { @@ -1233,6 +1237,9 @@ TARGET_SERVICE_FIELDS = { vol.Optional(ATTR_AREA_ID): vol.Any( ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)]) ), + vol.Optional(ATTR_FLOOR_ID): vol.Any( + ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)]) + ), } diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 99c15c3412e..b58e0831ccb 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -18,6 +18,7 @@ from homeassistant.const import ( ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID, + ATTR_FLOOR_ID, CONF_ENTITY_ID, CONF_SERVICE, CONF_SERVICE_DATA, @@ -53,6 +54,7 @@ from . import ( config_validation as cv, device_registry, entity_registry, + floor_registry, template, translation, ) @@ -194,7 +196,7 @@ class ServiceParams(TypedDict): class ServiceTargetSelector: """Class to hold a target selector for a service.""" - __slots__ = ("entity_ids", "device_ids", "area_ids") + __slots__ = ("entity_ids", "device_ids", "area_ids", "floor_ids") def __init__(self, service_call: ServiceCall) -> None: """Extract ids from service call data.""" @@ -202,6 +204,7 @@ class ServiceTargetSelector: entity_ids: str | list | None = service_call_data.get(ATTR_ENTITY_ID) device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID) area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID) + floor_ids: str | list | None = service_call_data.get(ATTR_FLOOR_ID) self.entity_ids = ( set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() @@ -210,11 +213,16 @@ class ServiceTargetSelector: set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set() ) self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set() + self.floor_ids = ( + set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set() + ) @property def has_any_selector(self) -> bool: """Determine if any selectors are present.""" - return bool(self.entity_ids or self.device_ids or self.area_ids) + return bool( + self.entity_ids or self.device_ids or self.area_ids or self.floor_ids + ) @dataclasses.dataclass(slots=True) @@ -224,21 +232,24 @@ class SelectedEntities: # Entities that were explicitly mentioned. referenced: set[str] = dataclasses.field(default_factory=set) - # Entities that were referenced via device/area ID. + # Entities that were referenced via device/area/floor ID. # Should not trigger a warning when they don't exist. indirectly_referenced: set[str] = dataclasses.field(default_factory=set) # Referenced items that could not be found. missing_devices: set[str] = dataclasses.field(default_factory=set) missing_areas: set[str] = dataclasses.field(default_factory=set) + missing_floors: set[str] = dataclasses.field(default_factory=set) # Referenced devices referenced_devices: set[str] = dataclasses.field(default_factory=set) + referenced_areas: set[str] = dataclasses.field(default_factory=set) def log_missing(self, missing_entities: set[str]) -> None: """Log about missing items.""" parts = [] for label, items in ( + ("floors", self.missing_floors), ("areas", self.missing_areas), ("devices", self.missing_devices), ("entities", missing_entities), @@ -472,37 +483,49 @@ def async_extract_referenced_entity_ids( selected.referenced.update(entity_ids) - if not selector.device_ids and not selector.area_ids: + if not selector.device_ids and not selector.area_ids and not selector.floor_ids: return selected ent_reg = entity_registry.async_get(hass) dev_reg = device_registry.async_get(hass) area_reg = area_registry.async_get(hass) + floor_reg = floor_registry.async_get(hass) - for device_id in selector.device_ids: - if device_id not in dev_reg.devices: - selected.missing_devices.add(device_id) + for floor_id in selector.floor_ids: + if floor_id not in floor_reg.floors: + selected.missing_floors.add(floor_id) for area_id in selector.area_ids: if area_id not in area_reg.areas: selected.missing_areas.add(area_id) + for device_id in selector.device_ids: + if device_id not in dev_reg.devices: + selected.missing_devices.add(device_id) + + # Find areas for targeted floors + if selector.floor_ids: + for area_entry in area_reg.areas.values(): + if area_entry.id and area_entry.floor_id in selector.floor_ids: + selected.referenced_areas.add(area_entry.id) + # Find devices for targeted areas selected.referenced_devices.update(selector.device_ids) - if selector.area_ids: + selected.referenced_areas.update(selector.area_ids) + if selected.referenced_areas: for device_entry in dev_reg.devices.values(): - if device_entry.area_id in selector.area_ids: + if device_entry.area_id in selected.referenced_areas: selected.referenced_devices.add(device_entry.id) - if not selector.area_ids and not selected.referenced_devices: + if not selected.referenced_areas and not selected.referenced_devices: return selected entities = ent_reg.entities # Add indirectly referenced by area selected.indirectly_referenced.update( entry.entity_id - for area_id in selector.area_ids + for area_id in selected.referenced_areas # The entity's area matches a targeted area for entry in entities.get_entries_for_area_id(area_id) # Do not add entities which are hidden or which are config diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 1a37de217d9..2209e60a37d 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -32,6 +32,7 @@ from homeassistant.core import ( SupportsResponse, ) from homeassistant.helpers import ( + area_registry as ar, device_registry as dr, entity_registry as er, service, @@ -45,6 +46,7 @@ from tests.common import ( MockEntity, MockUser, async_mock_service, + mock_area_registry, mock_device_registry, mock_registry, ) @@ -102,12 +104,38 @@ def mock_entities(hass: HomeAssistant) -> dict[str, MockEntity]: @pytest.fixture -def area_mock(hass): - """Mock including area info.""" +def floor_area_mock(hass: HomeAssistant) -> None: + """Mock including floor and area info.""" hass.states.async_set("light.Bowl", STATE_ON) hass.states.async_set("light.Ceiling", STATE_OFF) hass.states.async_set("light.Kitchen", STATE_OFF) + area_in_floor = ar.AreaEntry( + id="test-area", + name="Test area", + aliases={}, + normalized_name="test-area", + floor_id="test-floor", + icon=None, + picture=None, + ) + area_in_floor_a = ar.AreaEntry( + id="area-a", + name="Area A", + aliases={}, + normalized_name="area-a", + floor_id="floor-a", + icon=None, + picture=None, + ) + mock_area_registry( + hass, + { + area_in_floor.id: area_in_floor, + area_in_floor_a.id: area_in_floor_a, + }, + ) + device_in_area = dr.DeviceEntry(area_id="test-area") device_no_area = dr.DeviceEntry(id="device-no-area-id") device_diff_area = dr.DeviceEntry(area_id="diff-area") @@ -264,7 +292,11 @@ async def test_service_call(hass: HomeAssistant) -> None: "effect": {"value": "{{ 'complex' }}", "simple": "simple"}, }, "data_template": {"list": ["{{ 'list' }}", "2"]}, - "target": {"area_id": "test-area-id", "entity_id": "will.be_overridden"}, + "target": { + "area_id": "test-area-id", + "entity_id": "will.be_overridden", + "floor_id": "test-floor-id", + }, } await service.async_call_from_config(hass, config) @@ -279,6 +311,7 @@ async def test_service_call(hass: HomeAssistant) -> None: "list": ["list", "2"], "entity_id": ["hello.world"], "area_id": ["test-area-id"], + "floor_id": ["test-floor-id"], } config = { @@ -287,6 +320,7 @@ async def test_service_call(hass: HomeAssistant) -> None: "area_id": ["area-42", "{{ 'area-51' }}"], "device_id": ["abcdef", "{{ 'fedcba' }}"], "entity_id": ["light.static", "{{ 'light.dynamic' }}"], + "floor_id": ["floor-first", "{{ 'floor-second' }}"], }, } @@ -297,6 +331,7 @@ async def test_service_call(hass: HomeAssistant) -> None: "area_id": ["area-42", "area-51"], "device_id": ["abcdef", "fedcba"], "entity_id": ["light.static", "light.dynamic"], + "floor_id": ["floor-first", "floor-second"], } config = { @@ -510,7 +545,9 @@ async def test_extract_entity_ids(hass: HomeAssistant) -> None: ) -async def test_extract_entity_ids_from_area(hass: HomeAssistant, area_mock) -> None: +async def test_extract_entity_ids_from_area( + hass: HomeAssistant, floor_area_mock +) -> None: """Test extract_entity_ids method with areas.""" call = ServiceCall("light", "turn_on", {"area_id": "own-area"}) @@ -541,7 +578,9 @@ async def test_extract_entity_ids_from_area(hass: HomeAssistant, area_mock) -> N ) -async def test_extract_entity_ids_from_devices(hass: HomeAssistant, area_mock) -> None: +async def test_extract_entity_ids_from_devices( + hass: HomeAssistant, floor_area_mock +) -> None: """Test extract_entity_ids method with devices.""" assert await service.async_extract_entity_ids( hass, ServiceCall("light", "turn_on", {"device_id": "device-no-area-id"}) @@ -564,6 +603,32 @@ async def test_extract_entity_ids_from_devices(hass: HomeAssistant, area_mock) - ) +@pytest.mark.usefixtures("floor_area_mock") +async def test_extract_entity_ids_from_floor(hass: HomeAssistant) -> None: + """Test extract_entity_ids method with floors.""" + call = ServiceCall("light", "turn_on", {"floor_id": "test-floor"}) + + assert { + "light.in_area", + "light.assigned_to_area", + } == await service.async_extract_entity_ids(hass, call) + + call = ServiceCall("light", "turn_on", {"floor_id": ["test-floor", "floor-a"]}) + + assert { + "light.in_area", + "light.assigned_to_area", + "light.in_area_a", + } == await service.async_extract_entity_ids(hass, call) + + assert ( + await service.async_extract_entity_ids( + hass, ServiceCall("light", "turn_on", {"floor_id": ENTITY_MATCH_NONE}) + ) + == set() + ) + + async def test_async_get_all_descriptions(hass: HomeAssistant) -> None: """Test async_get_all_descriptions.""" group_config = {DOMAIN_GROUP: {}} @@ -1476,7 +1541,9 @@ async def test_extract_from_service_filter_out_non_existing_entities( ] -async def test_extract_from_service_area_id(hass: HomeAssistant, area_mock) -> None: +async def test_extract_from_service_area_id( + hass: HomeAssistant, floor_area_mock +) -> None: """Test the extraction using area ID as reference.""" entities = [ MockEntity(name="in_area", entity_id="light.in_area"), @@ -1522,12 +1589,14 @@ async def test_entity_service_call_warn_referenced( "area_id": "non-existent-area", "entity_id": "non.existent", "device_id": "non-existent-device", + "floor_id": "non-existent-floor", }, ) await service.entity_service_call(hass, {}, "", call) assert ( - "Referenced areas non-existent-area, devices non-existent-device, " - "entities non.existent are missing or not currently available" + "Referenced floors non-existent-floor, areas non-existent-area, " + "devices non-existent-device, entities non.existent are missing " + "or not currently available" ) in caplog.text @@ -1542,13 +1611,15 @@ async def test_async_extract_entities_warn_referenced( "area_id": "non-existent-area", "entity_id": "non.existent", "device_id": "non-existent-device", + "floor_id": "non-existent-floor", }, ) extracted = await service.async_extract_entities(hass, {}, call) assert len(extracted) == 0 assert ( - "Referenced areas non-existent-area, devices non-existent-device, " - "entities non.existent are missing or not currently available" + "Referenced floors non-existent-floor, areas non-existent-area, " + "devices non-existent-device, entities non.existent are missing " + "or not currently available" ) in caplog.text