From cf5be049b375cf00df9c9c77529fd79d63223a81 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 1 Dec 2020 08:01:27 +0100 Subject: [PATCH] Warn when referencing missing devices/areas (#43787) --- homeassistant/helpers/service.py | 180 +++++++++++++++++++++---------- tests/helpers/test_service.py | 37 +++++++ 2 files changed, 162 insertions(+), 55 deletions(-) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 47918f31514..25a88bb59cb 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -1,5 +1,6 @@ """Service calling related helpers.""" import asyncio +import dataclasses from functools import partial, wraps import logging from typing import ( @@ -37,8 +38,13 @@ from homeassistant.exceptions import ( Unauthorized, UnknownUser, ) -from homeassistant.helpers import device_registry, entity_registry, template -import homeassistant.helpers.config_validation as cv +from homeassistant.helpers import ( + area_registry, + config_validation as cv, + device_registry, + entity_registry, + template, +) from homeassistant.helpers.typing import ConfigType, HomeAssistantType, TemplateVarsType from homeassistant.loader import ( MAX_LOAD_CONCURRENTLY, @@ -64,6 +70,38 @@ _LOGGER = logging.getLogger(__name__) SERVICE_DESCRIPTION_CACHE = "service_description_cache" +@dataclasses.dataclass +class SelectedEntities: + """Class to hold the selected entities.""" + + # Entities that were explicitly mentioned. + referenced: Set[str] = dataclasses.field(default_factory=set) + + # Entities that were referenced via device/area ID. + # Should not trigger a warning when they don't exist. + indirectly_referenced: Set[str] = dataclasses.field(default_factory=set) + + # Referenced items that could not be found. + missing_devices: Set[str] = dataclasses.field(default_factory=set) + missing_areas: Set[str] = dataclasses.field(default_factory=set) + + def log_missing(self, missing_entities: Set[str]) -> None: + """Log about missing items.""" + parts = [] + for label, items in ( + ("areas", self.missing_areas), + ("devices", self.missing_devices), + ("entities", missing_entities), + ): + if items: + parts.append(f"{label} {', '.join(sorted(items))}") + + if not parts: + return + + _LOGGER.warning("Unable to find referenced %s", ", ".join(parts)) + + @bind_hass def call_from_config( hass: HomeAssistantType, @@ -186,25 +224,25 @@ async def async_extract_entities( if data_ent_id == ENTITY_MATCH_ALL: return [entity for entity in entities if entity.available] - entity_ids = await async_extract_entity_ids(hass, service_call, expand_group) + referenced = await async_extract_referenced_entity_ids( + hass, service_call, expand_group + ) + combined = referenced.referenced | referenced.indirectly_referenced found = [] for entity in entities: - if entity.entity_id not in entity_ids: + if entity.entity_id not in combined: continue - entity_ids.remove(entity.entity_id) + combined.remove(entity.entity_id) if not entity.available: continue found.append(entity) - if entity_ids: - _LOGGER.warning( - "Unable to find referenced entities %s", ", ".join(sorted(entity_ids)) - ) + referenced.log_missing(referenced.referenced & combined) return found @@ -213,10 +251,21 @@ async def async_extract_entities( async def async_extract_entity_ids( hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True ) -> Set[str]: - """Extract a list of entity ids from a service call. + """Extract a set of entity ids from a service call. Will convert group entity ids to the entity ids it represents. """ + referenced = await async_extract_referenced_entity_ids( + hass, service_call, expand_group + ) + return referenced.referenced | referenced.indirectly_referenced + + +@bind_hass +async def async_extract_referenced_entity_ids( + hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True +) -> SelectedEntities: + """Extract referenced entity IDs from a service call.""" entity_ids = service_call.data.get(ATTR_ENTITY_ID) device_ids = service_call.data.get(ATTR_DEVICE_ID) area_ids = service_call.data.get(ATTR_AREA_ID) @@ -225,12 +274,14 @@ async def async_extract_entity_ids( selects_device_ids = device_ids not in (None, ENTITY_MATCH_NONE) selects_area_ids = area_ids not in (None, ENTITY_MATCH_NONE) - extracted: Set[str] = set() + selected = SelectedEntities() if not selects_entity_ids and not selects_device_ids and not selects_area_ids: - return extracted + return selected if selects_entity_ids: + assert entity_ids is not None + # Entity ID attr can be a list or a string if isinstance(entity_ids, str): entity_ids = [entity_ids] @@ -238,58 +289,68 @@ async def async_extract_entity_ids( if expand_group: entity_ids = hass.components.group.expand_entity_ids(entity_ids) - extracted.update(entity_ids) + selected.referenced.update(entity_ids) if not selects_device_ids and not selects_area_ids: - return extracted + return selected - dev_reg, ent_reg = cast( - Tuple[device_registry.DeviceRegistry, entity_registry.EntityRegistry], + area_reg, dev_reg, ent_reg = cast( + Tuple[ + area_registry.AreaRegistry, + device_registry.DeviceRegistry, + entity_registry.EntityRegistry, + ], await asyncio.gather( + area_registry.async_get_registry(hass), device_registry.async_get_registry(hass), entity_registry.async_get_registry(hass), ), ) - if not selects_device_ids: - picked_devices = set() - elif isinstance(device_ids, str): - picked_devices = {device_ids} - else: - assert isinstance(device_ids, list) - picked_devices = set(device_ids) + picked_devices = set() + + if selects_device_ids: + if isinstance(device_ids, str): + picked_devices = {device_ids} + else: + assert isinstance(device_ids, list) + picked_devices = set(device_ids) + + for device_id in picked_devices: + if device_id not in dev_reg.devices: + selected.missing_devices.add(device_id) if selects_area_ids: - if isinstance(area_ids, str): - area_ids = [area_ids] + assert area_ids is not None - assert isinstance(area_ids, list) + if isinstance(area_ids, str): + area_lookup = {area_ids} + else: + area_lookup = set(area_ids) + + for area_id in area_lookup: + if area_id not in area_reg.areas: + selected.missing_areas.add(area_id) + continue # Find entities tied to an area - extracted.update( - entry.entity_id - for area_id in area_ids - for entry in entity_registry.async_entries_for_area(ent_reg, area_id) - ) + for entity_entry in ent_reg.entities.values(): + if entity_entry.area_id in area_lookup: + selected.indirectly_referenced.add(entity_entry.entity_id) - picked_devices.update( - [ - device.id - for area_id in area_ids - for device in device_registry.async_entries_for_area(dev_reg, area_id) - ] - ) + # Find devices for this area + for device_entry in dev_reg.devices.values(): + if device_entry.area_id in area_lookup: + picked_devices.add(device_entry.id) if not picked_devices: - return extracted + return selected - extracted.update( - entity_entry.entity_id - for entity_entry in ent_reg.entities.values() - if not entity_entry.area_id and entity_entry.device_id in picked_devices - ) + for entity_entry in ent_reg.entities.values(): + if not entity_entry.area_id and entity_entry.device_id in picked_devices: + selected.indirectly_referenced.add(entity_entry.entity_id) - return extracted + return selected def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JSON_TYPE: @@ -416,9 +477,13 @@ async def entity_service_call( target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL - if not target_all_entities: + if target_all_entities: + referenced: Optional[SelectedEntities] = None + all_referenced: Optional[Set[str]] = None + else: # A set of entities we're trying to target. - entity_ids = await async_extract_entity_ids(hass, call, True) + referenced = await async_extract_referenced_entity_ids(hass, call, True) + all_referenced = referenced.referenced | referenced.indirectly_referenced # If the service function is a string, we'll pass it the service call data if isinstance(func, str): @@ -441,11 +506,12 @@ async def entity_service_call( if target_all_entities: entity_candidates.extend(platform.entities.values()) else: + assert all_referenced is not None entity_candidates.extend( [ entity for entity in platform.entities.values() - if entity.entity_id in entity_ids + if entity.entity_id in all_referenced ] ) @@ -462,11 +528,13 @@ async def entity_service_call( ) else: + assert all_referenced is not None + for platform in platforms: platform_entities = [] for entity in platform.entities.values(): - if entity.entity_id not in entity_ids: + if entity.entity_id not in all_referenced: continue if not entity_perms(entity.entity_id, POLICY_CONTROL): @@ -481,13 +549,15 @@ async def entity_service_call( entity_candidates.extend(platform_entities) if not target_all_entities: - for entity in entity_candidates: - entity_ids.remove(entity.entity_id) + assert referenced is not None - if entity_ids: - _LOGGER.warning( - "Unable to find referenced entities %s", ", ".join(sorted(entity_ids)) - ) + # Only report on explicit referenced entities + missing = set(referenced.referenced) + + for entity in entity_candidates: + missing.discard(entity.entity_id) + + referenced.log_missing(missing) entities = [] diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index f9b09b259ca..a75593ddd40 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -960,3 +960,40 @@ async def test_extract_from_service_area_id(hass, area_mock): "light.in_area", "light.no_area", ] + + +async def test_entity_service_call_warn_referenced(hass, caplog): + """Test we only warn for referenced entities in entity_service_call.""" + call = ha.ServiceCall( + "light", + "turn_on", + { + "area_id": "non-existent-area", + "entity_id": "non.existent", + "device_id": "non-existent-device", + }, + ) + await service.entity_service_call(hass, {}, "", call) + assert ( + "Unable to find referenced areas non-existent-area, devices non-existent-device, entities non.existent" + in caplog.text + ) + + +async def test_async_extract_entities_warn_referenced(hass, caplog): + """Test we only warn for referenced entities in async_extract_entities.""" + call = ha.ServiceCall( + "light", + "turn_on", + { + "area_id": "non-existent-area", + "entity_id": "non.existent", + "device_id": "non-existent-device", + }, + ) + extracted = await service.async_extract_entities(hass, {}, call) + assert len(extracted) == 0 + assert ( + "Unable to find referenced areas non-existent-area, devices non-existent-device, entities non.existent" + in caplog.text + )