Move target selector extractor method to common module

This commit is contained in:
abmantis 2025-07-03 22:55:38 +01:00
parent 5d553e5641
commit 0ead4c033e
2 changed files with 253 additions and 237 deletions

View File

@ -3,21 +3,41 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Mapping, Sequence from collections.abc import Callable, Mapping, Sequence
import dataclasses
from enum import StrEnum from enum import StrEnum
from functools import cache from functools import cache
import importlib import importlib
from typing import Any, Literal, Required, TypedDict, cast from logging import Logger
from typing import Any, Literal, Required, TypedDict, TypeGuard, cast
from uuid import UUID from uuid import UUID
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_MODE, CONF_UNIT_OF_MEASUREMENT from homeassistant.const import (
from homeassistant.core import split_entity_id, valid_entity_id ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
CONF_MODE,
CONF_UNIT_OF_MEASUREMENT,
ENTITY_MATCH_NONE,
)
from homeassistant.core import HomeAssistant, split_entity_id, valid_entity_id
from homeassistant.generated.countries import COUNTRIES from homeassistant.generated.countries import COUNTRIES
from homeassistant.util import decorator from homeassistant.util import decorator
from homeassistant.util.yaml import dumper from homeassistant.util.yaml import dumper
from . import config_validation as cv from . import (
area_registry as ar,
config_validation as cv,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
group,
label_registry as lr,
)
from .typing import ConfigType
SELECTORS: decorator.Registry[str, type[Selector]] = decorator.Registry() SELECTORS: decorator.Registry[str, type[Selector]] = decorator.Registry()
@ -1551,3 +1571,215 @@ dumper.add_representer(
dumper, "tag:yaml.org,2002:map", value.serialize() dumper, "tag:yaml.org,2002:map", value.serialize()
), ),
) )
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything."""
return ids not in (None, ENTITY_MATCH_NONE)
class TargetSelectorData:
"""Class to hold data of target selector."""
__slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids")
def __init__(self, config: ConfigType) -> None:
"""Extract ids from the config."""
entity_ids: str | list | None = config.get(ATTR_ENTITY_ID)
device_ids: str | list | None = config.get(ATTR_DEVICE_ID)
area_ids: str | list | None = config.get(ATTR_AREA_ID)
floor_ids: str | list | None = config.get(ATTR_FLOOR_ID)
label_ids: str | list | None = config.get(ATTR_LABEL_ID)
self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
)
self.device_ids = (
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
)
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
self.floor_ids = (
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
)
self.label_ids = (
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
)
@property
def has_any_selector(self) -> bool:
"""Determine if any selectors are present."""
return bool(
self.entity_ids
or self.device_ids
or self.area_ids
or self.floor_ids
or self.label_ids
)
@dataclasses.dataclass(slots=True)
class SelectedEntities:
"""Class to hold the selected entities."""
# Entities that were explicitly mentioned.
referenced: set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area/floor/label ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found.
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
missing_floors: set[str] = dataclasses.field(default_factory=set)
missing_labels: set[str] = dataclasses.field(default_factory=set)
referenced_devices: set[str] = dataclasses.field(default_factory=set)
referenced_areas: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: set[str], logger: Logger) -> None:
"""Log about missing items."""
parts = []
for label, items in (
("floors", self.missing_floors),
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
("labels", self.missing_labels),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
if not parts:
return
logger.warning(
"Referenced %s are missing or not currently available",
", ".join(parts),
)
def async_extract_referenced_entity_ids(
hass: HomeAssistant, selector_data: TargetSelectorData, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a target selector."""
selected = SelectedEntities()
if not selector_data.has_any_selector:
return selected
entity_ids: set[str] | list[str] = selector_data.entity_ids
if expand_group:
entity_ids = group.expand_entity_ids(hass, entity_ids)
selected.referenced.update(entity_ids)
if (
not selector_data.device_ids
and not selector_data.area_ids
and not selector_data.floor_ids
and not selector_data.label_ids
):
return selected
entities = er.async_get(hass).entities
dev_reg = dr.async_get(hass)
area_reg = ar.async_get(hass)
if selector_data.floor_ids:
floor_reg = fr.async_get(hass)
for floor_id in selector_data.floor_ids:
if floor_id not in floor_reg.floors:
selected.missing_floors.add(floor_id)
for area_id in selector_data.area_ids:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
for device_id in selector_data.device_ids:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
if selector_data.label_ids:
label_reg = lr.async_get(hass)
for label_id in selector_data.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
for entity_entry in entities.get_entries_for_label(label_id):
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
):
selected.indirectly_referenced.add(entity_entry.entity_id)
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id)
for area_entry in area_reg.areas.get_areas_for_label(label_id):
selected.referenced_areas.add(area_entry.id)
# Find areas for targeted floors
if selector_data.floor_ids:
selected.referenced_areas.update(
area_entry.id
for floor_id in selector_data.floor_ids
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
)
selected.referenced_areas.update(selector_data.area_ids)
selected.referenced_devices.update(selector_data.device_ids)
if not selected.referenced_areas and not selected.referenced_devices:
return selected
# Add indirectly referenced by device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in selected.referenced_devices
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (entry.entity_category is None and entry.hidden_by is None)
)
# Find devices for targeted areas
referenced_devices_by_area: set[str] = set()
if selected.referenced_areas:
for area_id in selected.referenced_areas:
referenced_devices_by_area.update(
device_entry.id
for device_entry in dev_reg.devices.get_devices_for_area_id(area_id)
)
selected.referenced_devices.update(referenced_devices_by_area)
# Add indirectly referenced by area
selected.indirectly_referenced.update(
entry.entity_id
for area_id in selected.referenced_areas
# The entity's area matches a targeted area
for entry in entities.get_entries_for_area_id(area_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if entry.entity_category is None and entry.hidden_by is None
)
# Add indirectly referenced by area through device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in referenced_devices_by_area
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (
entry.entity_category is None
and entry.hidden_by is None
and (
# The entity's device matches a device referenced
# by an area and the entity
# has no explicitly set area
not entry.area_id
)
)
)
return selected

View File

@ -4,22 +4,17 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine, Iterable from collections.abc import Callable, Coroutine, Iterable
import dataclasses
from enum import Enum from enum import Enum
from functools import cache, partial from functools import cache, partial
import logging import logging
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Any, TypedDict, TypeGuard, cast from typing import TYPE_CHECKING, Any, TypedDict, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL
from homeassistant.const import ( from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
CONF_ACTION, CONF_ACTION,
CONF_ENTITY_ID, CONF_ENTITY_ID,
CONF_SERVICE_DATA, CONF_SERVICE_DATA,
@ -54,17 +49,18 @@ from homeassistant.util.yaml import load_yaml_dict
from homeassistant.util.yaml.loader import JSON_TYPE from homeassistant.util.yaml.loader import JSON_TYPE
from . import ( from . import (
area_registry,
config_validation as cv, config_validation as cv,
device_registry, device_registry,
entity_registry, entity_registry,
floor_registry,
label_registry,
template, template,
translation, translation,
) )
from .group import expand_entity_ids from .selector import (
from .selector import TargetSelector SelectedEntities,
TargetSelector,
TargetSelectorData,
async_extract_referenced_entity_ids,
)
from .typing import ConfigType, TemplateVarsType, VolDictType, VolSchemaType from .typing import ConfigType, TemplateVarsType, VolDictType, VolSchemaType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -223,89 +219,6 @@ class ServiceParams(TypedDict):
target: dict | None target: dict | None
class ServiceTargetSelector:
"""Class to hold a target selector for a service."""
__slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids")
def __init__(self, service_call: ServiceCall) -> None:
"""Extract ids from service call data."""
service_call_data = service_call.data
entity_ids: str | list | None = service_call_data.get(ATTR_ENTITY_ID)
device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID)
area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID)
floor_ids: str | list | None = service_call_data.get(ATTR_FLOOR_ID)
label_ids: str | list | None = service_call_data.get(ATTR_LABEL_ID)
self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
)
self.device_ids = (
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
)
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
self.floor_ids = (
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
)
self.label_ids = (
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
)
@property
def has_any_selector(self) -> bool:
"""Determine if any selectors are present."""
return bool(
self.entity_ids
or self.device_ids
or self.area_ids
or self.floor_ids
or self.label_ids
)
@dataclasses.dataclass(slots=True)
class SelectedEntities:
"""Class to hold the selected entities."""
# Entities that were explicitly mentioned.
referenced: set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area/floor/label ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found.
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
missing_floors: set[str] = dataclasses.field(default_factory=set)
missing_labels: set[str] = dataclasses.field(default_factory=set)
# Referenced devices
referenced_devices: set[str] = dataclasses.field(default_factory=set)
referenced_areas: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: set[str]) -> None:
"""Log about missing items."""
parts = []
for label, items in (
("floors", self.missing_floors),
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
("labels", self.missing_labels),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
if not parts:
return
_LOGGER.warning(
"Referenced %s are missing or not currently available",
", ".join(parts),
)
@bind_hass @bind_hass
def call_from_config( def call_from_config(
hass: HomeAssistant, hass: HomeAssistant,
@ -464,7 +377,8 @@ async def async_extract_entities[_EntityT: Entity](
if data_ent_id == ENTITY_MATCH_ALL: if data_ent_id == ENTITY_MATCH_ALL:
return [entity for entity in entities if entity.available] return [entity for entity in entities if entity.available]
referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group) selector_data = TargetSelectorData(service_call.data)
referenced = async_extract_referenced_entity_ids(hass, selector_data, expand_group)
combined = referenced.referenced | referenced.indirectly_referenced combined = referenced.referenced | referenced.indirectly_referenced
found = [] found = []
@ -480,7 +394,7 @@ async def async_extract_entities[_EntityT: Entity](
found.append(entity) found.append(entity)
referenced.log_missing(referenced.referenced & combined) referenced.log_missing(referenced.referenced & combined, _LOGGER)
return found return found
@ -493,149 +407,18 @@ async def async_extract_entity_ids(
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
""" """
referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group) selector_data = TargetSelectorData(service_call.data)
referenced = async_extract_referenced_entity_ids(hass, selector_data, expand_group)
return referenced.referenced | referenced.indirectly_referenced return referenced.referenced | referenced.indirectly_referenced
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything."""
return ids not in (None, ENTITY_MATCH_NONE)
@bind_hass
def async_extract_referenced_entity_ids(
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a service call."""
selector = ServiceTargetSelector(service_call)
selected = SelectedEntities()
if not selector.has_any_selector:
return selected
entity_ids: set[str] | list[str] = selector.entity_ids
if expand_group:
entity_ids = expand_entity_ids(hass, entity_ids)
selected.referenced.update(entity_ids)
if (
not selector.device_ids
and not selector.area_ids
and not selector.floor_ids
and not selector.label_ids
):
return selected
entities = entity_registry.async_get(hass).entities
dev_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
if selector.floor_ids:
floor_reg = floor_registry.async_get(hass)
for floor_id in selector.floor_ids:
if floor_id not in floor_reg.floors:
selected.missing_floors.add(floor_id)
for area_id in selector.area_ids:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
for device_id in selector.device_ids:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
if selector.label_ids:
label_reg = label_registry.async_get(hass)
for label_id in selector.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
for entity_entry in entities.get_entries_for_label(label_id):
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
):
selected.indirectly_referenced.add(entity_entry.entity_id)
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id)
for area_entry in area_reg.areas.get_areas_for_label(label_id):
selected.referenced_areas.add(area_entry.id)
# Find areas for targeted floors
if selector.floor_ids:
selected.referenced_areas.update(
area_entry.id
for floor_id in selector.floor_ids
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
)
selected.referenced_areas.update(selector.area_ids)
selected.referenced_devices.update(selector.device_ids)
if not selected.referenced_areas and not selected.referenced_devices:
return selected
# Add indirectly referenced by device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in selected.referenced_devices
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (entry.entity_category is None and entry.hidden_by is None)
)
# Find devices for targeted areas
referenced_devices_by_area: set[str] = set()
if selected.referenced_areas:
for area_id in selected.referenced_areas:
referenced_devices_by_area.update(
device_entry.id
for device_entry in dev_reg.devices.get_devices_for_area_id(area_id)
)
selected.referenced_devices.update(referenced_devices_by_area)
# Add indirectly referenced by area
selected.indirectly_referenced.update(
entry.entity_id
for area_id in selected.referenced_areas
# The entity's area matches a targeted area
for entry in entities.get_entries_for_area_id(area_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if entry.entity_category is None and entry.hidden_by is None
)
# Add indirectly referenced by area through device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in referenced_devices_by_area
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (
entry.entity_category is None
and entry.hidden_by is None
and (
# The entity's device matches a device referenced
# by an area and the entity
# has no explicitly set area
not entry.area_id
)
)
)
return selected
@bind_hass @bind_hass
async def async_extract_config_entry_ids( async def async_extract_config_entry_ids(
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
) -> set[str]: ) -> set[str]:
"""Extract referenced config entry ids from a service call.""" """Extract referenced config entry ids from a service call."""
referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group) selector_data = TargetSelectorData(service_call.data)
referenced = async_extract_referenced_entity_ids(hass, selector_data, expand_group)
ent_reg = entity_registry.async_get(hass) ent_reg = entity_registry.async_get(hass)
dev_reg = device_registry.async_get(hass) dev_reg = device_registry.async_get(hass)
config_entry_ids: set[str] = set() config_entry_ids: set[str] = set()
@ -950,7 +733,8 @@ async def entity_service_call(
all_referenced: set[str] | None = None all_referenced: set[str] | None = None
else: else:
# A set of entities we're trying to target. # A set of entities we're trying to target.
referenced = async_extract_referenced_entity_ids(hass, call, True) selector_data = TargetSelectorData(call.data)
referenced = async_extract_referenced_entity_ids(hass, selector_data, True)
all_referenced = referenced.referenced | referenced.indirectly_referenced all_referenced = referenced.referenced | referenced.indirectly_referenced
# If the service function is a string, we'll pass it the service call data # If the service function is a string, we'll pass it the service call data
@ -975,7 +759,7 @@ async def entity_service_call(
missing = referenced.referenced.copy() missing = referenced.referenced.copy()
for entity in entity_candidates: for entity in entity_candidates:
missing.discard(entity.entity_id) missing.discard(entity.entity_id)
referenced.log_missing(missing) referenced.log_missing(missing, _LOGGER)
entities: list[Entity] = [] entities: list[Entity] = []
for entity in entity_candidates: for entity in entity_candidates: