From 504e5b77ca966ff153fb68a314bf556b5b815d42 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 3 Mar 2021 19:12:37 +0100 Subject: [PATCH] Improve behaviour when disabling or enabling config entries (#47301) --- homeassistant/config_entries.py | 32 ++++++--- homeassistant/const.py | 1 - homeassistant/helpers/device_registry.py | 71 +++++++++---------- homeassistant/helpers/entity_registry.py | 87 +++++++++++------------- tests/helpers/test_entity_registry.py | 2 +- 5 files changed, 97 insertions(+), 96 deletions(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index d3e8f66abc4..f74acede507 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -11,10 +11,9 @@ import weakref import attr from homeassistant import data_entry_flow, loader -from homeassistant.const import EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError -from homeassistant.helpers import entity_registry +from homeassistant.helpers import device_registry, entity_registry from homeassistant.helpers.event import Event from homeassistant.helpers.typing import UNDEFINED, UndefinedType from homeassistant.setup import async_process_deps_reqs, async_setup_component @@ -807,12 +806,21 @@ class ConfigEntries: entry.disabled_by = disabled_by self._async_schedule_save() - # Unload the config entry, then fire an event + dev_reg = device_registry.async_get(self.hass) + ent_reg = entity_registry.async_get(self.hass) + + if not entry.disabled_by: + # The config entry will no longer be disabled, enable devices and entities + device_registry.async_config_entry_disabled_by_changed(dev_reg, entry) + entity_registry.async_config_entry_disabled_by_changed(ent_reg, entry) + + # Load or unload the config entry reload_result = await self.async_reload(entry_id) - self.hass.bus.async_fire( - EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED, {"config_entry_id": entry_id} - ) + if entry.disabled_by: + # The config entry has been disabled, disable devices and entities + device_registry.async_config_entry_disabled_by_changed(dev_reg, entry) + entity_registry.async_config_entry_disabled_by_changed(ent_reg, entry) return reload_result @@ -1250,8 +1258,16 @@ class EntityRegistryDisabledHandler: @callback def _handle_entry_updated_filter(event: Event) -> bool: - """Handle entity registry entry update filter.""" - if event.data["action"] != "update" or "disabled_by" not in event.data["changes"]: + """Handle entity registry entry update filter. + + Only handle changes to "disabled_by". + If "disabled_by" was DISABLED_CONFIG_ENTRY, reload is not needed. + """ + if ( + event.data["action"] != "update" + or "disabled_by" not in event.data["changes"] + or event.data["changes"]["disabled_by"] == entity_registry.DISABLED_CONFIG_ENTRY + ): return False return True diff --git a/homeassistant/const.py b/homeassistant/const.py index 712f7ede0d3..1076b962f2a 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -202,7 +202,6 @@ CONF_ZONE = "zone" # #### EVENTS #### EVENT_CALL_SERVICE = "call_service" EVENT_COMPONENT_LOADED = "component_loaded" -EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED = "config_entry_disabled_by_updated" EVENT_CORE_CONFIG_UPDATE = "core_config_updated" EVENT_HOMEASSISTANT_CLOSE = "homeassistant_close" EVENT_HOMEASSISTANT_START = "homeassistant_start" diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 705f6cdd89a..d311538f27f 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -6,10 +6,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union, import attr -from homeassistant.const import ( - EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED, - EVENT_HOMEASSISTANT_STARTED, -) +from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.core import Event, callback from homeassistant.loader import bind_hass import homeassistant.util.uuid as uuid_util @@ -20,6 +17,8 @@ from .typing import UNDEFINED, HomeAssistantType, UndefinedType # mypy: disallow_any_generics if TYPE_CHECKING: + from homeassistant.config_entries import ConfigEntry + from . import entity_registry _LOGGER = logging.getLogger(__name__) @@ -143,10 +142,6 @@ class DeviceRegistry: self.hass = hass self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._clear_index() - self.hass.bus.async_listen( - EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED, - self.async_config_entry_disabled_by_changed, - ) @callback def async_get(self, device_id: str) -> Optional[DeviceEntry]: @@ -618,38 +613,6 @@ class DeviceRegistry: if area_id == device.area_id: self._async_update_device(dev_id, area_id=None) - @callback - def async_config_entry_disabled_by_changed(self, event: Event) -> None: - """Handle a config entry being disabled or enabled. - - Disable devices in the registry that are associated to a config entry when - the config entry is disabled. - """ - config_entry = self.hass.config_entries.async_get_entry( - event.data["config_entry_id"] - ) - - # The config entry may be deleted already if the event handling is late - if not config_entry: - return - - if not config_entry.disabled_by: - devices = async_entries_for_config_entry( - self, event.data["config_entry_id"] - ) - for device in devices: - if device.disabled_by != DISABLED_CONFIG_ENTRY: - continue - self.async_update_device(device.id, disabled_by=None) - return - - devices = async_entries_for_config_entry(self, event.data["config_entry_id"]) - for device in devices: - if device.disabled: - # Entity already disabled, do not overwrite - continue - self.async_update_device(device.id, disabled_by=DISABLED_CONFIG_ENTRY) - @callback def async_get(hass: HomeAssistantType) -> DeviceRegistry: @@ -691,6 +654,34 @@ def async_entries_for_config_entry( ] +@callback +def async_config_entry_disabled_by_changed( + registry: DeviceRegistry, config_entry: "ConfigEntry" +) -> None: + """Handle a config entry being disabled or enabled. + + Disable devices in the registry that are associated with a config entry when + the config entry is disabled, enable devices in the registry that are associated + with a config entry when the config entry is enabled and the devices are marked + DISABLED_CONFIG_ENTRY. + """ + + devices = async_entries_for_config_entry(registry, config_entry.entry_id) + + if not config_entry.disabled_by: + for device in devices: + if device.disabled_by != DISABLED_CONFIG_ENTRY: + continue + registry.async_update_device(device.id, disabled_by=None) + return + + for device in devices: + if device.disabled: + # Device already disabled, do not overwrite + continue + registry.async_update_device(device.id, disabled_by=DISABLED_CONFIG_ENTRY) + + @callback def async_cleanup( hass: HomeAssistantType, diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index f18dda71529..36b010c82a0 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -31,7 +31,6 @@ from homeassistant.const import ( ATTR_RESTORED, ATTR_SUPPORTED_FEATURES, ATTR_UNIT_OF_MEASUREMENT, - EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED, EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE, ) @@ -158,10 +157,6 @@ class EntityRegistry: self.hass.bus.async_listen( EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified ) - self.hass.bus.async_listen( - EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED, - self.async_config_entry_disabled_by_changed, - ) @callback def async_get_device_class_lookup(self, domain_device_classes: set) -> dict: @@ -363,40 +358,6 @@ class EntityRegistry: for entity in entities: self.async_update_entity(entity.entity_id, disabled_by=DISABLED_DEVICE) - @callback - def async_config_entry_disabled_by_changed(self, event: Event) -> None: - """Handle a config entry being disabled or enabled. - - Disable entities in the registry that are associated to a config entry when - the config entry is disabled. - """ - config_entry = self.hass.config_entries.async_get_entry( - event.data["config_entry_id"] - ) - - # The config entry may be deleted already if the event handling is late - if not config_entry: - return - - if not config_entry.disabled_by: - entities = async_entries_for_config_entry( - self, event.data["config_entry_id"] - ) - for entity in entities: - if entity.disabled_by != DISABLED_CONFIG_ENTRY: - continue - self.async_update_entity(entity.entity_id, disabled_by=None) - return - - entities = async_entries_for_config_entry(self, event.data["config_entry_id"]) - for entity in entities: - if entity.disabled: - # Entity already disabled, do not overwrite - continue - self.async_update_entity( - entity.entity_id, disabled_by=DISABLED_CONFIG_ENTRY - ) - @callback def async_update_entity( self, @@ -443,7 +404,8 @@ class EntityRegistry: """Private facing update properties method.""" old = self.entities[entity_id] - changes = {} + new_values = {} # Dict with new key/value pairs + old_values = {} # Dict with old key/value pairs for attr_name, value in ( ("name", name), @@ -460,7 +422,8 @@ class EntityRegistry: ("original_icon", original_icon), ): if value is not UNDEFINED and value != getattr(old, attr_name): - changes[attr_name] = value + new_values[attr_name] = value + old_values[attr_name] = getattr(old, attr_name) if new_entity_id is not UNDEFINED and new_entity_id != old.entity_id: if self.async_is_registered(new_entity_id): @@ -473,7 +436,8 @@ class EntityRegistry: raise ValueError("New entity ID should be same domain") self.entities.pop(entity_id) - entity_id = changes["entity_id"] = new_entity_id + entity_id = new_values["entity_id"] = new_entity_id + old_values["entity_id"] = old.entity_id if new_unique_id is not UNDEFINED: conflict_entity_id = self.async_get_entity_id( @@ -484,18 +448,19 @@ class EntityRegistry: f"Unique id '{new_unique_id}' is already in use by " f"'{conflict_entity_id}'" ) - changes["unique_id"] = new_unique_id + new_values["unique_id"] = new_unique_id + old_values["unique_id"] = old.unique_id - if not changes: + if not new_values: return old self._remove_index(old) - new = attr.evolve(old, **changes) + new = attr.evolve(old, **new_values) self._register_entry(new) self.async_schedule_save() - data = {"action": "update", "entity_id": entity_id, "changes": list(changes)} + data = {"action": "update", "entity_id": entity_id, "changes": old_values} if old.entity_id != entity_id: data["old_entity_id"] = old.entity_id @@ -670,6 +635,36 @@ def async_entries_for_config_entry( ] +@callback +def async_config_entry_disabled_by_changed( + registry: EntityRegistry, config_entry: "ConfigEntry" +) -> None: + """Handle a config entry being disabled or enabled. + + Disable entities in the registry that are associated with a config entry when + the config entry is disabled, enable entities in the registry that are associated + with a config entry when the config entry is enabled and the entities are marked + DISABLED_CONFIG_ENTRY. + """ + + entities = async_entries_for_config_entry(registry, config_entry.entry_id) + + if not config_entry.disabled_by: + for entity in entities: + if entity.disabled_by != DISABLED_CONFIG_ENTRY: + continue + registry.async_update_entity(entity.entity_id, disabled_by=None) + return + + for entity in entities: + if entity.disabled: + # Entity already disabled, do not overwrite + continue + registry.async_update_entity( + entity.entity_id, disabled_by=DISABLED_CONFIG_ENTRY + ) + + async def _async_migrate(entities: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]: """Migrate the YAML config file to storage helper format.""" return { diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 86cdab82238..0a1a27efef5 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -313,7 +313,7 @@ async def test_updating_config_entry_id(hass, registry, update_events): assert update_events[0]["entity_id"] == entry.entity_id assert update_events[1]["action"] == "update" assert update_events[1]["entity_id"] == entry.entity_id - assert update_events[1]["changes"] == ["config_entry_id"] + assert update_events[1]["changes"] == {"config_entry_id": "mock-id-1"} async def test_removing_config_entry_id(hass, registry, update_events):