From 08db262972c33c7a466b93b9dc90f460a255fb33 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 17 Mar 2021 18:27:21 -1000 Subject: [PATCH] Add a service to reload config entries that can easily be called though automations (#46762) --- .../components/homeassistant/__init__.py | 39 ++++- .../components/homeassistant/services.yaml | 16 ++ homeassistant/helpers/service.py | 158 ++++++++++-------- tests/components/homeassistant/test_init.py | 62 +++++++ tests/helpers/test_service.py | 25 +++ 5 files changed, 228 insertions(+), 72 deletions(-) diff --git a/homeassistant/components/homeassistant/__init__.py b/homeassistant/components/homeassistant/__init__.py index c2ee40b7d43..309f98e6095 100644 --- a/homeassistant/components/homeassistant/__init__.py +++ b/homeassistant/components/homeassistant/__init__.py @@ -21,15 +21,30 @@ from homeassistant.const import ( import homeassistant.core as ha from homeassistant.exceptions import HomeAssistantError, Unauthorized, UnknownUser from homeassistant.helpers import config_validation as cv -from homeassistant.helpers.service import async_extract_referenced_entity_ids +from homeassistant.helpers.service import ( + async_extract_config_entry_ids, + async_extract_referenced_entity_ids, +) + +ATTR_ENTRY_ID = "entry_id" _LOGGER = logging.getLogger(__name__) DOMAIN = ha.DOMAIN SERVICE_RELOAD_CORE_CONFIG = "reload_core_config" +SERVICE_RELOAD_CONFIG_ENTRY = "reload_config_entry" SERVICE_CHECK_CONFIG = "check_config" SERVICE_UPDATE_ENTITY = "update_entity" SERVICE_SET_LOCATION = "set_location" SCHEMA_UPDATE_ENTITY = vol.Schema({ATTR_ENTITY_ID: cv.entity_ids}) +SCHEMA_RELOAD_CONFIG_ENTRY = vol.All( + vol.Schema( + { + vol.Optional(ATTR_ENTRY_ID): str, + **cv.ENTITY_SERVICE_FIELDS, + }, + ), + cv.has_at_least_one_key(ATTR_ENTRY_ID, *cv.ENTITY_SERVICE_FIELDS), +) async def async_setup(hass: ha.HomeAssistant, config: dict) -> bool: @@ -203,4 +218,26 @@ async def async_setup(hass: ha.HomeAssistant, config: dict) -> bool: vol.Schema({ATTR_LATITUDE: cv.latitude, ATTR_LONGITUDE: cv.longitude}), ) + async def async_handle_reload_config_entry(call): + """Service handler for reloading a config entry.""" + reload_entries = set() + if ATTR_ENTRY_ID in call.data: + reload_entries.add(call.data[ATTR_ENTRY_ID]) + reload_entries.update(await async_extract_config_entry_ids(hass, call)) + if not reload_entries: + raise ValueError("There were no matching config entries to reload") + await asyncio.gather( + *[ + hass.config_entries.async_reload(config_entry_id) + for config_entry_id in reload_entries + ] + ) + + hass.helpers.service.async_register_admin_service( + ha.DOMAIN, + SERVICE_RELOAD_CONFIG_ENTRY, + async_handle_reload_config_entry, + schema=SCHEMA_RELOAD_CONFIG_ENTRY, + ) + return True diff --git a/homeassistant/components/homeassistant/services.yaml b/homeassistant/components/homeassistant/services.yaml index 38814d9f902..251ee171b6a 100644 --- a/homeassistant/components/homeassistant/services.yaml +++ b/homeassistant/components/homeassistant/services.yaml @@ -58,3 +58,19 @@ update_entity: description: Force one or more entities to update its data target: entity: {} + +reload_config_entry: + name: Reload config entry + description: Reload a config entry that matches a target. + target: + entity: {} + device: {} + fields: + entry_id: + advanced: true + name: Config entry id + description: A configuration entry id + required: false + example: 8955375327824e14ba89e4b29cc3ec9a + selector: + text: diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index f1ab245d38a..f43b85575b8 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -11,9 +11,9 @@ from typing import ( Awaitable, Callable, Iterable, - Tuple, + Optional, TypedDict, - cast, + Union, ) import voluptuous as vol @@ -78,6 +78,29 @@ class ServiceParams(TypedDict): target: dict | None +class ServiceTargetSelector: + """Class to hold a target selector for a service.""" + + def __init__(self, service_call: ha.ServiceCall): + """Extract ids from service call data.""" + entity_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_ENTITY_ID) + device_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_DEVICE_ID) + area_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_AREA_ID) + + self.entity_ids = ( + set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() + ) + self.device_ids = ( + set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set() + ) + self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set() + + @property + def has_any_selector(self) -> bool: + """Determine if any selectors are present.""" + return bool(self.entity_ids or self.device_ids or self.area_ids) + + @dataclasses.dataclass class SelectedEntities: """Class to hold the selected entities.""" @@ -93,6 +116,9 @@ class SelectedEntities: missing_devices: set[str] = dataclasses.field(default_factory=set) missing_areas: set[str] = dataclasses.field(default_factory=set) + # Referenced devices + referenced_devices: set[str] = dataclasses.field(default_factory=set) + def log_missing(self, missing_entities: set[str]) -> None: """Log about missing items.""" parts = [] @@ -293,98 +319,88 @@ async def async_extract_entity_ids( return referenced.referenced | referenced.indirectly_referenced +def _has_match(ids: Optional[Union[str, list]]) -> bool: + """Check if ids can match anything.""" + return ids not in (None, ENTITY_MATCH_NONE) + + @bind_hass async def async_extract_referenced_entity_ids( hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True ) -> SelectedEntities: """Extract referenced entity IDs from a service call.""" - entity_ids = service_call.data.get(ATTR_ENTITY_ID) - device_ids = service_call.data.get(ATTR_DEVICE_ID) - area_ids = service_call.data.get(ATTR_AREA_ID) - - selects_entity_ids = entity_ids not in (None, ENTITY_MATCH_NONE) - selects_device_ids = device_ids not in (None, ENTITY_MATCH_NONE) - selects_area_ids = area_ids not in (None, ENTITY_MATCH_NONE) - + selector = ServiceTargetSelector(service_call) selected = SelectedEntities() - if not selects_entity_ids and not selects_device_ids and not selects_area_ids: + if not selector.has_any_selector: return selected - if selects_entity_ids: - assert entity_ids is not None + entity_ids = selector.entity_ids + if expand_group: + entity_ids = hass.components.group.expand_entity_ids(entity_ids) - # Entity ID attr can be a list or a string - if isinstance(entity_ids, str): - entity_ids = [entity_ids] + selected.referenced.update(entity_ids) - if expand_group: - entity_ids = hass.components.group.expand_entity_ids(entity_ids) - - selected.referenced.update(entity_ids) - - if not selects_device_ids and not selects_area_ids: + if not selector.device_ids and not selector.area_ids: return selected - area_reg, dev_reg, ent_reg = cast( - Tuple[ - area_registry.AreaRegistry, - device_registry.DeviceRegistry, - entity_registry.EntityRegistry, - ], - await asyncio.gather( - area_registry.async_get_registry(hass), - device_registry.async_get_registry(hass), - entity_registry.async_get_registry(hass), - ), - ) + ent_reg = entity_registry.async_get(hass) + dev_reg = device_registry.async_get(hass) + area_reg = area_registry.async_get(hass) - picked_devices = set() + for device_id in selector.device_ids: + if device_id not in dev_reg.devices: + selected.missing_devices.add(device_id) - if selects_device_ids: - if isinstance(device_ids, str): - picked_devices = {device_ids} - else: - assert isinstance(device_ids, list) - picked_devices = set(device_ids) + for area_id in selector.area_ids: + if area_id not in area_reg.areas: + selected.missing_areas.add(area_id) - for device_id in picked_devices: - if device_id not in dev_reg.devices: - selected.missing_devices.add(device_id) + # Find devices for this area + selected.referenced_devices.update(selector.device_ids) + for device_entry in dev_reg.devices.values(): + if device_entry.area_id in selector.area_ids: + selected.referenced_devices.add(device_entry.id) - if selects_area_ids: - assert area_ids is not None - - if isinstance(area_ids, str): - area_lookup = {area_ids} - else: - area_lookup = set(area_ids) - - for area_id in area_lookup: - if area_id not in area_reg.areas: - selected.missing_areas.add(area_id) - continue - - # Find entities tied to an area - for entity_entry in ent_reg.entities.values(): - if entity_entry.area_id in area_lookup: - selected.indirectly_referenced.add(entity_entry.entity_id) - - # Find devices for this area - for device_entry in dev_reg.devices.values(): - if device_entry.area_id in area_lookup: - picked_devices.add(device_entry.id) - - if not picked_devices: + if not selector.area_ids and not selected.referenced_devices: return selected - for entity_entry in ent_reg.entities.values(): - if not entity_entry.area_id and entity_entry.device_id in picked_devices: - selected.indirectly_referenced.add(entity_entry.entity_id) + for ent_entry in ent_reg.entities.values(): + if ent_entry.area_id in selector.area_ids or ( + not ent_entry.area_id and ent_entry.device_id in selected.referenced_devices + ): + selected.indirectly_referenced.add(ent_entry.entity_id) return selected +@bind_hass +async def async_extract_config_entry_ids( + hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True +) -> set: + """Extract referenced config entry ids from a service call.""" + referenced = await async_extract_referenced_entity_ids( + hass, service_call, expand_group + ) + ent_reg = entity_registry.async_get(hass) + dev_reg = device_registry.async_get(hass) + config_entry_ids: set[str] = set() + + # Some devices may have no entities + for device_id in referenced.referenced_devices: + if device_id in dev_reg.devices: + device = dev_reg.async_get(device_id) + if device is not None: + config_entry_ids.update(device.config_entries) + + for entity_id in referenced.referenced | referenced.indirectly_referenced: + entry = ent_reg.async_get(entity_id) + if entry is not None and entry.config_entry_id is not None: + config_entry_ids.add(entry.config_entry_id) + + return config_entry_ids + + def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JSON_TYPE: """Load services file for an integration.""" try: diff --git a/tests/components/homeassistant/test_init.py b/tests/components/homeassistant/test_init.py index ef830c7ee77..51646ae7139 100644 --- a/tests/components/homeassistant/test_init.py +++ b/tests/components/homeassistant/test_init.py @@ -11,6 +11,7 @@ import yaml from homeassistant import config import homeassistant.components as comps from homeassistant.components.homeassistant import ( + ATTR_ENTRY_ID, SERVICE_CHECK_CONFIG, SERVICE_RELOAD_CORE_CONFIG, SERVICE_SET_LOCATION, @@ -34,9 +35,11 @@ from homeassistant.helpers import entity from homeassistant.setup import async_setup_component from tests.common import ( + MockConfigEntry, async_capture_events, async_mock_service, get_test_home_assistant, + mock_registry, mock_service, patch_yaml_files, ) @@ -385,3 +388,62 @@ async def test_not_allowing_recursion(hass, caplog): f"Called service homeassistant.{service} with invalid entities homeassistant.light" in caplog.text ), service + + +async def test_reload_config_entry_by_entity_id(hass): + """Test being able to reload a config entry by entity_id.""" + await async_setup_component(hass, "homeassistant", {}) + entity_reg = mock_registry(hass) + entry1 = MockConfigEntry(domain="mockdomain") + entry1.add_to_hass(hass) + entry2 = MockConfigEntry(domain="mockdomain") + entry2.add_to_hass(hass) + reg_entity1 = entity_reg.async_get_or_create( + "binary_sensor", "powerwall", "battery_charging", config_entry=entry1 + ) + reg_entity2 = entity_reg.async_get_or_create( + "binary_sensor", "powerwall", "battery_status", config_entry=entry2 + ) + with patch( + "homeassistant.config_entries.ConfigEntries.async_reload", + return_value=None, + ) as mock_reload: + await hass.services.async_call( + "homeassistant", + "reload_config_entry", + {"entity_id": f"{reg_entity1.entity_id},{reg_entity2.entity_id}"}, + blocking=True, + ) + + assert len(mock_reload.mock_calls) == 2 + assert {mock_reload.mock_calls[0][1][0], mock_reload.mock_calls[1][1][0]} == { + entry1.entry_id, + entry2.entry_id, + } + + with pytest.raises(ValueError): + await hass.services.async_call( + "homeassistant", + "reload_config_entry", + {"entity_id": "unknown.entity_id"}, + blocking=True, + ) + + +async def test_reload_config_entry_by_entry_id(hass): + """Test being able to reload a config entry by config entry id.""" + await async_setup_component(hass, "homeassistant", {}) + + with patch( + "homeassistant.config_entries.ConfigEntries.async_reload", + return_value=None, + ) as mock_reload: + await hass.services.async_call( + "homeassistant", + "reload_config_entry", + {ATTR_ENTRY_ID: "8955375327824e14ba89e4b29cc3ec9a"}, + blocking=True, + ) + + assert len(mock_reload.mock_calls) == 1 + assert mock_reload.mock_calls[0][1][0] == "8955375327824e14ba89e4b29cc3ec9a" diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 92cbd5514e6..7a084fed9dd 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -1015,3 +1015,28 @@ async def test_async_extract_entities_warn_referenced(hass, caplog): "Unable to find referenced areas non-existent-area, devices non-existent-device, entities non.existent" in caplog.text ) + + +async def test_async_extract_config_entry_ids(hass): + """Test we can find devices that have no entities.""" + + device_no_entities = dev_reg.DeviceEntry( + id="device-no-entities", config_entries={"abc"} + ) + + call = ha.ServiceCall( + "homeassistant", + "reload_config_entry", + { + "device_id": "device-no-entities", + }, + ) + + mock_device_registry( + hass, + { + device_no_entities.id: device_no_entities, + }, + ) + + assert await service.async_extract_config_entry_ids(hass, call) == {"abc"}