Move target selector extractor method to common module (#148087)

This commit is contained in:
Abílio Costa 2025-07-07 13:48:48 +01:00 committed by GitHub
parent c60e06d32f
commit b71bcb002b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 855 additions and 228 deletions

View File

@ -44,11 +44,14 @@ from homeassistant.helpers.entity_component import async_update_entity
from homeassistant.helpers.issue_registry import IssueSeverity
from homeassistant.helpers.service import (
async_extract_config_entry_ids,
async_extract_referenced_entity_ids,
async_register_admin_service,
)
from homeassistant.helpers.signal import KEY_HA_STOP
from homeassistant.helpers.system_info import async_get_system_info
from homeassistant.helpers.target import (
TargetSelectorData,
async_extract_referenced_entity_ids,
)
from homeassistant.helpers.template import async_load_custom_templates
from homeassistant.helpers.typing import ConfigType
@ -111,7 +114,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa:
async def async_handle_turn_service(service: ServiceCall) -> None:
"""Handle calls to homeassistant.turn_on/off."""
referenced = async_extract_referenced_entity_ids(hass, service)
referenced = async_extract_referenced_entity_ids(
hass, TargetSelectorData(service.data)
)
all_referenced = referenced.referenced | referenced.indirectly_referenced
# Generic turn on/off method requires entity id

View File

@ -75,11 +75,12 @@ from homeassistant.helpers.entityfilter import (
EntityFilter,
)
from homeassistant.helpers.reload import async_integration_yaml_config
from homeassistant.helpers.service import (
async_extract_referenced_entity_ids,
async_register_admin_service,
)
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.start import async_at_started
from homeassistant.helpers.target import (
TargetSelectorData,
async_extract_referenced_entity_ids,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.util.async_ import create_eager_task
@ -482,7 +483,9 @@ def _async_register_events_and_services(hass: HomeAssistant) -> None:
async def async_handle_homekit_unpair(service: ServiceCall) -> None:
"""Handle unpair HomeKit service call."""
referenced = async_extract_referenced_entity_ids(hass, service)
referenced = async_extract_referenced_entity_ids(
hass, TargetSelectorData(service.data)
)
dev_reg = dr.async_get(hass)
for device_id in referenced.referenced_devices:
if not (dev_reg_ent := dev_reg.async_get(device_id)):

View File

@ -28,7 +28,10 @@ from homeassistant.components.light import (
from homeassistant.const import ATTR_MODE
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_extract_referenced_entity_ids
from homeassistant.helpers.target import (
TargetSelectorData,
async_extract_referenced_entity_ids,
)
from .const import _ATTR_COLOR_TEMP, ATTR_THEME, DOMAIN
from .coordinator import LIFXConfigEntry, LIFXUpdateCoordinator
@ -268,7 +271,9 @@ class LIFXManager:
async def service_handler(service: ServiceCall) -> None:
"""Apply a service, i.e. start an effect."""
referenced = async_extract_referenced_entity_ids(self.hass, service)
referenced = async_extract_referenced_entity_ids(
self.hass, TargetSelectorData(service.data)
)
all_referenced = referenced.referenced | referenced.indirectly_referenced
if all_referenced:
await self.start_effect(all_referenced, service.service, **service.data)
@ -499,6 +504,5 @@ class LIFXManager:
if self.entry_id_to_entity_id[entry.entry_id] in entity_ids:
coordinators.append(entry.runtime_data)
bulbs.append(entry.runtime_data.device)
if start_effect_func := self._effect_dispatch.get(service):
await start_effect_func(self, bulbs, coordinators, **kwargs)

View File

@ -26,7 +26,10 @@ from homeassistant.helpers import (
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.service import async_extract_referenced_entity_ids
from homeassistant.helpers.target import (
TargetSelectorData,
async_extract_referenced_entity_ids,
)
from homeassistant.util.json import JsonValueType
from homeassistant.util.read_only_dict import ReadOnlyDict
@ -115,7 +118,7 @@ def _async_get_ufp_instance(hass: HomeAssistant, device_id: str) -> ProtectApiCl
@callback
def _async_get_ufp_camera(call: ServiceCall) -> Camera:
ref = async_extract_referenced_entity_ids(call.hass, call)
ref = async_extract_referenced_entity_ids(call.hass, TargetSelectorData(call.data))
entity_registry = er.async_get(call.hass)
entity_id = ref.indirectly_referenced.pop()
@ -133,7 +136,7 @@ def _async_get_protect_from_call(call: ServiceCall) -> set[ProtectApiClient]:
return {
_async_get_ufp_instance(call.hass, device_id)
for device_id in async_extract_referenced_entity_ids(
call.hass, call
call.hass, TargetSelectorData(call.data)
).referenced_devices
}
@ -196,7 +199,7 @@ def _async_unique_id_to_mac(unique_id: str) -> str:
async def set_chime_paired_doorbells(call: ServiceCall) -> None:
"""Set paired doorbells on chime."""
ref = async_extract_referenced_entity_ids(call.hass, call)
ref = async_extract_referenced_entity_ids(call.hass, TargetSelectorData(call.data))
entity_registry = er.async_get(call.hass)
entity_id = ref.indirectly_referenced.pop()
@ -211,7 +214,9 @@ async def set_chime_paired_doorbells(call: ServiceCall) -> None:
assert chime is not None
call.data = ReadOnlyDict(call.data.get("doorbells") or {})
doorbell_refs = async_extract_referenced_entity_ids(call.hass, call)
doorbell_refs = async_extract_referenced_entity_ids(
call.hass, TargetSelectorData(call.data)
)
doorbell_ids: set[str] = set()
for camera_id in doorbell_refs.referenced | doorbell_refs.indirectly_referenced:
doorbell_sensor = entity_registry.async_get(camera_id)

View File

@ -9,17 +9,13 @@ from enum import Enum
from functools import cache, partial
import logging
from types import ModuleType
from typing import TYPE_CHECKING, Any, TypedDict, TypeGuard, cast
from typing import TYPE_CHECKING, Any, TypedDict, cast, override
import voluptuous as vol
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
CONF_ACTION,
CONF_ENTITY_ID,
CONF_SERVICE_DATA,
@ -54,16 +50,14 @@ from homeassistant.util.yaml import load_yaml_dict
from homeassistant.util.yaml.loader import JSON_TYPE
from . import (
area_registry,
config_validation as cv,
device_registry,
entity_registry,
floor_registry,
label_registry,
target as target_helpers,
template,
translation,
)
from .group import expand_entity_ids
from .deprecation import deprecated_class, deprecated_function
from .selector import TargetSelector
from .typing import ConfigType, TemplateVarsType, VolDictType, VolSchemaType
@ -225,87 +219,31 @@ class ServiceParams(TypedDict):
target: dict | None
class ServiceTargetSelector:
@deprecated_class(
"homeassistant.helpers.target.TargetSelectorData",
breaks_in_ha_version="2026.8",
)
class ServiceTargetSelector(target_helpers.TargetSelectorData):
"""Class to hold a target selector for a service."""
__slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids")
def __init__(self, service_call: ServiceCall) -> None:
"""Extract ids from service call data."""
service_call_data = service_call.data
entity_ids: str | list | None = service_call_data.get(ATTR_ENTITY_ID)
device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID)
area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID)
floor_ids: str | list | None = service_call_data.get(ATTR_FLOOR_ID)
label_ids: str | list | None = service_call_data.get(ATTR_LABEL_ID)
self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
)
self.device_ids = (
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
)
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
self.floor_ids = (
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
)
self.label_ids = (
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
)
@property
def has_any_selector(self) -> bool:
"""Determine if any selectors are present."""
return bool(
self.entity_ids
or self.device_ids
or self.area_ids
or self.floor_ids
or self.label_ids
)
super().__init__(service_call.data)
@dataclasses.dataclass(slots=True)
class SelectedEntities:
@deprecated_class(
"homeassistant.helpers.target.SelectedEntities",
breaks_in_ha_version="2026.8",
)
class SelectedEntities(target_helpers.SelectedEntities):
"""Class to hold the selected entities."""
# Entities that were explicitly mentioned.
referenced: set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area/floor/label ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found.
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
missing_floors: set[str] = dataclasses.field(default_factory=set)
missing_labels: set[str] = dataclasses.field(default_factory=set)
# Referenced devices
referenced_devices: set[str] = dataclasses.field(default_factory=set)
referenced_areas: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: set[str]) -> None:
@override
def log_missing(
self, missing_entities: set[str], logger: logging.Logger | None = None
) -> None:
"""Log about missing items."""
parts = []
for label, items in (
("floors", self.missing_floors),
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
("labels", self.missing_labels),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
if not parts:
return
_LOGGER.warning(
"Referenced %s are missing or not currently available",
", ".join(parts),
)
super().log_missing(missing_entities, logger or _LOGGER)
@bind_hass
@ -466,7 +404,10 @@ async def async_extract_entities[_EntityT: Entity](
if data_ent_id == ENTITY_MATCH_ALL:
return [entity for entity in entities if entity.available]
referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group)
selector_data = target_helpers.TargetSelectorData(service_call.data)
referenced = target_helpers.async_extract_referenced_entity_ids(
hass, selector_data, expand_group
)
combined = referenced.referenced | referenced.indirectly_referenced
found = []
@ -482,7 +423,7 @@ async def async_extract_entities[_EntityT: Entity](
found.append(entity)
referenced.log_missing(referenced.referenced & combined)
referenced.log_missing(referenced.referenced & combined, _LOGGER)
return found
@ -495,141 +436,27 @@ async def async_extract_entity_ids(
Will convert group entity ids to the entity ids it represents.
"""
referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group)
selector_data = target_helpers.TargetSelectorData(service_call.data)
referenced = target_helpers.async_extract_referenced_entity_ids(
hass, selector_data, expand_group
)
return referenced.referenced | referenced.indirectly_referenced
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything."""
return ids not in (None, ENTITY_MATCH_NONE)
@deprecated_function(
"homeassistant.helpers.target.async_extract_referenced_entity_ids",
breaks_in_ha_version="2026.8",
)
@bind_hass
def async_extract_referenced_entity_ids(
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a service call."""
selector = ServiceTargetSelector(service_call)
selected = SelectedEntities()
if not selector.has_any_selector:
return selected
entity_ids: set[str] | list[str] = selector.entity_ids
if expand_group:
entity_ids = expand_entity_ids(hass, entity_ids)
selected.referenced.update(entity_ids)
if (
not selector.device_ids
and not selector.area_ids
and not selector.floor_ids
and not selector.label_ids
):
return selected
entities = entity_registry.async_get(hass).entities
dev_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
if selector.floor_ids:
floor_reg = floor_registry.async_get(hass)
for floor_id in selector.floor_ids:
if floor_id not in floor_reg.floors:
selected.missing_floors.add(floor_id)
for area_id in selector.area_ids:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
for device_id in selector.device_ids:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
if selector.label_ids:
label_reg = label_registry.async_get(hass)
for label_id in selector.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
for entity_entry in entities.get_entries_for_label(label_id):
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
):
selected.indirectly_referenced.add(entity_entry.entity_id)
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id)
for area_entry in area_reg.areas.get_areas_for_label(label_id):
selected.referenced_areas.add(area_entry.id)
# Find areas for targeted floors
if selector.floor_ids:
selected.referenced_areas.update(
area_entry.id
for floor_id in selector.floor_ids
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
)
selected.referenced_areas.update(selector.area_ids)
selected.referenced_devices.update(selector.device_ids)
if not selected.referenced_areas and not selected.referenced_devices:
return selected
# Add indirectly referenced by device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in selected.referenced_devices
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (entry.entity_category is None and entry.hidden_by is None)
selector_data = target_helpers.TargetSelectorData(service_call.data)
selected = target_helpers.async_extract_referenced_entity_ids(
hass, selector_data, expand_group
)
# Find devices for targeted areas
referenced_devices_by_area: set[str] = set()
if selected.referenced_areas:
for area_id in selected.referenced_areas:
referenced_devices_by_area.update(
device_entry.id
for device_entry in dev_reg.devices.get_devices_for_area_id(area_id)
)
selected.referenced_devices.update(referenced_devices_by_area)
# Add indirectly referenced by area
selected.indirectly_referenced.update(
entry.entity_id
for area_id in selected.referenced_areas
# The entity's area matches a targeted area
for entry in entities.get_entries_for_area_id(area_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if entry.entity_category is None and entry.hidden_by is None
)
# Add indirectly referenced by area through device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in referenced_devices_by_area
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (
entry.entity_category is None
and entry.hidden_by is None
and (
# The entity's device matches a device referenced
# by an area and the entity
# has no explicitly set area
not entry.area_id
)
)
)
return selected
return SelectedEntities(**dataclasses.asdict(selected))
@bind_hass
@ -637,7 +464,10 @@ async def async_extract_config_entry_ids(
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
) -> set[str]:
"""Extract referenced config entry ids from a service call."""
referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group)
selector_data = target_helpers.TargetSelectorData(service_call.data)
referenced = target_helpers.async_extract_referenced_entity_ids(
hass, selector_data, expand_group
)
ent_reg = entity_registry.async_get(hass)
dev_reg = device_registry.async_get(hass)
config_entry_ids: set[str] = set()
@ -948,11 +778,14 @@ async def entity_service_call(
target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL
if target_all_entities:
referenced: SelectedEntities | None = None
referenced: target_helpers.SelectedEntities | None = None
all_referenced: set[str] | None = None
else:
# A set of entities we're trying to target.
referenced = async_extract_referenced_entity_ids(hass, call, True)
selector_data = target_helpers.TargetSelectorData(call.data)
referenced = target_helpers.async_extract_referenced_entity_ids(
hass, selector_data, True
)
all_referenced = referenced.referenced | referenced.indirectly_referenced
# If the service function is a string, we'll pass it the service call data
@ -977,7 +810,7 @@ async def entity_service_call(
missing = referenced.referenced.copy()
for entity in entity_candidates:
missing.discard(entity.entity_id)
referenced.log_missing(missing)
referenced.log_missing(missing, _LOGGER)
entities: list[Entity] = []
for entity in entity_candidates:

View File

@ -0,0 +1,240 @@
"""Helpers for dealing with entity targets."""
from __future__ import annotations
import dataclasses
from logging import Logger
from typing import TypeGuard
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
ENTITY_MATCH_NONE,
)
from homeassistant.core import HomeAssistant
from . import (
area_registry as ar,
config_validation as cv,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
group,
label_registry as lr,
)
from .typing import ConfigType
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything."""
return ids not in (None, ENTITY_MATCH_NONE)
class TargetSelectorData:
"""Class to hold data of target selector."""
__slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids")
def __init__(self, config: ConfigType) -> None:
"""Extract ids from the config."""
entity_ids: str | list | None = config.get(ATTR_ENTITY_ID)
device_ids: str | list | None = config.get(ATTR_DEVICE_ID)
area_ids: str | list | None = config.get(ATTR_AREA_ID)
floor_ids: str | list | None = config.get(ATTR_FLOOR_ID)
label_ids: str | list | None = config.get(ATTR_LABEL_ID)
self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
)
self.device_ids = (
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
)
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
self.floor_ids = (
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
)
self.label_ids = (
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
)
@property
def has_any_selector(self) -> bool:
"""Determine if any selectors are present."""
return bool(
self.entity_ids
or self.device_ids
or self.area_ids
or self.floor_ids
or self.label_ids
)
@dataclasses.dataclass(slots=True)
class SelectedEntities:
"""Class to hold the selected entities."""
# Entities that were explicitly mentioned.
referenced: set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area/floor/label ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found.
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
missing_floors: set[str] = dataclasses.field(default_factory=set)
missing_labels: set[str] = dataclasses.field(default_factory=set)
referenced_devices: set[str] = dataclasses.field(default_factory=set)
referenced_areas: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: set[str], logger: Logger) -> None:
"""Log about missing items."""
parts = []
for label, items in (
("floors", self.missing_floors),
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
("labels", self.missing_labels),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
if not parts:
return
logger.warning(
"Referenced %s are missing or not currently available",
", ".join(parts),
)
def async_extract_referenced_entity_ids(
hass: HomeAssistant, selector_data: TargetSelectorData, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a target selector."""
selected = SelectedEntities()
if not selector_data.has_any_selector:
return selected
entity_ids: set[str] | list[str] = selector_data.entity_ids
if expand_group:
entity_ids = group.expand_entity_ids(hass, entity_ids)
selected.referenced.update(entity_ids)
if (
not selector_data.device_ids
and not selector_data.area_ids
and not selector_data.floor_ids
and not selector_data.label_ids
):
return selected
entities = er.async_get(hass).entities
dev_reg = dr.async_get(hass)
area_reg = ar.async_get(hass)
if selector_data.floor_ids:
floor_reg = fr.async_get(hass)
for floor_id in selector_data.floor_ids:
if floor_id not in floor_reg.floors:
selected.missing_floors.add(floor_id)
for area_id in selector_data.area_ids:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
for device_id in selector_data.device_ids:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
if selector_data.label_ids:
label_reg = lr.async_get(hass)
for label_id in selector_data.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
for entity_entry in entities.get_entries_for_label(label_id):
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
):
selected.indirectly_referenced.add(entity_entry.entity_id)
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id)
for area_entry in area_reg.areas.get_areas_for_label(label_id):
selected.referenced_areas.add(area_entry.id)
# Find areas for targeted floors
if selector_data.floor_ids:
selected.referenced_areas.update(
area_entry.id
for floor_id in selector_data.floor_ids
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
)
selected.referenced_areas.update(selector_data.area_ids)
selected.referenced_devices.update(selector_data.device_ids)
if not selected.referenced_areas and not selected.referenced_devices:
return selected
# Add indirectly referenced by device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in selected.referenced_devices
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (entry.entity_category is None and entry.hidden_by is None)
)
# Find devices for targeted areas
referenced_devices_by_area: set[str] = set()
if selected.referenced_areas:
for area_id in selected.referenced_areas:
referenced_devices_by_area.update(
device_entry.id
for device_entry in dev_reg.devices.get_devices_for_area_id(area_id)
)
selected.referenced_devices.update(referenced_devices_by_area)
# Add indirectly referenced by area
selected.indirectly_referenced.update(
entry.entity_id
for area_id in selected.referenced_areas
# The entity's area matches a targeted area
for entry in entities.get_entries_for_area_id(area_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if entry.entity_category is None and entry.hidden_by is None
)
# Add indirectly referenced by area through device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in referenced_devices_by_area
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (
entry.entity_category is None
and entry.hidden_by is None
and (
# The entity's device matches a device referenced
# by an area and the entity
# has no explicitly set area
not entry.area_id
)
)
)
return selected

View File

@ -3,6 +3,7 @@
import asyncio
from collections.abc import Iterable
from copy import deepcopy
import dataclasses
import io
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
@ -2322,3 +2323,80 @@ async def test_reload_service_helper(hass: HomeAssistant) -> None:
]
await asyncio.gather(*tasks)
assert reloaded == unordered(["all", "target1", "target2", "target3", "target4"])
async def test_deprecated_service_target_selector_class(hass: HomeAssistant) -> None:
"""Test that the deprecated ServiceTargetSelector class forwards correctly."""
call = ServiceCall(
hass,
"test",
"test",
{
"entity_id": ["light.test", "switch.test"],
"area_id": "kitchen",
"device_id": ["device1", "device2"],
"floor_id": "first_floor",
"label_id": ["label1", "label2"],
},
)
selector = service.ServiceTargetSelector(call)
assert selector.entity_ids == {"light.test", "switch.test"}
assert selector.area_ids == {"kitchen"}
assert selector.device_ids == {"device1", "device2"}
assert selector.floor_ids == {"first_floor"}
assert selector.label_ids == {"label1", "label2"}
assert selector.has_any_selector is True
async def test_deprecated_selected_entities_class(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that the deprecated SelectedEntities class forwards correctly."""
selected = service.SelectedEntities(
referenced={"entity.test"},
indirectly_referenced=set(),
referenced_devices=set(),
referenced_areas=set(),
missing_devices={"missing_device"},
missing_areas={"missing_area"},
missing_floors={"missing_floor"},
missing_labels={"missing_label"},
)
missing_entities = {"entity.missing"}
selected.log_missing(missing_entities)
assert (
"Referenced floors missing_floor, areas missing_area, "
"devices missing_device, entities entity.missing, "
"labels missing_label are missing or not currently available" in caplog.text
)
async def test_deprecated_async_extract_referenced_entity_ids(
hass: HomeAssistant,
) -> None:
"""Test that the deprecated async_extract_referenced_entity_ids function forwards correctly."""
from homeassistant.helpers import target # noqa: PLC0415
mock_selected = target.SelectedEntities(
referenced={"entity.test"},
indirectly_referenced={"entity.indirect"},
)
with patch(
"homeassistant.helpers.target.async_extract_referenced_entity_ids",
return_value=mock_selected,
) as mock_target_func:
call = ServiceCall(hass, "test", "test", {"entity_id": "light.test"})
result = service.async_extract_referenced_entity_ids(
hass, call, expand_group=False
)
# Verify target helper was called with correct parameters
mock_target_func.assert_called_once()
args = mock_target_func.call_args
assert args[0][0] is hass
assert args[0][1].entity_ids == {"light.test"}
assert args[0][2] is False
assert dataclasses.asdict(result) == dataclasses.asdict(mock_selected)

View File

@ -0,0 +1,459 @@
"""Test service helpers."""
import pytest
# TODO(abmantis): is this import needed?
# To prevent circular import when running just this file
import homeassistant.components # noqa: F401
from homeassistant.components.group import Group
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
ENTITY_MATCH_NONE,
STATE_OFF,
STATE_ON,
EntityCategory,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
target,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_setup_component
from tests.common import (
RegistryEntryWithDefaults,
mock_area_registry,
mock_device_registry,
mock_registry,
)
@pytest.fixture
def registries_mock(hass: HomeAssistant) -> None:
"""Mock including floor and area info."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set("light.Kitchen", STATE_OFF)
area_in_floor = ar.AreaEntry(
id="test-area",
name="Test area",
aliases={},
floor_id="test-floor",
icon=None,
picture=None,
temperature_entity_id=None,
humidity_entity_id=None,
)
area_in_floor_a = ar.AreaEntry(
id="area-a",
name="Area A",
aliases={},
floor_id="floor-a",
icon=None,
picture=None,
temperature_entity_id=None,
humidity_entity_id=None,
)
area_with_labels = ar.AreaEntry(
id="area-with-labels",
name="Area with labels",
aliases={},
floor_id=None,
icon=None,
labels={"label_area"},
picture=None,
temperature_entity_id=None,
humidity_entity_id=None,
)
mock_area_registry(
hass,
{
area_in_floor.id: area_in_floor,
area_in_floor_a.id: area_in_floor_a,
area_with_labels.id: area_with_labels,
},
)
device_in_area = dr.DeviceEntry(id="device-test-area", area_id="test-area")
device_no_area = dr.DeviceEntry(id="device-no-area-id")
device_diff_area = dr.DeviceEntry(id="device-diff-area", area_id="diff-area")
device_area_a = dr.DeviceEntry(id="device-area-a-id", area_id="area-a")
device_has_label1 = dr.DeviceEntry(id="device-has-label1-id", labels={"label1"})
device_has_label2 = dr.DeviceEntry(id="device-has-label2-id", labels={"label2"})
device_has_labels = dr.DeviceEntry(
id="device-has-labels-id",
labels={"label1", "label2"},
area_id=area_with_labels.id,
)
mock_device_registry(
hass,
{
device_in_area.id: device_in_area,
device_no_area.id: device_no_area,
device_diff_area.id: device_diff_area,
device_area_a.id: device_area_a,
device_has_label1.id: device_has_label1,
device_has_label2.id: device_has_label2,
device_has_labels.id: device_has_labels,
},
)
entity_in_own_area = RegistryEntryWithDefaults(
entity_id="light.in_own_area",
unique_id="in-own-area-id",
platform="test",
area_id="own-area",
)
config_entity_in_own_area = RegistryEntryWithDefaults(
entity_id="light.config_in_own_area",
unique_id="config-in-own-area-id",
platform="test",
area_id="own-area",
entity_category=EntityCategory.CONFIG,
)
hidden_entity_in_own_area = RegistryEntryWithDefaults(
entity_id="light.hidden_in_own_area",
unique_id="hidden-in-own-area-id",
platform="test",
area_id="own-area",
hidden_by=er.RegistryEntryHider.USER,
)
entity_in_area = RegistryEntryWithDefaults(
entity_id="light.in_area",
unique_id="in-area-id",
platform="test",
device_id=device_in_area.id,
)
config_entity_in_area = RegistryEntryWithDefaults(
entity_id="light.config_in_area",
unique_id="config-in-area-id",
platform="test",
device_id=device_in_area.id,
entity_category=EntityCategory.CONFIG,
)
hidden_entity_in_area = RegistryEntryWithDefaults(
entity_id="light.hidden_in_area",
unique_id="hidden-in-area-id",
platform="test",
device_id=device_in_area.id,
hidden_by=er.RegistryEntryHider.USER,
)
entity_in_other_area = RegistryEntryWithDefaults(
entity_id="light.in_other_area",
unique_id="in-area-a-id",
platform="test",
device_id=device_in_area.id,
area_id="other-area",
)
entity_assigned_to_area = RegistryEntryWithDefaults(
entity_id="light.assigned_to_area",
unique_id="assigned-area-id",
platform="test",
device_id=device_in_area.id,
area_id="test-area",
)
entity_no_area = RegistryEntryWithDefaults(
entity_id="light.no_area",
unique_id="no-area-id",
platform="test",
device_id=device_no_area.id,
)
config_entity_no_area = RegistryEntryWithDefaults(
entity_id="light.config_no_area",
unique_id="config-no-area-id",
platform="test",
device_id=device_no_area.id,
entity_category=EntityCategory.CONFIG,
)
hidden_entity_no_area = RegistryEntryWithDefaults(
entity_id="light.hidden_no_area",
unique_id="hidden-no-area-id",
platform="test",
device_id=device_no_area.id,
hidden_by=er.RegistryEntryHider.USER,
)
entity_diff_area = RegistryEntryWithDefaults(
entity_id="light.diff_area",
unique_id="diff-area-id",
platform="test",
device_id=device_diff_area.id,
)
entity_in_area_a = RegistryEntryWithDefaults(
entity_id="light.in_area_a",
unique_id="in-area-a-id",
platform="test",
device_id=device_area_a.id,
area_id="area-a",
)
entity_in_area_b = RegistryEntryWithDefaults(
entity_id="light.in_area_b",
unique_id="in-area-b-id",
platform="test",
device_id=device_area_a.id,
area_id="area-b",
)
entity_with_my_label = RegistryEntryWithDefaults(
entity_id="light.with_my_label",
unique_id="with_my_label",
platform="test",
labels={"my-label"},
)
hidden_entity_with_my_label = RegistryEntryWithDefaults(
entity_id="light.hidden_with_my_label",
unique_id="hidden_with_my_label",
platform="test",
labels={"my-label"},
hidden_by=er.RegistryEntryHider.USER,
)
config_entity_with_my_label = RegistryEntryWithDefaults(
entity_id="light.config_with_my_label",
unique_id="config_with_my_label",
platform="test",
labels={"my-label"},
entity_category=EntityCategory.CONFIG,
)
entity_with_label1_from_device = RegistryEntryWithDefaults(
entity_id="light.with_label1_from_device",
unique_id="with_label1_from_device",
platform="test",
device_id=device_has_label1.id,
)
entity_with_label1_from_device_and_different_area = RegistryEntryWithDefaults(
entity_id="light.with_label1_from_device_diff_area",
unique_id="with_label1_from_device_diff_area",
platform="test",
device_id=device_has_label1.id,
area_id=area_in_floor_a.id,
)
entity_with_label1_and_label2_from_device = RegistryEntryWithDefaults(
entity_id="light.with_label1_and_label2_from_device",
unique_id="with_label1_and_label2_from_device",
platform="test",
labels={"label1"},
device_id=device_has_label2.id,
)
entity_with_labels_from_device = RegistryEntryWithDefaults(
entity_id="light.with_labels_from_device",
unique_id="with_labels_from_device",
platform="test",
device_id=device_has_labels.id,
)
mock_registry(
hass,
{
entity_in_own_area.entity_id: entity_in_own_area,
config_entity_in_own_area.entity_id: config_entity_in_own_area,
hidden_entity_in_own_area.entity_id: hidden_entity_in_own_area,
entity_in_area.entity_id: entity_in_area,
config_entity_in_area.entity_id: config_entity_in_area,
hidden_entity_in_area.entity_id: hidden_entity_in_area,
entity_in_other_area.entity_id: entity_in_other_area,
entity_assigned_to_area.entity_id: entity_assigned_to_area,
entity_no_area.entity_id: entity_no_area,
config_entity_no_area.entity_id: config_entity_no_area,
hidden_entity_no_area.entity_id: hidden_entity_no_area,
entity_diff_area.entity_id: entity_diff_area,
entity_in_area_a.entity_id: entity_in_area_a,
entity_in_area_b.entity_id: entity_in_area_b,
config_entity_with_my_label.entity_id: config_entity_with_my_label,
entity_with_label1_and_label2_from_device.entity_id: entity_with_label1_and_label2_from_device,
entity_with_label1_from_device.entity_id: entity_with_label1_from_device,
entity_with_label1_from_device_and_different_area.entity_id: entity_with_label1_from_device_and_different_area,
entity_with_labels_from_device.entity_id: entity_with_labels_from_device,
entity_with_my_label.entity_id: entity_with_my_label,
hidden_entity_with_my_label.entity_id: hidden_entity_with_my_label,
},
)
@pytest.mark.parametrize(
("selector_config", "expand_group", "expected_selected"),
[
(
{
ATTR_ENTITY_ID: ENTITY_MATCH_NONE,
ATTR_AREA_ID: ENTITY_MATCH_NONE,
ATTR_FLOOR_ID: ENTITY_MATCH_NONE,
ATTR_LABEL_ID: ENTITY_MATCH_NONE,
},
False,
target.SelectedEntities(),
),
(
{ATTR_ENTITY_ID: "light.bowl"},
False,
target.SelectedEntities(referenced={"light.bowl"}),
),
(
{ATTR_ENTITY_ID: "group.test"},
True,
target.SelectedEntities(referenced={"light.ceiling", "light.kitchen"}),
),
(
{ATTR_ENTITY_ID: "group.test"},
False,
target.SelectedEntities(referenced={"group.test"}),
),
(
{ATTR_AREA_ID: "own-area"},
False,
target.SelectedEntities(
indirectly_referenced={"light.in_own_area"},
referenced_areas={"own-area"},
missing_areas={"own-area"},
),
),
(
{ATTR_AREA_ID: "test-area"},
False,
target.SelectedEntities(
indirectly_referenced={
"light.in_area",
"light.assigned_to_area",
},
referenced_areas={"test-area"},
referenced_devices={"device-test-area"},
),
),
(
{ATTR_AREA_ID: ["test-area", "diff-area"]},
False,
target.SelectedEntities(
indirectly_referenced={
"light.in_area",
"light.diff_area",
"light.assigned_to_area",
},
referenced_areas={"test-area", "diff-area"},
referenced_devices={"device-diff-area", "device-test-area"},
missing_areas={"diff-area"},
),
),
(
{ATTR_DEVICE_ID: "device-no-area-id"},
False,
target.SelectedEntities(
indirectly_referenced={"light.no_area"},
referenced_devices={"device-no-area-id"},
),
),
(
{ATTR_DEVICE_ID: "device-area-a-id"},
False,
target.SelectedEntities(
indirectly_referenced={"light.in_area_a", "light.in_area_b"},
referenced_devices={"device-area-a-id"},
),
),
(
{ATTR_FLOOR_ID: "test-floor"},
False,
target.SelectedEntities(
indirectly_referenced={"light.in_area", "light.assigned_to_area"},
referenced_devices={"device-test-area"},
referenced_areas={"test-area"},
missing_floors={"test-floor"},
),
),
(
{ATTR_FLOOR_ID: ["test-floor", "floor-a"]},
False,
target.SelectedEntities(
indirectly_referenced={
"light.in_area",
"light.assigned_to_area",
"light.in_area_a",
"light.with_label1_from_device_diff_area",
},
referenced_devices={"device-area-a-id", "device-test-area"},
referenced_areas={"area-a", "test-area"},
missing_floors={"floor-a", "test-floor"},
),
),
(
{ATTR_LABEL_ID: "my-label"},
False,
target.SelectedEntities(
indirectly_referenced={"light.with_my_label"},
missing_labels={"my-label"},
),
),
(
{ATTR_LABEL_ID: "label1"},
False,
target.SelectedEntities(
indirectly_referenced={
"light.with_label1_from_device",
"light.with_label1_from_device_diff_area",
"light.with_labels_from_device",
"light.with_label1_and_label2_from_device",
},
referenced_devices={"device-has-label1-id", "device-has-labels-id"},
missing_labels={"label1"},
),
),
(
{ATTR_LABEL_ID: ["label2"]},
False,
target.SelectedEntities(
indirectly_referenced={
"light.with_labels_from_device",
"light.with_label1_and_label2_from_device",
},
referenced_devices={"device-has-label2-id", "device-has-labels-id"},
missing_labels={"label2"},
),
),
(
{ATTR_LABEL_ID: ["label_area"]},
False,
target.SelectedEntities(
indirectly_referenced={"light.with_labels_from_device"},
referenced_devices={"device-has-labels-id"},
referenced_areas={"area-with-labels"},
missing_labels={"label_area"},
),
),
],
)
@pytest.mark.usefixtures("registries_mock")
async def test_extract_referenced_entity_ids(
hass: HomeAssistant,
selector_config: ConfigType,
expand_group: bool,
expected_selected: target.SelectedEntities,
) -> None:
"""Test extract_entity_ids method."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set("light.Kitchen", STATE_OFF)
assert await async_setup_component(hass, "group", {})
await hass.async_block_till_done()
await Group.async_create_group(
hass,
"test",
created_by_service=False,
entity_ids=["light.Ceiling", "light.Kitchen"],
icon=None,
mode=None,
object_id=None,
order=None,
)
target_data = target.TargetSelectorData(selector_config)
assert (
target.async_extract_referenced_entity_ids(
hass, target_data, expand_group=expand_group
)
== expected_selected
)