Add floors to service target (#110850)

This commit is contained in:
Franck Nijhof 2024-03-14 19:02:23 +01:00 committed by GitHub
parent 20626947db
commit 2aadd643ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 125 additions and 21 deletions

View File

@ -510,6 +510,9 @@ ATTR_AREA_ID: Final = "area_id"
# Contains one string, the device ID # Contains one string, the device ID
ATTR_DEVICE_ID: Final = "device_id" ATTR_DEVICE_ID: Final = "device_id"
# Contains one string or a list of strings, each being an floor id
ATTR_FLOOR_ID: Final = "floor_id"
# String with a friendly name for the entity # String with a friendly name for the entity
ATTR_FRIENDLY_NAME: Final = "friendly_name" ATTR_FRIENDLY_NAME: Final = "friendly_name"

View File

@ -29,6 +29,7 @@ from homeassistant.const import (
ATTR_AREA_ID, ATTR_AREA_ID,
ATTR_DEVICE_ID, ATTR_DEVICE_ID,
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
CONF_ABOVE, CONF_ABOVE,
CONF_ALIAS, CONF_ALIAS,
CONF_ATTRIBUTE, CONF_ATTRIBUTE,
@ -1216,6 +1217,9 @@ ENTITY_SERVICE_FIELDS = {
vol.Optional(ATTR_AREA_ID): vol.Any( vol.Optional(ATTR_AREA_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)]) ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
), ),
vol.Optional(ATTR_FLOOR_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
),
} }
TARGET_SERVICE_FIELDS = { TARGET_SERVICE_FIELDS = {
@ -1233,6 +1237,9 @@ TARGET_SERVICE_FIELDS = {
vol.Optional(ATTR_AREA_ID): vol.Any( vol.Optional(ATTR_AREA_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)]) ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
), ),
vol.Optional(ATTR_FLOOR_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
),
} }

View File

@ -18,6 +18,7 @@ from homeassistant.const import (
ATTR_AREA_ID, ATTR_AREA_ID,
ATTR_DEVICE_ID, ATTR_DEVICE_ID,
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
CONF_ENTITY_ID, CONF_ENTITY_ID,
CONF_SERVICE, CONF_SERVICE,
CONF_SERVICE_DATA, CONF_SERVICE_DATA,
@ -53,6 +54,7 @@ from . import (
config_validation as cv, config_validation as cv,
device_registry, device_registry,
entity_registry, entity_registry,
floor_registry,
template, template,
translation, translation,
) )
@ -194,7 +196,7 @@ class ServiceParams(TypedDict):
class ServiceTargetSelector: class ServiceTargetSelector:
"""Class to hold a target selector for a service.""" """Class to hold a target selector for a service."""
__slots__ = ("entity_ids", "device_ids", "area_ids") __slots__ = ("entity_ids", "device_ids", "area_ids", "floor_ids")
def __init__(self, service_call: ServiceCall) -> None: def __init__(self, service_call: ServiceCall) -> None:
"""Extract ids from service call data.""" """Extract ids from service call data."""
@ -202,6 +204,7 @@ class ServiceTargetSelector:
entity_ids: str | list | None = service_call_data.get(ATTR_ENTITY_ID) entity_ids: str | list | None = service_call_data.get(ATTR_ENTITY_ID)
device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID) device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID)
area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID) area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID)
floor_ids: str | list | None = service_call_data.get(ATTR_FLOOR_ID)
self.entity_ids = ( self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
@ -210,11 +213,16 @@ class ServiceTargetSelector:
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set() set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
) )
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set() self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
self.floor_ids = (
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
)
@property @property
def has_any_selector(self) -> bool: def has_any_selector(self) -> bool:
"""Determine if any selectors are present.""" """Determine if any selectors are present."""
return bool(self.entity_ids or self.device_ids or self.area_ids) return bool(
self.entity_ids or self.device_ids or self.area_ids or self.floor_ids
)
@dataclasses.dataclass(slots=True) @dataclasses.dataclass(slots=True)
@ -224,21 +232,24 @@ class SelectedEntities:
# Entities that were explicitly mentioned. # Entities that were explicitly mentioned.
referenced: set[str] = dataclasses.field(default_factory=set) referenced: set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area ID. # Entities that were referenced via device/area/floor ID.
# Should not trigger a warning when they don't exist. # Should not trigger a warning when they don't exist.
indirectly_referenced: set[str] = dataclasses.field(default_factory=set) indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found. # Referenced items that could not be found.
missing_devices: set[str] = dataclasses.field(default_factory=set) missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set) missing_areas: set[str] = dataclasses.field(default_factory=set)
missing_floors: set[str] = dataclasses.field(default_factory=set)
# Referenced devices # Referenced devices
referenced_devices: set[str] = dataclasses.field(default_factory=set) referenced_devices: set[str] = dataclasses.field(default_factory=set)
referenced_areas: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: set[str]) -> None: def log_missing(self, missing_entities: set[str]) -> None:
"""Log about missing items.""" """Log about missing items."""
parts = [] parts = []
for label, items in ( for label, items in (
("floors", self.missing_floors),
("areas", self.missing_areas), ("areas", self.missing_areas),
("devices", self.missing_devices), ("devices", self.missing_devices),
("entities", missing_entities), ("entities", missing_entities),
@ -472,37 +483,49 @@ def async_extract_referenced_entity_ids(
selected.referenced.update(entity_ids) selected.referenced.update(entity_ids)
if not selector.device_ids and not selector.area_ids: if not selector.device_ids and not selector.area_ids and not selector.floor_ids:
return selected return selected
ent_reg = entity_registry.async_get(hass) ent_reg = entity_registry.async_get(hass)
dev_reg = device_registry.async_get(hass) dev_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass) area_reg = area_registry.async_get(hass)
floor_reg = floor_registry.async_get(hass)
for device_id in selector.device_ids: for floor_id in selector.floor_ids:
if device_id not in dev_reg.devices: if floor_id not in floor_reg.floors:
selected.missing_devices.add(device_id) selected.missing_floors.add(floor_id)
for area_id in selector.area_ids: for area_id in selector.area_ids:
if area_id not in area_reg.areas: if area_id not in area_reg.areas:
selected.missing_areas.add(area_id) selected.missing_areas.add(area_id)
for device_id in selector.device_ids:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
# Find areas for targeted floors
if selector.floor_ids:
for area_entry in area_reg.areas.values():
if area_entry.id and area_entry.floor_id in selector.floor_ids:
selected.referenced_areas.add(area_entry.id)
# Find devices for targeted areas # Find devices for targeted areas
selected.referenced_devices.update(selector.device_ids) selected.referenced_devices.update(selector.device_ids)
if selector.area_ids: selected.referenced_areas.update(selector.area_ids)
if selected.referenced_areas:
for device_entry in dev_reg.devices.values(): for device_entry in dev_reg.devices.values():
if device_entry.area_id in selector.area_ids: if device_entry.area_id in selected.referenced_areas:
selected.referenced_devices.add(device_entry.id) selected.referenced_devices.add(device_entry.id)
if not selector.area_ids and not selected.referenced_devices: if not selected.referenced_areas and not selected.referenced_devices:
return selected return selected
entities = ent_reg.entities entities = ent_reg.entities
# Add indirectly referenced by area # Add indirectly referenced by area
selected.indirectly_referenced.update( selected.indirectly_referenced.update(
entry.entity_id entry.entity_id
for area_id in selector.area_ids for area_id in selected.referenced_areas
# The entity's area matches a targeted area # The entity's area matches a targeted area
for entry in entities.get_entries_for_area_id(area_id) for entry in entities.get_entries_for_area_id(area_id)
# Do not add entities which are hidden or which are config # Do not add entities which are hidden or which are config

View File

@ -32,6 +32,7 @@ from homeassistant.core import (
SupportsResponse, SupportsResponse,
) )
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry as ar,
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
service, service,
@ -45,6 +46,7 @@ from tests.common import (
MockEntity, MockEntity,
MockUser, MockUser,
async_mock_service, async_mock_service,
mock_area_registry,
mock_device_registry, mock_device_registry,
mock_registry, mock_registry,
) )
@ -102,12 +104,38 @@ def mock_entities(hass: HomeAssistant) -> dict[str, MockEntity]:
@pytest.fixture @pytest.fixture
def area_mock(hass): def floor_area_mock(hass: HomeAssistant) -> None:
"""Mock including area info.""" """Mock including floor and area info."""
hass.states.async_set("light.Bowl", STATE_ON) hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF) hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set("light.Kitchen", STATE_OFF) hass.states.async_set("light.Kitchen", STATE_OFF)
area_in_floor = ar.AreaEntry(
id="test-area",
name="Test area",
aliases={},
normalized_name="test-area",
floor_id="test-floor",
icon=None,
picture=None,
)
area_in_floor_a = ar.AreaEntry(
id="area-a",
name="Area A",
aliases={},
normalized_name="area-a",
floor_id="floor-a",
icon=None,
picture=None,
)
mock_area_registry(
hass,
{
area_in_floor.id: area_in_floor,
area_in_floor_a.id: area_in_floor_a,
},
)
device_in_area = dr.DeviceEntry(area_id="test-area") device_in_area = dr.DeviceEntry(area_id="test-area")
device_no_area = dr.DeviceEntry(id="device-no-area-id") device_no_area = dr.DeviceEntry(id="device-no-area-id")
device_diff_area = dr.DeviceEntry(area_id="diff-area") device_diff_area = dr.DeviceEntry(area_id="diff-area")
@ -264,7 +292,11 @@ async def test_service_call(hass: HomeAssistant) -> None:
"effect": {"value": "{{ 'complex' }}", "simple": "simple"}, "effect": {"value": "{{ 'complex' }}", "simple": "simple"},
}, },
"data_template": {"list": ["{{ 'list' }}", "2"]}, "data_template": {"list": ["{{ 'list' }}", "2"]},
"target": {"area_id": "test-area-id", "entity_id": "will.be_overridden"}, "target": {
"area_id": "test-area-id",
"entity_id": "will.be_overridden",
"floor_id": "test-floor-id",
},
} }
await service.async_call_from_config(hass, config) await service.async_call_from_config(hass, config)
@ -279,6 +311,7 @@ async def test_service_call(hass: HomeAssistant) -> None:
"list": ["list", "2"], "list": ["list", "2"],
"entity_id": ["hello.world"], "entity_id": ["hello.world"],
"area_id": ["test-area-id"], "area_id": ["test-area-id"],
"floor_id": ["test-floor-id"],
} }
config = { config = {
@ -287,6 +320,7 @@ async def test_service_call(hass: HomeAssistant) -> None:
"area_id": ["area-42", "{{ 'area-51' }}"], "area_id": ["area-42", "{{ 'area-51' }}"],
"device_id": ["abcdef", "{{ 'fedcba' }}"], "device_id": ["abcdef", "{{ 'fedcba' }}"],
"entity_id": ["light.static", "{{ 'light.dynamic' }}"], "entity_id": ["light.static", "{{ 'light.dynamic' }}"],
"floor_id": ["floor-first", "{{ 'floor-second' }}"],
}, },
} }
@ -297,6 +331,7 @@ async def test_service_call(hass: HomeAssistant) -> None:
"area_id": ["area-42", "area-51"], "area_id": ["area-42", "area-51"],
"device_id": ["abcdef", "fedcba"], "device_id": ["abcdef", "fedcba"],
"entity_id": ["light.static", "light.dynamic"], "entity_id": ["light.static", "light.dynamic"],
"floor_id": ["floor-first", "floor-second"],
} }
config = { config = {
@ -510,7 +545,9 @@ async def test_extract_entity_ids(hass: HomeAssistant) -> None:
) )
async def test_extract_entity_ids_from_area(hass: HomeAssistant, area_mock) -> None: async def test_extract_entity_ids_from_area(
hass: HomeAssistant, floor_area_mock
) -> None:
"""Test extract_entity_ids method with areas.""" """Test extract_entity_ids method with areas."""
call = ServiceCall("light", "turn_on", {"area_id": "own-area"}) call = ServiceCall("light", "turn_on", {"area_id": "own-area"})
@ -541,7 +578,9 @@ async def test_extract_entity_ids_from_area(hass: HomeAssistant, area_mock) -> N
) )
async def test_extract_entity_ids_from_devices(hass: HomeAssistant, area_mock) -> None: async def test_extract_entity_ids_from_devices(
hass: HomeAssistant, floor_area_mock
) -> None:
"""Test extract_entity_ids method with devices.""" """Test extract_entity_ids method with devices."""
assert await service.async_extract_entity_ids( assert await service.async_extract_entity_ids(
hass, ServiceCall("light", "turn_on", {"device_id": "device-no-area-id"}) hass, ServiceCall("light", "turn_on", {"device_id": "device-no-area-id"})
@ -564,6 +603,32 @@ async def test_extract_entity_ids_from_devices(hass: HomeAssistant, area_mock) -
) )
@pytest.mark.usefixtures("floor_area_mock")
async def test_extract_entity_ids_from_floor(hass: HomeAssistant) -> None:
"""Test extract_entity_ids method with floors."""
call = ServiceCall("light", "turn_on", {"floor_id": "test-floor"})
assert {
"light.in_area",
"light.assigned_to_area",
} == await service.async_extract_entity_ids(hass, call)
call = ServiceCall("light", "turn_on", {"floor_id": ["test-floor", "floor-a"]})
assert {
"light.in_area",
"light.assigned_to_area",
"light.in_area_a",
} == await service.async_extract_entity_ids(hass, call)
assert (
await service.async_extract_entity_ids(
hass, ServiceCall("light", "turn_on", {"floor_id": ENTITY_MATCH_NONE})
)
== set()
)
async def test_async_get_all_descriptions(hass: HomeAssistant) -> None: async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
"""Test async_get_all_descriptions.""" """Test async_get_all_descriptions."""
group_config = {DOMAIN_GROUP: {}} group_config = {DOMAIN_GROUP: {}}
@ -1476,7 +1541,9 @@ async def test_extract_from_service_filter_out_non_existing_entities(
] ]
async def test_extract_from_service_area_id(hass: HomeAssistant, area_mock) -> None: async def test_extract_from_service_area_id(
hass: HomeAssistant, floor_area_mock
) -> None:
"""Test the extraction using area ID as reference.""" """Test the extraction using area ID as reference."""
entities = [ entities = [
MockEntity(name="in_area", entity_id="light.in_area"), MockEntity(name="in_area", entity_id="light.in_area"),
@ -1522,12 +1589,14 @@ async def test_entity_service_call_warn_referenced(
"area_id": "non-existent-area", "area_id": "non-existent-area",
"entity_id": "non.existent", "entity_id": "non.existent",
"device_id": "non-existent-device", "device_id": "non-existent-device",
"floor_id": "non-existent-floor",
}, },
) )
await service.entity_service_call(hass, {}, "", call) await service.entity_service_call(hass, {}, "", call)
assert ( assert (
"Referenced areas non-existent-area, devices non-existent-device, " "Referenced floors non-existent-floor, areas non-existent-area, "
"entities non.existent are missing or not currently available" "devices non-existent-device, entities non.existent are missing "
"or not currently available"
) in caplog.text ) in caplog.text
@ -1542,13 +1611,15 @@ async def test_async_extract_entities_warn_referenced(
"area_id": "non-existent-area", "area_id": "non-existent-area",
"entity_id": "non.existent", "entity_id": "non.existent",
"device_id": "non-existent-device", "device_id": "non-existent-device",
"floor_id": "non-existent-floor",
}, },
) )
extracted = await service.async_extract_entities(hass, {}, call) extracted = await service.async_extract_entities(hass, {}, call)
assert len(extracted) == 0 assert len(extracted) == 0
assert ( assert (
"Referenced areas non-existent-area, devices non-existent-device, " "Referenced floors non-existent-floor, areas non-existent-area, "
"entities non.existent are missing or not currently available" "devices non-existent-device, entities non.existent are missing "
"or not currently available"
) in caplog.text ) in caplog.text