Move target selector extractor method to common module (#148087)

This commit is contained in:
Abílio Costa 2025-07-07 13:48:48 +01:00 committed by GitHub
parent c60e06d32f
commit b71bcb002b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 855 additions and 228 deletions

View File

@ -44,11 +44,14 @@ from homeassistant.helpers.entity_component import async_update_entity
from homeassistant.helpers.issue_registry import IssueSeverity from homeassistant.helpers.issue_registry import IssueSeverity
from homeassistant.helpers.service import ( from homeassistant.helpers.service import (
async_extract_config_entry_ids, async_extract_config_entry_ids,
async_extract_referenced_entity_ids,
async_register_admin_service, async_register_admin_service,
) )
from homeassistant.helpers.signal import KEY_HA_STOP from homeassistant.helpers.signal import KEY_HA_STOP
from homeassistant.helpers.system_info import async_get_system_info from homeassistant.helpers.system_info import async_get_system_info
from homeassistant.helpers.target import (
TargetSelectorData,
async_extract_referenced_entity_ids,
)
from homeassistant.helpers.template import async_load_custom_templates from homeassistant.helpers.template import async_load_custom_templates
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -111,7 +114,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa:
async def async_handle_turn_service(service: ServiceCall) -> None: async def async_handle_turn_service(service: ServiceCall) -> None:
"""Handle calls to homeassistant.turn_on/off.""" """Handle calls to homeassistant.turn_on/off."""
referenced = async_extract_referenced_entity_ids(hass, service) referenced = async_extract_referenced_entity_ids(
hass, TargetSelectorData(service.data)
)
all_referenced = referenced.referenced | referenced.indirectly_referenced all_referenced = referenced.referenced | referenced.indirectly_referenced
# Generic turn on/off method requires entity id # Generic turn on/off method requires entity id

View File

@ -75,11 +75,12 @@ from homeassistant.helpers.entityfilter import (
EntityFilter, EntityFilter,
) )
from homeassistant.helpers.reload import async_integration_yaml_config from homeassistant.helpers.reload import async_integration_yaml_config
from homeassistant.helpers.service import ( from homeassistant.helpers.service import async_register_admin_service
async_extract_referenced_entity_ids,
async_register_admin_service,
)
from homeassistant.helpers.start import async_at_started from homeassistant.helpers.start import async_at_started
from homeassistant.helpers.target import (
TargetSelectorData,
async_extract_referenced_entity_ids,
)
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.util.async_ import create_eager_task from homeassistant.util.async_ import create_eager_task
@ -482,7 +483,9 @@ def _async_register_events_and_services(hass: HomeAssistant) -> None:
async def async_handle_homekit_unpair(service: ServiceCall) -> None: async def async_handle_homekit_unpair(service: ServiceCall) -> None:
"""Handle unpair HomeKit service call.""" """Handle unpair HomeKit service call."""
referenced = async_extract_referenced_entity_ids(hass, service) referenced = async_extract_referenced_entity_ids(
hass, TargetSelectorData(service.data)
)
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
for device_id in referenced.referenced_devices: for device_id in referenced.referenced_devices:
if not (dev_reg_ent := dev_reg.async_get(device_id)): if not (dev_reg_ent := dev_reg.async_get(device_id)):

View File

@ -28,7 +28,10 @@ from homeassistant.components.light import (
from homeassistant.const import ATTR_MODE from homeassistant.const import ATTR_MODE
from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_extract_referenced_entity_ids from homeassistant.helpers.target import (
TargetSelectorData,
async_extract_referenced_entity_ids,
)
from .const import _ATTR_COLOR_TEMP, ATTR_THEME, DOMAIN from .const import _ATTR_COLOR_TEMP, ATTR_THEME, DOMAIN
from .coordinator import LIFXConfigEntry, LIFXUpdateCoordinator from .coordinator import LIFXConfigEntry, LIFXUpdateCoordinator
@ -268,7 +271,9 @@ class LIFXManager:
async def service_handler(service: ServiceCall) -> None: async def service_handler(service: ServiceCall) -> None:
"""Apply a service, i.e. start an effect.""" """Apply a service, i.e. start an effect."""
referenced = async_extract_referenced_entity_ids(self.hass, service) referenced = async_extract_referenced_entity_ids(
self.hass, TargetSelectorData(service.data)
)
all_referenced = referenced.referenced | referenced.indirectly_referenced all_referenced = referenced.referenced | referenced.indirectly_referenced
if all_referenced: if all_referenced:
await self.start_effect(all_referenced, service.service, **service.data) await self.start_effect(all_referenced, service.service, **service.data)
@ -499,6 +504,5 @@ class LIFXManager:
if self.entry_id_to_entity_id[entry.entry_id] in entity_ids: if self.entry_id_to_entity_id[entry.entry_id] in entity_ids:
coordinators.append(entry.runtime_data) coordinators.append(entry.runtime_data)
bulbs.append(entry.runtime_data.device) bulbs.append(entry.runtime_data.device)
if start_effect_func := self._effect_dispatch.get(service): if start_effect_func := self._effect_dispatch.get(service):
await start_effect_func(self, bulbs, coordinators, **kwargs) await start_effect_func(self, bulbs, coordinators, **kwargs)

View File

@ -26,7 +26,10 @@ from homeassistant.helpers import (
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
) )
from homeassistant.helpers.service import async_extract_referenced_entity_ids from homeassistant.helpers.target import (
TargetSelectorData,
async_extract_referenced_entity_ids,
)
from homeassistant.util.json import JsonValueType from homeassistant.util.json import JsonValueType
from homeassistant.util.read_only_dict import ReadOnlyDict from homeassistant.util.read_only_dict import ReadOnlyDict
@ -115,7 +118,7 @@ def _async_get_ufp_instance(hass: HomeAssistant, device_id: str) -> ProtectApiCl
@callback @callback
def _async_get_ufp_camera(call: ServiceCall) -> Camera: def _async_get_ufp_camera(call: ServiceCall) -> Camera:
ref = async_extract_referenced_entity_ids(call.hass, call) ref = async_extract_referenced_entity_ids(call.hass, TargetSelectorData(call.data))
entity_registry = er.async_get(call.hass) entity_registry = er.async_get(call.hass)
entity_id = ref.indirectly_referenced.pop() entity_id = ref.indirectly_referenced.pop()
@ -133,7 +136,7 @@ def _async_get_protect_from_call(call: ServiceCall) -> set[ProtectApiClient]:
return { return {
_async_get_ufp_instance(call.hass, device_id) _async_get_ufp_instance(call.hass, device_id)
for device_id in async_extract_referenced_entity_ids( for device_id in async_extract_referenced_entity_ids(
call.hass, call call.hass, TargetSelectorData(call.data)
).referenced_devices ).referenced_devices
} }
@ -196,7 +199,7 @@ def _async_unique_id_to_mac(unique_id: str) -> str:
async def set_chime_paired_doorbells(call: ServiceCall) -> None: async def set_chime_paired_doorbells(call: ServiceCall) -> None:
"""Set paired doorbells on chime.""" """Set paired doorbells on chime."""
ref = async_extract_referenced_entity_ids(call.hass, call) ref = async_extract_referenced_entity_ids(call.hass, TargetSelectorData(call.data))
entity_registry = er.async_get(call.hass) entity_registry = er.async_get(call.hass)
entity_id = ref.indirectly_referenced.pop() entity_id = ref.indirectly_referenced.pop()
@ -211,7 +214,9 @@ async def set_chime_paired_doorbells(call: ServiceCall) -> None:
assert chime is not None assert chime is not None
call.data = ReadOnlyDict(call.data.get("doorbells") or {}) call.data = ReadOnlyDict(call.data.get("doorbells") or {})
doorbell_refs = async_extract_referenced_entity_ids(call.hass, call) doorbell_refs = async_extract_referenced_entity_ids(
call.hass, TargetSelectorData(call.data)
)
doorbell_ids: set[str] = set() doorbell_ids: set[str] = set()
for camera_id in doorbell_refs.referenced | doorbell_refs.indirectly_referenced: for camera_id in doorbell_refs.referenced | doorbell_refs.indirectly_referenced:
doorbell_sensor = entity_registry.async_get(camera_id) doorbell_sensor = entity_registry.async_get(camera_id)

View File

@ -9,17 +9,13 @@ from enum import Enum
from functools import cache, partial from functools import cache, partial
import logging import logging
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Any, TypedDict, TypeGuard, cast from typing import TYPE_CHECKING, Any, TypedDict, cast, override
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL
from homeassistant.const import ( from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
CONF_ACTION, CONF_ACTION,
CONF_ENTITY_ID, CONF_ENTITY_ID,
CONF_SERVICE_DATA, CONF_SERVICE_DATA,
@ -54,16 +50,14 @@ from homeassistant.util.yaml import load_yaml_dict
from homeassistant.util.yaml.loader import JSON_TYPE from homeassistant.util.yaml.loader import JSON_TYPE
from . import ( from . import (
area_registry,
config_validation as cv, config_validation as cv,
device_registry, device_registry,
entity_registry, entity_registry,
floor_registry, target as target_helpers,
label_registry,
template, template,
translation, translation,
) )
from .group import expand_entity_ids from .deprecation import deprecated_class, deprecated_function
from .selector import TargetSelector from .selector import TargetSelector
from .typing import ConfigType, TemplateVarsType, VolDictType, VolSchemaType from .typing import ConfigType, TemplateVarsType, VolDictType, VolSchemaType
@ -225,87 +219,31 @@ class ServiceParams(TypedDict):
target: dict | None target: dict | None
class ServiceTargetSelector: @deprecated_class(
"homeassistant.helpers.target.TargetSelectorData",
breaks_in_ha_version="2026.8",
)
class ServiceTargetSelector(target_helpers.TargetSelectorData):
"""Class to hold a target selector for a service.""" """Class to hold a target selector for a service."""
__slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids")
def __init__(self, service_call: ServiceCall) -> None: def __init__(self, service_call: ServiceCall) -> None:
"""Extract ids from service call data.""" """Extract ids from service call data."""
service_call_data = service_call.data super().__init__(service_call.data)
entity_ids: str | list | None = service_call_data.get(ATTR_ENTITY_ID)
device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID)
area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID)
floor_ids: str | list | None = service_call_data.get(ATTR_FLOOR_ID)
label_ids: str | list | None = service_call_data.get(ATTR_LABEL_ID)
self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
)
self.device_ids = (
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
)
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
self.floor_ids = (
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
)
self.label_ids = (
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
)
@property
def has_any_selector(self) -> bool:
"""Determine if any selectors are present."""
return bool(
self.entity_ids
or self.device_ids
or self.area_ids
or self.floor_ids
or self.label_ids
)
@dataclasses.dataclass(slots=True) @deprecated_class(
class SelectedEntities: "homeassistant.helpers.target.SelectedEntities",
breaks_in_ha_version="2026.8",
)
class SelectedEntities(target_helpers.SelectedEntities):
"""Class to hold the selected entities.""" """Class to hold the selected entities."""
# Entities that were explicitly mentioned. @override
referenced: set[str] = dataclasses.field(default_factory=set) def log_missing(
self, missing_entities: set[str], logger: logging.Logger | None = None
# Entities that were referenced via device/area/floor/label ID. ) -> None:
# Should not trigger a warning when they don't exist.
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found.
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
missing_floors: set[str] = dataclasses.field(default_factory=set)
missing_labels: set[str] = dataclasses.field(default_factory=set)
# Referenced devices
referenced_devices: set[str] = dataclasses.field(default_factory=set)
referenced_areas: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: set[str]) -> None:
"""Log about missing items.""" """Log about missing items."""
parts = [] super().log_missing(missing_entities, logger or _LOGGER)
for label, items in (
("floors", self.missing_floors),
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
("labels", self.missing_labels),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
if not parts:
return
_LOGGER.warning(
"Referenced %s are missing or not currently available",
", ".join(parts),
)
@bind_hass @bind_hass
@ -466,7 +404,10 @@ async def async_extract_entities[_EntityT: Entity](
if data_ent_id == ENTITY_MATCH_ALL: if data_ent_id == ENTITY_MATCH_ALL:
return [entity for entity in entities if entity.available] return [entity for entity in entities if entity.available]
referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group) selector_data = target_helpers.TargetSelectorData(service_call.data)
referenced = target_helpers.async_extract_referenced_entity_ids(
hass, selector_data, expand_group
)
combined = referenced.referenced | referenced.indirectly_referenced combined = referenced.referenced | referenced.indirectly_referenced
found = [] found = []
@ -482,7 +423,7 @@ async def async_extract_entities[_EntityT: Entity](
found.append(entity) found.append(entity)
referenced.log_missing(referenced.referenced & combined) referenced.log_missing(referenced.referenced & combined, _LOGGER)
return found return found
@ -495,141 +436,27 @@ async def async_extract_entity_ids(
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
""" """
referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group) selector_data = target_helpers.TargetSelectorData(service_call.data)
referenced = target_helpers.async_extract_referenced_entity_ids(
hass, selector_data, expand_group
)
return referenced.referenced | referenced.indirectly_referenced return referenced.referenced | referenced.indirectly_referenced
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: @deprecated_function(
"""Check if ids can match anything.""" "homeassistant.helpers.target.async_extract_referenced_entity_ids",
return ids not in (None, ENTITY_MATCH_NONE) breaks_in_ha_version="2026.8",
)
@bind_hass @bind_hass
def async_extract_referenced_entity_ids( def async_extract_referenced_entity_ids(
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
) -> SelectedEntities: ) -> SelectedEntities:
"""Extract referenced entity IDs from a service call.""" """Extract referenced entity IDs from a service call."""
selector = ServiceTargetSelector(service_call) selector_data = target_helpers.TargetSelectorData(service_call.data)
selected = SelectedEntities() selected = target_helpers.async_extract_referenced_entity_ids(
hass, selector_data, expand_group
if not selector.has_any_selector:
return selected
entity_ids: set[str] | list[str] = selector.entity_ids
if expand_group:
entity_ids = expand_entity_ids(hass, entity_ids)
selected.referenced.update(entity_ids)
if (
not selector.device_ids
and not selector.area_ids
and not selector.floor_ids
and not selector.label_ids
):
return selected
entities = entity_registry.async_get(hass).entities
dev_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
if selector.floor_ids:
floor_reg = floor_registry.async_get(hass)
for floor_id in selector.floor_ids:
if floor_id not in floor_reg.floors:
selected.missing_floors.add(floor_id)
for area_id in selector.area_ids:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
for device_id in selector.device_ids:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
if selector.label_ids:
label_reg = label_registry.async_get(hass)
for label_id in selector.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
for entity_entry in entities.get_entries_for_label(label_id):
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
):
selected.indirectly_referenced.add(entity_entry.entity_id)
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id)
for area_entry in area_reg.areas.get_areas_for_label(label_id):
selected.referenced_areas.add(area_entry.id)
# Find areas for targeted floors
if selector.floor_ids:
selected.referenced_areas.update(
area_entry.id
for floor_id in selector.floor_ids
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
) )
return SelectedEntities(**dataclasses.asdict(selected))
selected.referenced_areas.update(selector.area_ids)
selected.referenced_devices.update(selector.device_ids)
if not selected.referenced_areas and not selected.referenced_devices:
return selected
# Add indirectly referenced by device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in selected.referenced_devices
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (entry.entity_category is None and entry.hidden_by is None)
)
# Find devices for targeted areas
referenced_devices_by_area: set[str] = set()
if selected.referenced_areas:
for area_id in selected.referenced_areas:
referenced_devices_by_area.update(
device_entry.id
for device_entry in dev_reg.devices.get_devices_for_area_id(area_id)
)
selected.referenced_devices.update(referenced_devices_by_area)
# Add indirectly referenced by area
selected.indirectly_referenced.update(
entry.entity_id
for area_id in selected.referenced_areas
# The entity's area matches a targeted area
for entry in entities.get_entries_for_area_id(area_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if entry.entity_category is None and entry.hidden_by is None
)
# Add indirectly referenced by area through device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in referenced_devices_by_area
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (
entry.entity_category is None
and entry.hidden_by is None
and (
# The entity's device matches a device referenced
# by an area and the entity
# has no explicitly set area
not entry.area_id
)
)
)
return selected
@bind_hass @bind_hass
@ -637,7 +464,10 @@ async def async_extract_config_entry_ids(
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
) -> set[str]: ) -> set[str]:
"""Extract referenced config entry ids from a service call.""" """Extract referenced config entry ids from a service call."""
referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group) selector_data = target_helpers.TargetSelectorData(service_call.data)
referenced = target_helpers.async_extract_referenced_entity_ids(
hass, selector_data, expand_group
)
ent_reg = entity_registry.async_get(hass) ent_reg = entity_registry.async_get(hass)
dev_reg = device_registry.async_get(hass) dev_reg = device_registry.async_get(hass)
config_entry_ids: set[str] = set() config_entry_ids: set[str] = set()
@ -948,11 +778,14 @@ async def entity_service_call(
target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL
if target_all_entities: if target_all_entities:
referenced: SelectedEntities | None = None referenced: target_helpers.SelectedEntities | None = None
all_referenced: set[str] | None = None all_referenced: set[str] | None = None
else: else:
# A set of entities we're trying to target. # A set of entities we're trying to target.
referenced = async_extract_referenced_entity_ids(hass, call, True) selector_data = target_helpers.TargetSelectorData(call.data)
referenced = target_helpers.async_extract_referenced_entity_ids(
hass, selector_data, True
)
all_referenced = referenced.referenced | referenced.indirectly_referenced all_referenced = referenced.referenced | referenced.indirectly_referenced
# If the service function is a string, we'll pass it the service call data # If the service function is a string, we'll pass it the service call data
@ -977,7 +810,7 @@ async def entity_service_call(
missing = referenced.referenced.copy() missing = referenced.referenced.copy()
for entity in entity_candidates: for entity in entity_candidates:
missing.discard(entity.entity_id) missing.discard(entity.entity_id)
referenced.log_missing(missing) referenced.log_missing(missing, _LOGGER)
entities: list[Entity] = [] entities: list[Entity] = []
for entity in entity_candidates: for entity in entity_candidates:

View File

@ -0,0 +1,240 @@
"""Helpers for dealing with entity targets."""
from __future__ import annotations
import dataclasses
from logging import Logger
from typing import TypeGuard
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
ENTITY_MATCH_NONE,
)
from homeassistant.core import HomeAssistant
from . import (
area_registry as ar,
config_validation as cv,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
group,
label_registry as lr,
)
from .typing import ConfigType
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything."""
return ids not in (None, ENTITY_MATCH_NONE)
class TargetSelectorData:
"""Class to hold data of target selector."""
__slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids")
def __init__(self, config: ConfigType) -> None:
"""Extract ids from the config."""
entity_ids: str | list | None = config.get(ATTR_ENTITY_ID)
device_ids: str | list | None = config.get(ATTR_DEVICE_ID)
area_ids: str | list | None = config.get(ATTR_AREA_ID)
floor_ids: str | list | None = config.get(ATTR_FLOOR_ID)
label_ids: str | list | None = config.get(ATTR_LABEL_ID)
self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
)
self.device_ids = (
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
)
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
self.floor_ids = (
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
)
self.label_ids = (
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
)
@property
def has_any_selector(self) -> bool:
"""Determine if any selectors are present."""
return bool(
self.entity_ids
or self.device_ids
or self.area_ids
or self.floor_ids
or self.label_ids
)
@dataclasses.dataclass(slots=True)
class SelectedEntities:
"""Class to hold the selected entities."""
# Entities that were explicitly mentioned.
referenced: set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area/floor/label ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found.
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
missing_floors: set[str] = dataclasses.field(default_factory=set)
missing_labels: set[str] = dataclasses.field(default_factory=set)
referenced_devices: set[str] = dataclasses.field(default_factory=set)
referenced_areas: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: set[str], logger: Logger) -> None:
"""Log about missing items."""
parts = []
for label, items in (
("floors", self.missing_floors),
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
("labels", self.missing_labels),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
if not parts:
return
logger.warning(
"Referenced %s are missing or not currently available",
", ".join(parts),
)
def async_extract_referenced_entity_ids(
hass: HomeAssistant, selector_data: TargetSelectorData, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a target selector."""
selected = SelectedEntities()
if not selector_data.has_any_selector:
return selected
entity_ids: set[str] | list[str] = selector_data.entity_ids
if expand_group:
entity_ids = group.expand_entity_ids(hass, entity_ids)
selected.referenced.update(entity_ids)
if (
not selector_data.device_ids
and not selector_data.area_ids
and not selector_data.floor_ids
and not selector_data.label_ids
):
return selected
entities = er.async_get(hass).entities
dev_reg = dr.async_get(hass)
area_reg = ar.async_get(hass)
if selector_data.floor_ids:
floor_reg = fr.async_get(hass)
for floor_id in selector_data.floor_ids:
if floor_id not in floor_reg.floors:
selected.missing_floors.add(floor_id)
for area_id in selector_data.area_ids:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
for device_id in selector_data.device_ids:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
if selector_data.label_ids:
label_reg = lr.async_get(hass)
for label_id in selector_data.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
for entity_entry in entities.get_entries_for_label(label_id):
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
):
selected.indirectly_referenced.add(entity_entry.entity_id)
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id)
for area_entry in area_reg.areas.get_areas_for_label(label_id):
selected.referenced_areas.add(area_entry.id)
# Find areas for targeted floors
if selector_data.floor_ids:
selected.referenced_areas.update(
area_entry.id
for floor_id in selector_data.floor_ids
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
)
selected.referenced_areas.update(selector_data.area_ids)
selected.referenced_devices.update(selector_data.device_ids)
if not selected.referenced_areas and not selected.referenced_devices:
return selected
# Add indirectly referenced by device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in selected.referenced_devices
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (entry.entity_category is None and entry.hidden_by is None)
)
# Find devices for targeted areas
referenced_devices_by_area: set[str] = set()
if selected.referenced_areas:
for area_id in selected.referenced_areas:
referenced_devices_by_area.update(
device_entry.id
for device_entry in dev_reg.devices.get_devices_for_area_id(area_id)
)
selected.referenced_devices.update(referenced_devices_by_area)
# Add indirectly referenced by area
selected.indirectly_referenced.update(
entry.entity_id
for area_id in selected.referenced_areas
# The entity's area matches a targeted area
for entry in entities.get_entries_for_area_id(area_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if entry.entity_category is None and entry.hidden_by is None
)
# Add indirectly referenced by area through device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in referenced_devices_by_area
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (
entry.entity_category is None
and entry.hidden_by is None
and (
# The entity's device matches a device referenced
# by an area and the entity
# has no explicitly set area
not entry.area_id
)
)
)
return selected

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
from collections.abc import Iterable from collections.abc import Iterable
from copy import deepcopy from copy import deepcopy
import dataclasses
import io import io
from typing import Any from typing import Any
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
@ -2322,3 +2323,80 @@ async def test_reload_service_helper(hass: HomeAssistant) -> None:
] ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
assert reloaded == unordered(["all", "target1", "target2", "target3", "target4"]) assert reloaded == unordered(["all", "target1", "target2", "target3", "target4"])
async def test_deprecated_service_target_selector_class(hass: HomeAssistant) -> None:
"""Test that the deprecated ServiceTargetSelector class forwards correctly."""
call = ServiceCall(
hass,
"test",
"test",
{
"entity_id": ["light.test", "switch.test"],
"area_id": "kitchen",
"device_id": ["device1", "device2"],
"floor_id": "first_floor",
"label_id": ["label1", "label2"],
},
)
selector = service.ServiceTargetSelector(call)
assert selector.entity_ids == {"light.test", "switch.test"}
assert selector.area_ids == {"kitchen"}
assert selector.device_ids == {"device1", "device2"}
assert selector.floor_ids == {"first_floor"}
assert selector.label_ids == {"label1", "label2"}
assert selector.has_any_selector is True
async def test_deprecated_selected_entities_class(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that the deprecated SelectedEntities class forwards correctly."""
selected = service.SelectedEntities(
referenced={"entity.test"},
indirectly_referenced=set(),
referenced_devices=set(),
referenced_areas=set(),
missing_devices={"missing_device"},
missing_areas={"missing_area"},
missing_floors={"missing_floor"},
missing_labels={"missing_label"},
)
missing_entities = {"entity.missing"}
selected.log_missing(missing_entities)
assert (
"Referenced floors missing_floor, areas missing_area, "
"devices missing_device, entities entity.missing, "
"labels missing_label are missing or not currently available" in caplog.text
)
async def test_deprecated_async_extract_referenced_entity_ids(
hass: HomeAssistant,
) -> None:
"""Test that the deprecated async_extract_referenced_entity_ids function forwards correctly."""
from homeassistant.helpers import target # noqa: PLC0415
mock_selected = target.SelectedEntities(
referenced={"entity.test"},
indirectly_referenced={"entity.indirect"},
)
with patch(
"homeassistant.helpers.target.async_extract_referenced_entity_ids",
return_value=mock_selected,
) as mock_target_func:
call = ServiceCall(hass, "test", "test", {"entity_id": "light.test"})
result = service.async_extract_referenced_entity_ids(
hass, call, expand_group=False
)
# Verify target helper was called with correct parameters
mock_target_func.assert_called_once()
args = mock_target_func.call_args
assert args[0][0] is hass
assert args[0][1].entity_ids == {"light.test"}
assert args[0][2] is False
assert dataclasses.asdict(result) == dataclasses.asdict(mock_selected)

View File

@ -0,0 +1,459 @@
"""Test service helpers."""
import pytest
# TODO(abmantis): is this import needed?
# To prevent circular import when running just this file
import homeassistant.components # noqa: F401
from homeassistant.components.group import Group
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
ENTITY_MATCH_NONE,
STATE_OFF,
STATE_ON,
EntityCategory,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
target,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_setup_component
from tests.common import (
RegistryEntryWithDefaults,
mock_area_registry,
mock_device_registry,
mock_registry,
)
@pytest.fixture
def registries_mock(hass: HomeAssistant) -> None:
"""Mock including floor and area info."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set("light.Kitchen", STATE_OFF)
area_in_floor = ar.AreaEntry(
id="test-area",
name="Test area",
aliases={},
floor_id="test-floor",
icon=None,
picture=None,
temperature_entity_id=None,
humidity_entity_id=None,
)
area_in_floor_a = ar.AreaEntry(
id="area-a",
name="Area A",
aliases={},
floor_id="floor-a",
icon=None,
picture=None,
temperature_entity_id=None,
humidity_entity_id=None,
)
area_with_labels = ar.AreaEntry(
id="area-with-labels",
name="Area with labels",
aliases={},
floor_id=None,
icon=None,
labels={"label_area"},
picture=None,
temperature_entity_id=None,
humidity_entity_id=None,
)
mock_area_registry(
hass,
{
area_in_floor.id: area_in_floor,
area_in_floor_a.id: area_in_floor_a,
area_with_labels.id: area_with_labels,
},
)
device_in_area = dr.DeviceEntry(id="device-test-area", area_id="test-area")
device_no_area = dr.DeviceEntry(id="device-no-area-id")
device_diff_area = dr.DeviceEntry(id="device-diff-area", area_id="diff-area")
device_area_a = dr.DeviceEntry(id="device-area-a-id", area_id="area-a")
device_has_label1 = dr.DeviceEntry(id="device-has-label1-id", labels={"label1"})
device_has_label2 = dr.DeviceEntry(id="device-has-label2-id", labels={"label2"})
device_has_labels = dr.DeviceEntry(
id="device-has-labels-id",
labels={"label1", "label2"},
area_id=area_with_labels.id,
)
mock_device_registry(
hass,
{
device_in_area.id: device_in_area,
device_no_area.id: device_no_area,
device_diff_area.id: device_diff_area,
device_area_a.id: device_area_a,
device_has_label1.id: device_has_label1,
device_has_label2.id: device_has_label2,
device_has_labels.id: device_has_labels,
},
)
entity_in_own_area = RegistryEntryWithDefaults(
entity_id="light.in_own_area",
unique_id="in-own-area-id",
platform="test",
area_id="own-area",
)
config_entity_in_own_area = RegistryEntryWithDefaults(
entity_id="light.config_in_own_area",
unique_id="config-in-own-area-id",
platform="test",
area_id="own-area",
entity_category=EntityCategory.CONFIG,
)
hidden_entity_in_own_area = RegistryEntryWithDefaults(
entity_id="light.hidden_in_own_area",
unique_id="hidden-in-own-area-id",
platform="test",
area_id="own-area",
hidden_by=er.RegistryEntryHider.USER,
)
entity_in_area = RegistryEntryWithDefaults(
entity_id="light.in_area",
unique_id="in-area-id",
platform="test",
device_id=device_in_area.id,
)
config_entity_in_area = RegistryEntryWithDefaults(
entity_id="light.config_in_area",
unique_id="config-in-area-id",
platform="test",
device_id=device_in_area.id,
entity_category=EntityCategory.CONFIG,
)
hidden_entity_in_area = RegistryEntryWithDefaults(
entity_id="light.hidden_in_area",
unique_id="hidden-in-area-id",
platform="test",
device_id=device_in_area.id,
hidden_by=er.RegistryEntryHider.USER,
)
entity_in_other_area = RegistryEntryWithDefaults(
entity_id="light.in_other_area",
unique_id="in-area-a-id",
platform="test",
device_id=device_in_area.id,
area_id="other-area",
)
entity_assigned_to_area = RegistryEntryWithDefaults(
entity_id="light.assigned_to_area",
unique_id="assigned-area-id",
platform="test",
device_id=device_in_area.id,
area_id="test-area",
)
entity_no_area = RegistryEntryWithDefaults(
entity_id="light.no_area",
unique_id="no-area-id",
platform="test",
device_id=device_no_area.id,
)
config_entity_no_area = RegistryEntryWithDefaults(
entity_id="light.config_no_area",
unique_id="config-no-area-id",
platform="test",
device_id=device_no_area.id,
entity_category=EntityCategory.CONFIG,
)
hidden_entity_no_area = RegistryEntryWithDefaults(
entity_id="light.hidden_no_area",
unique_id="hidden-no-area-id",
platform="test",
device_id=device_no_area.id,
hidden_by=er.RegistryEntryHider.USER,
)
entity_diff_area = RegistryEntryWithDefaults(
entity_id="light.diff_area",
unique_id="diff-area-id",
platform="test",
device_id=device_diff_area.id,
)
entity_in_area_a = RegistryEntryWithDefaults(
entity_id="light.in_area_a",
unique_id="in-area-a-id",
platform="test",
device_id=device_area_a.id,
area_id="area-a",
)
entity_in_area_b = RegistryEntryWithDefaults(
entity_id="light.in_area_b",
unique_id="in-area-b-id",
platform="test",
device_id=device_area_a.id,
area_id="area-b",
)
entity_with_my_label = RegistryEntryWithDefaults(
entity_id="light.with_my_label",
unique_id="with_my_label",
platform="test",
labels={"my-label"},
)
hidden_entity_with_my_label = RegistryEntryWithDefaults(
entity_id="light.hidden_with_my_label",
unique_id="hidden_with_my_label",
platform="test",
labels={"my-label"},
hidden_by=er.RegistryEntryHider.USER,
)
config_entity_with_my_label = RegistryEntryWithDefaults(
entity_id="light.config_with_my_label",
unique_id="config_with_my_label",
platform="test",
labels={"my-label"},
entity_category=EntityCategory.CONFIG,
)
entity_with_label1_from_device = RegistryEntryWithDefaults(
entity_id="light.with_label1_from_device",
unique_id="with_label1_from_device",
platform="test",
device_id=device_has_label1.id,
)
entity_with_label1_from_device_and_different_area = RegistryEntryWithDefaults(
entity_id="light.with_label1_from_device_diff_area",
unique_id="with_label1_from_device_diff_area",
platform="test",
device_id=device_has_label1.id,
area_id=area_in_floor_a.id,
)
entity_with_label1_and_label2_from_device = RegistryEntryWithDefaults(
entity_id="light.with_label1_and_label2_from_device",
unique_id="with_label1_and_label2_from_device",
platform="test",
labels={"label1"},
device_id=device_has_label2.id,
)
entity_with_labels_from_device = RegistryEntryWithDefaults(
entity_id="light.with_labels_from_device",
unique_id="with_labels_from_device",
platform="test",
device_id=device_has_labels.id,
)
mock_registry(
hass,
{
entity_in_own_area.entity_id: entity_in_own_area,
config_entity_in_own_area.entity_id: config_entity_in_own_area,
hidden_entity_in_own_area.entity_id: hidden_entity_in_own_area,
entity_in_area.entity_id: entity_in_area,
config_entity_in_area.entity_id: config_entity_in_area,
hidden_entity_in_area.entity_id: hidden_entity_in_area,
entity_in_other_area.entity_id: entity_in_other_area,
entity_assigned_to_area.entity_id: entity_assigned_to_area,
entity_no_area.entity_id: entity_no_area,
config_entity_no_area.entity_id: config_entity_no_area,
hidden_entity_no_area.entity_id: hidden_entity_no_area,
entity_diff_area.entity_id: entity_diff_area,
entity_in_area_a.entity_id: entity_in_area_a,
entity_in_area_b.entity_id: entity_in_area_b,
config_entity_with_my_label.entity_id: config_entity_with_my_label,
entity_with_label1_and_label2_from_device.entity_id: entity_with_label1_and_label2_from_device,
entity_with_label1_from_device.entity_id: entity_with_label1_from_device,
entity_with_label1_from_device_and_different_area.entity_id: entity_with_label1_from_device_and_different_area,
entity_with_labels_from_device.entity_id: entity_with_labels_from_device,
entity_with_my_label.entity_id: entity_with_my_label,
hidden_entity_with_my_label.entity_id: hidden_entity_with_my_label,
},
)
@pytest.mark.parametrize(
("selector_config", "expand_group", "expected_selected"),
[
(
{
ATTR_ENTITY_ID: ENTITY_MATCH_NONE,
ATTR_AREA_ID: ENTITY_MATCH_NONE,
ATTR_FLOOR_ID: ENTITY_MATCH_NONE,
ATTR_LABEL_ID: ENTITY_MATCH_NONE,
},
False,
target.SelectedEntities(),
),
(
{ATTR_ENTITY_ID: "light.bowl"},
False,
target.SelectedEntities(referenced={"light.bowl"}),
),
(
{ATTR_ENTITY_ID: "group.test"},
True,
target.SelectedEntities(referenced={"light.ceiling", "light.kitchen"}),
),
(
{ATTR_ENTITY_ID: "group.test"},
False,
target.SelectedEntities(referenced={"group.test"}),
),
(
{ATTR_AREA_ID: "own-area"},
False,
target.SelectedEntities(
indirectly_referenced={"light.in_own_area"},
referenced_areas={"own-area"},
missing_areas={"own-area"},
),
),
(
{ATTR_AREA_ID: "test-area"},
False,
target.SelectedEntities(
indirectly_referenced={
"light.in_area",
"light.assigned_to_area",
},
referenced_areas={"test-area"},
referenced_devices={"device-test-area"},
),
),
(
{ATTR_AREA_ID: ["test-area", "diff-area"]},
False,
target.SelectedEntities(
indirectly_referenced={
"light.in_area",
"light.diff_area",
"light.assigned_to_area",
},
referenced_areas={"test-area", "diff-area"},
referenced_devices={"device-diff-area", "device-test-area"},
missing_areas={"diff-area"},
),
),
(
{ATTR_DEVICE_ID: "device-no-area-id"},
False,
target.SelectedEntities(
indirectly_referenced={"light.no_area"},
referenced_devices={"device-no-area-id"},
),
),
(
{ATTR_DEVICE_ID: "device-area-a-id"},
False,
target.SelectedEntities(
indirectly_referenced={"light.in_area_a", "light.in_area_b"},
referenced_devices={"device-area-a-id"},
),
),
(
{ATTR_FLOOR_ID: "test-floor"},
False,
target.SelectedEntities(
indirectly_referenced={"light.in_area", "light.assigned_to_area"},
referenced_devices={"device-test-area"},
referenced_areas={"test-area"},
missing_floors={"test-floor"},
),
),
(
{ATTR_FLOOR_ID: ["test-floor", "floor-a"]},
False,
target.SelectedEntities(
indirectly_referenced={
"light.in_area",
"light.assigned_to_area",
"light.in_area_a",
"light.with_label1_from_device_diff_area",
},
referenced_devices={"device-area-a-id", "device-test-area"},
referenced_areas={"area-a", "test-area"},
missing_floors={"floor-a", "test-floor"},
),
),
(
{ATTR_LABEL_ID: "my-label"},
False,
target.SelectedEntities(
indirectly_referenced={"light.with_my_label"},
missing_labels={"my-label"},
),
),
(
{ATTR_LABEL_ID: "label1"},
False,
target.SelectedEntities(
indirectly_referenced={
"light.with_label1_from_device",
"light.with_label1_from_device_diff_area",
"light.with_labels_from_device",
"light.with_label1_and_label2_from_device",
},
referenced_devices={"device-has-label1-id", "device-has-labels-id"},
missing_labels={"label1"},
),
),
(
{ATTR_LABEL_ID: ["label2"]},
False,
target.SelectedEntities(
indirectly_referenced={
"light.with_labels_from_device",
"light.with_label1_and_label2_from_device",
},
referenced_devices={"device-has-label2-id", "device-has-labels-id"},
missing_labels={"label2"},
),
),
(
{ATTR_LABEL_ID: ["label_area"]},
False,
target.SelectedEntities(
indirectly_referenced={"light.with_labels_from_device"},
referenced_devices={"device-has-labels-id"},
referenced_areas={"area-with-labels"},
missing_labels={"label_area"},
),
),
],
)
@pytest.mark.usefixtures("registries_mock")
async def test_extract_referenced_entity_ids(
hass: HomeAssistant,
selector_config: ConfigType,
expand_group: bool,
expected_selected: target.SelectedEntities,
) -> None:
"""Test extract_entity_ids method."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set("light.Kitchen", STATE_OFF)
assert await async_setup_component(hass, "group", {})
await hass.async_block_till_done()
await Group.async_create_group(
hass,
"test",
created_by_service=False,
entity_ids=["light.Ceiling", "light.Kitchen"],
icon=None,
mode=None,
object_id=None,
order=None,
)
target_data = target.TargetSelectorData(selector_config)
assert (
target.async_extract_referenced_entity_ids(
hass, target_data, expand_group=expand_group
)
== expected_selected
)