Warn when referencing missing devices/areas (#43787)

This commit is contained in:
Paulus Schoutsen 2020-12-01 08:01:27 +01:00 committed by GitHub
parent cf9598fe4f
commit cf5be049b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 162 additions and 55 deletions

View File

@ -1,5 +1,6 @@
"""Service calling related helpers.""" """Service calling related helpers."""
import asyncio import asyncio
import dataclasses
from functools import partial, wraps from functools import partial, wraps
import logging import logging
from typing import ( from typing import (
@ -37,8 +38,13 @@ from homeassistant.exceptions import (
Unauthorized, Unauthorized,
UnknownUser, UnknownUser,
) )
from homeassistant.helpers import device_registry, entity_registry, template from homeassistant.helpers import (
import homeassistant.helpers.config_validation as cv area_registry,
config_validation as cv,
device_registry,
entity_registry,
template,
)
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, TemplateVarsType from homeassistant.helpers.typing import ConfigType, HomeAssistantType, TemplateVarsType
from homeassistant.loader import ( from homeassistant.loader import (
MAX_LOAD_CONCURRENTLY, MAX_LOAD_CONCURRENTLY,
@ -64,6 +70,38 @@ _LOGGER = logging.getLogger(__name__)
SERVICE_DESCRIPTION_CACHE = "service_description_cache" SERVICE_DESCRIPTION_CACHE = "service_description_cache"
@dataclasses.dataclass
class SelectedEntities:
"""Class to hold the selected entities."""
# Entities that were explicitly mentioned.
referenced: Set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: Set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found.
missing_devices: Set[str] = dataclasses.field(default_factory=set)
missing_areas: Set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: Set[str]) -> None:
"""Log about missing items."""
parts = []
for label, items in (
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
if not parts:
return
_LOGGER.warning("Unable to find referenced %s", ", ".join(parts))
@bind_hass @bind_hass
def call_from_config( def call_from_config(
hass: HomeAssistantType, hass: HomeAssistantType,
@ -186,25 +224,25 @@ async def async_extract_entities(
if data_ent_id == ENTITY_MATCH_ALL: if data_ent_id == ENTITY_MATCH_ALL:
return [entity for entity in entities if entity.available] return [entity for entity in entities if entity.available]
entity_ids = await async_extract_entity_ids(hass, service_call, expand_group) referenced = await async_extract_referenced_entity_ids(
hass, service_call, expand_group
)
combined = referenced.referenced | referenced.indirectly_referenced
found = [] found = []
for entity in entities: for entity in entities:
if entity.entity_id not in entity_ids: if entity.entity_id not in combined:
continue continue
entity_ids.remove(entity.entity_id) combined.remove(entity.entity_id)
if not entity.available: if not entity.available:
continue continue
found.append(entity) found.append(entity)
if entity_ids: referenced.log_missing(referenced.referenced & combined)
_LOGGER.warning(
"Unable to find referenced entities %s", ", ".join(sorted(entity_ids))
)
return found return found
@ -213,10 +251,21 @@ async def async_extract_entities(
async def async_extract_entity_ids( async def async_extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> Set[str]: ) -> Set[str]:
"""Extract a list of entity ids from a service call. """Extract a set of entity ids from a service call.
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
""" """
referenced = await async_extract_referenced_entity_ids(
hass, service_call, expand_group
)
return referenced.referenced | referenced.indirectly_referenced
@bind_hass
async def async_extract_referenced_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a service call."""
entity_ids = service_call.data.get(ATTR_ENTITY_ID) entity_ids = service_call.data.get(ATTR_ENTITY_ID)
device_ids = service_call.data.get(ATTR_DEVICE_ID) device_ids = service_call.data.get(ATTR_DEVICE_ID)
area_ids = service_call.data.get(ATTR_AREA_ID) area_ids = service_call.data.get(ATTR_AREA_ID)
@ -225,12 +274,14 @@ async def async_extract_entity_ids(
selects_device_ids = device_ids not in (None, ENTITY_MATCH_NONE) selects_device_ids = device_ids not in (None, ENTITY_MATCH_NONE)
selects_area_ids = area_ids not in (None, ENTITY_MATCH_NONE) selects_area_ids = area_ids not in (None, ENTITY_MATCH_NONE)
extracted: Set[str] = set() selected = SelectedEntities()
if not selects_entity_ids and not selects_device_ids and not selects_area_ids: if not selects_entity_ids and not selects_device_ids and not selects_area_ids:
return extracted return selected
if selects_entity_ids: if selects_entity_ids:
assert entity_ids is not None
# Entity ID attr can be a list or a string # Entity ID attr can be a list or a string
if isinstance(entity_ids, str): if isinstance(entity_ids, str):
entity_ids = [entity_ids] entity_ids = [entity_ids]
@ -238,58 +289,68 @@ async def async_extract_entity_ids(
if expand_group: if expand_group:
entity_ids = hass.components.group.expand_entity_ids(entity_ids) entity_ids = hass.components.group.expand_entity_ids(entity_ids)
extracted.update(entity_ids) selected.referenced.update(entity_ids)
if not selects_device_ids and not selects_area_ids: if not selects_device_ids and not selects_area_ids:
return extracted return selected
dev_reg, ent_reg = cast( area_reg, dev_reg, ent_reg = cast(
Tuple[device_registry.DeviceRegistry, entity_registry.EntityRegistry], Tuple[
area_registry.AreaRegistry,
device_registry.DeviceRegistry,
entity_registry.EntityRegistry,
],
await asyncio.gather( await asyncio.gather(
area_registry.async_get_registry(hass),
device_registry.async_get_registry(hass), device_registry.async_get_registry(hass),
entity_registry.async_get_registry(hass), entity_registry.async_get_registry(hass),
), ),
) )
if not selects_device_ids:
picked_devices = set() picked_devices = set()
elif isinstance(device_ids, str):
if selects_device_ids:
if isinstance(device_ids, str):
picked_devices = {device_ids} picked_devices = {device_ids}
else: else:
assert isinstance(device_ids, list) assert isinstance(device_ids, list)
picked_devices = set(device_ids) picked_devices = set(device_ids)
if selects_area_ids: for device_id in picked_devices:
if isinstance(area_ids, str): if device_id not in dev_reg.devices:
area_ids = [area_ids] selected.missing_devices.add(device_id)
assert isinstance(area_ids, list) if selects_area_ids:
assert area_ids is not None
if isinstance(area_ids, str):
area_lookup = {area_ids}
else:
area_lookup = set(area_ids)
for area_id in area_lookup:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
continue
# Find entities tied to an area # Find entities tied to an area
extracted.update( for entity_entry in ent_reg.entities.values():
entry.entity_id if entity_entry.area_id in area_lookup:
for area_id in area_ids selected.indirectly_referenced.add(entity_entry.entity_id)
for entry in entity_registry.async_entries_for_area(ent_reg, area_id)
)
picked_devices.update( # Find devices for this area
[ for device_entry in dev_reg.devices.values():
device.id if device_entry.area_id in area_lookup:
for area_id in area_ids picked_devices.add(device_entry.id)
for device in device_registry.async_entries_for_area(dev_reg, area_id)
]
)
if not picked_devices: if not picked_devices:
return extracted return selected
extracted.update( for entity_entry in ent_reg.entities.values():
entity_entry.entity_id if not entity_entry.area_id and entity_entry.device_id in picked_devices:
for entity_entry in ent_reg.entities.values() selected.indirectly_referenced.add(entity_entry.entity_id)
if not entity_entry.area_id and entity_entry.device_id in picked_devices
)
return extracted return selected
def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JSON_TYPE: def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JSON_TYPE:
@ -416,9 +477,13 @@ async def entity_service_call(
target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL
if not target_all_entities: if target_all_entities:
referenced: Optional[SelectedEntities] = None
all_referenced: Optional[Set[str]] = None
else:
# A set of entities we're trying to target. # A set of entities we're trying to target.
entity_ids = await async_extract_entity_ids(hass, call, True) referenced = await async_extract_referenced_entity_ids(hass, call, True)
all_referenced = referenced.referenced | referenced.indirectly_referenced
# If the service function is a string, we'll pass it the service call data # If the service function is a string, we'll pass it the service call data
if isinstance(func, str): if isinstance(func, str):
@ -441,11 +506,12 @@ async def entity_service_call(
if target_all_entities: if target_all_entities:
entity_candidates.extend(platform.entities.values()) entity_candidates.extend(platform.entities.values())
else: else:
assert all_referenced is not None
entity_candidates.extend( entity_candidates.extend(
[ [
entity entity
for entity in platform.entities.values() for entity in platform.entities.values()
if entity.entity_id in entity_ids if entity.entity_id in all_referenced
] ]
) )
@ -462,11 +528,13 @@ async def entity_service_call(
) )
else: else:
assert all_referenced is not None
for platform in platforms: for platform in platforms:
platform_entities = [] platform_entities = []
for entity in platform.entities.values(): for entity in platform.entities.values():
if entity.entity_id not in entity_ids: if entity.entity_id not in all_referenced:
continue continue
if not entity_perms(entity.entity_id, POLICY_CONTROL): if not entity_perms(entity.entity_id, POLICY_CONTROL):
@ -481,13 +549,15 @@ async def entity_service_call(
entity_candidates.extend(platform_entities) entity_candidates.extend(platform_entities)
if not target_all_entities: if not target_all_entities:
for entity in entity_candidates: assert referenced is not None
entity_ids.remove(entity.entity_id)
if entity_ids: # Only report on explicit referenced entities
_LOGGER.warning( missing = set(referenced.referenced)
"Unable to find referenced entities %s", ", ".join(sorted(entity_ids))
) for entity in entity_candidates:
missing.discard(entity.entity_id)
referenced.log_missing(missing)
entities = [] entities = []

View File

@ -960,3 +960,40 @@ async def test_extract_from_service_area_id(hass, area_mock):
"light.in_area", "light.in_area",
"light.no_area", "light.no_area",
] ]
async def test_entity_service_call_warn_referenced(hass, caplog):
"""Test we only warn for referenced entities in entity_service_call."""
call = ha.ServiceCall(
"light",
"turn_on",
{
"area_id": "non-existent-area",
"entity_id": "non.existent",
"device_id": "non-existent-device",
},
)
await service.entity_service_call(hass, {}, "", call)
assert (
"Unable to find referenced areas non-existent-area, devices non-existent-device, entities non.existent"
in caplog.text
)
async def test_async_extract_entities_warn_referenced(hass, caplog):
"""Test we only warn for referenced entities in async_extract_entities."""
call = ha.ServiceCall(
"light",
"turn_on",
{
"area_id": "non-existent-area",
"entity_id": "non.existent",
"device_id": "non-existent-device",
},
)
extracted = await service.async_extract_entities(hass, {}, call)
assert len(extracted) == 0
assert (
"Unable to find referenced areas non-existent-area, devices non-existent-device, entities non.existent"
in caplog.text
)