Clean up device registry if entity registry updates (#35106)

This commit is contained in:
Paulus Schoutsen 2020-05-05 10:53:46 -07:00 committed by GitHub
parent 2ac29cf1a4
commit 4ae31bc938
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 157 additions and 18 deletions

View File

@ -21,7 +21,7 @@ class Debouncer:
"""Initialize debounce. """Initialize debounce.
immediate: indicate if the function needs to be called right away and immediate: indicate if the function needs to be called right away and
wait 0.3s until executing next invocation. wait <cooldown> until executing next invocation.
function: optional and can be instantiated later. function: optional and can be instantiated later.
""" """
self.hass = hass self.hass = hass

View File

@ -1,16 +1,21 @@
"""Provide a way to connect entities belonging to one device.""" """Provide a way to connect entities belonging to one device."""
from collections import OrderedDict from collections import OrderedDict
import logging import logging
from typing import Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
import uuid import uuid
import attr import attr
from homeassistant.core import callback from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import Event, callback
from .debounce import Debouncer
from .singleton import singleton from .singleton import singleton
from .typing import HomeAssistantType from .typing import HomeAssistantType
if TYPE_CHECKING:
from . import entity_registry
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -21,6 +26,7 @@ EVENT_DEVICE_REGISTRY_UPDATED = "device_registry_updated"
STORAGE_KEY = "core.device_registry" STORAGE_KEY = "core.device_registry"
STORAGE_VERSION = 1 STORAGE_VERSION = 1
SAVE_DELAY = 10 SAVE_DELAY = 10
CLEANUP_DELAY = 10
CONNECTION_NETWORK_MAC = "mac" CONNECTION_NETWORK_MAC = "mac"
CONNECTION_UPNP = "upnp" CONNECTION_UPNP = "upnp"
@ -285,6 +291,8 @@ class DeviceRegistry:
async def async_load(self): async def async_load(self):
"""Load the device registry.""" """Load the device registry."""
async_setup_cleanup(self.hass, self)
data = await self._store.async_load() data = await self._store.async_load()
devices = OrderedDict() devices = OrderedDict()
@ -347,16 +355,8 @@ class DeviceRegistry:
@callback @callback
def async_clear_config_entry(self, config_entry_id: str) -> None: def async_clear_config_entry(self, config_entry_id: str) -> None:
"""Clear config entry from registry entries.""" """Clear config entry from registry entries."""
remove = [] for device in list(self.devices.values()):
for dev_id, device in self.devices.items(): self._async_update_device(device.id, remove_config_entry_id=config_entry_id)
if device.config_entries == {config_entry_id}:
remove.append(dev_id)
else:
self._async_update_device(
dev_id, remove_config_entry_id=config_entry_id
)
for dev_id in remove:
self.async_remove_device(dev_id)
@callback @callback
def async_clear_area_id(self, area_id: str) -> None: def async_clear_area_id(self, area_id: str) -> None:
@ -390,3 +390,69 @@ def async_entries_for_config_entry(
for device in registry.devices.values() for device in registry.devices.values()
if config_entry_id in device.config_entries if config_entry_id in device.config_entries
] ]
@callback
def async_cleanup(
hass: HomeAssistantType,
dev_reg: DeviceRegistry,
ent_reg: "entity_registry.EntityRegistry",
) -> None:
"""Clean up device registry."""
# Find all devices that are no longer referenced in the entity registry.
referenced = {entry.device_id for entry in ent_reg.entities.values()}
orphan = set(dev_reg.devices) - referenced
for dev_id in orphan:
dev_reg.async_remove_device(dev_id)
# Find all referenced config entries that no longer exist
# This shouldn't happen but have not been able to track down the bug :(
config_entry_ids = {entry.entry_id for entry in hass.config_entries.async_entries()}
for device in list(dev_reg.devices.values()):
for config_entry_id in device.config_entries:
if config_entry_id not in config_entry_ids:
dev_reg.async_update_device(
device.id, remove_config_entry_id=config_entry_id
)
@callback
def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> None:
"""Clean up device registry when entities removed."""
from . import entity_registry # pylint: disable=import-outside-toplevel
async def cleanup():
"""Cleanup."""
ent_reg = await entity_registry.async_get_registry(hass)
async_cleanup(hass, dev_reg, ent_reg)
debounced_cleanup = Debouncer(
hass, _LOGGER, cooldown=CLEANUP_DELAY, immediate=False, function=cleanup
)
async def entity_registry_changed(event: Event) -> None:
"""Handle entity updated or removed."""
if (
event.data["action"] == "update"
and "device_id" not in event.data["changes"]
) or event.data["action"] == "create":
return
await debounced_cleanup.async_call()
if hass.is_running:
hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, entity_registry_changed
)
return
async def startup_clean(event: Event) -> None:
"""Clean up on startup."""
hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, entity_registry_changed
)
await debounced_cleanup.async_call()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean)

View File

@ -3,12 +3,12 @@ import asyncio
import pytest import pytest
from homeassistant.core import callback from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.helpers import device_registry from homeassistant.core import CoreState, callback
from homeassistant.helpers import device_registry, entity_registry
import tests.async_mock
from tests.async_mock import patch from tests.async_mock import patch
from tests.common import flush_store, mock_device_registry from tests.common import MockConfigEntry, flush_store, mock_device_registry
@pytest.fixture @pytest.fixture
@ -483,7 +483,7 @@ async def test_update_remove_config_entries(hass, registry, update_events):
async def test_loading_race_condition(hass): async def test_loading_race_condition(hass):
"""Test only one storage load called when concurrent loading occurred .""" """Test only one storage load called when concurrent loading occurred ."""
with tests.async_mock.patch( with patch(
"homeassistant.helpers.device_registry.DeviceRegistry.async_load" "homeassistant.helpers.device_registry.DeviceRegistry.async_load"
) as mock_load: ) as mock_load:
results = await asyncio.gather( results = await asyncio.gather(
@ -511,3 +511,76 @@ async def test_update_sw_version(registry):
assert mock_save.call_count == 1 assert mock_save.call_count == 1
assert updated_entry != entry assert updated_entry != entry
assert updated_entry.sw_version == sw_version assert updated_entry.sw_version == sw_version
async def test_cleanup_device_registry(hass, registry):
"""Test cleanup works."""
config_entry = MockConfigEntry(domain="hue")
config_entry.add_to_hass(hass)
d1 = registry.async_get_or_create(
identifiers={("hue", "d1")}, config_entry_id=config_entry.entry_id
)
registry.async_get_or_create(
identifiers={("hue", "d2")}, config_entry_id=config_entry.entry_id
)
d3 = registry.async_get_or_create(
identifiers={("hue", "d3")}, config_entry_id=config_entry.entry_id
)
registry.async_get_or_create(
identifiers={("something", "d4")}, config_entry_id="non_existing"
)
ent_reg = await entity_registry.async_get_registry(hass)
ent_reg.async_get_or_create("light", "hue", "e1", device_id=d1.id)
ent_reg.async_get_or_create("light", "hue", "e2", device_id=d1.id)
ent_reg.async_get_or_create("light", "hue", "e3", device_id=d3.id)
device_registry.async_cleanup(hass, registry, ent_reg)
assert registry.async_get_device({("hue", "d1")}, set()) is not None
assert registry.async_get_device({("hue", "d2")}, set()) is None
assert registry.async_get_device({("hue", "d3")}, set()) is not None
assert registry.async_get_device({("something", "d4")}, set()) is None
async def test_cleanup_startup(hass):
"""Test we run a cleanup on startup."""
hass.state = CoreState.not_running
await device_registry.async_get_registry(hass)
with patch(
"homeassistant.helpers.device_registry.Debouncer.async_call"
) as mock_call:
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
assert len(mock_call.mock_calls) == 1
async def test_cleanup_entity_registry_change(hass):
"""Test we run a cleanup when entity registry changes."""
await device_registry.async_get_registry(hass)
ent_reg = await entity_registry.async_get_registry(hass)
with patch(
"homeassistant.helpers.device_registry.Debouncer.async_call"
) as mock_call:
entity = ent_reg.async_get_or_create("light", "hue", "e1")
await hass.async_block_till_done()
assert len(mock_call.mock_calls) == 0
# Normal update does not trigger
ent_reg.async_update_entity(entity.entity_id, name="updated")
await hass.async_block_till_done()
assert len(mock_call.mock_calls) == 0
# Device ID update triggers
ent_reg.async_get_or_create("light", "hue", "e1", device_id="bla")
await hass.async_block_till_done()
assert len(mock_call.mock_calls) == 1
# Removal also triggers
ent_reg.async_remove(entity.entity_id)
await hass.async_block_till_done()
assert len(mock_call.mock_calls) == 2