diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 219f4ff1709..0dae0359d8a 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -1462,6 +1462,11 @@ class ConfigEntries: """Return entry with matching entry_id.""" return self._entries.data.get(entry_id) + @callback + def async_entry_ids(self) -> list[str]: + """Return entry ids.""" + return list(self._entries.data) + @callback def async_entries( self, diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 826a4cc200e..5cbc40e209f 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -1090,7 +1090,7 @@ def async_cleanup( ) -> None: """Clean up device registry.""" # Find all devices that are referenced by a config_entry. - config_entry_ids = {entry.entry_id for entry in hass.config_entries.async_entries()} + config_entry_ids = set(hass.config_entries.async_entry_ids()) references_config_entries = { device.id for device in dev_reg.devices.values() @@ -1099,9 +1099,13 @@ def async_cleanup( } # Find all devices that are referenced in the entity registry. - references_entities = {entry.device_id for entry in ent_reg.entities.values()} + device_ids_referenced_by_entities = set(ent_reg.entities.get_device_ids()) - orphan = set(dev_reg.devices) - references_entities - references_config_entries + orphan = ( + set(dev_reg.devices) + - device_ids_referenced_by_entities + - references_config_entries + ) for dev_id in orphan: dev_reg.async_remove_device(dev_id) diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 50ecbc1fb59..51afe4fc740 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -10,7 +10,7 @@ timer. from __future__ import annotations from collections import UserDict -from collections.abc import Callable, Iterable, Mapping, ValuesView +from collections.abc import Callable, Iterable, KeysView, Mapping, ValuesView from datetime import datetime, timedelta from enum import StrEnum import logging @@ -511,6 +511,14 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): self._unindex_entry(key) super().__delitem__(key) + def get_entity_ids(self) -> ValuesView[str]: + """Return entity ids.""" + return self._index.values() + + def get_device_ids(self) -> KeysView[str]: + """Return device ids.""" + return self._device_id_index.keys() + def get_entity_id(self, key: tuple[str, str, str]) -> str | None: """Get entity_id from (domain, platform, unique_id).""" return self._index.get(key) @@ -612,6 +620,16 @@ class EntityRegistry: """Check if an entity_id is currently registered.""" return self.entities.get_entity_id((domain, platform, unique_id)) + @callback + def async_entity_ids(self) -> list[str]: + """Return entity ids.""" + return list(self.entities.get_entity_ids()) + + @callback + def async_device_ids(self) -> list[str]: + """Return known device ids.""" + return list(self.entities.get_device_ids()) + def _entity_id_available( self, entity_id: str, known_object_ids: Iterable[str] | None ) -> bool: diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 0178e4fcd11..1b0fbe51147 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -90,6 +90,9 @@ def test_get_or_create_updates_data(entity_registry: er.EntityRegistry) -> None: unit_of_measurement="initial-unit_of_measurement", ) + assert set(entity_registry.async_device_ids()) == {"mock-dev-id"} + assert set(entity_registry.async_entity_ids()) == {"light.hue_5678"} + assert orig_entry == er.RegistryEntry( "light.hue_5678", "5678", @@ -159,6 +162,9 @@ def test_get_or_create_updates_data(entity_registry: er.EntityRegistry) -> None: unit_of_measurement="updated-unit_of_measurement", ) + assert set(entity_registry.async_device_ids()) == {"new-mock-dev-id"} + assert set(entity_registry.async_entity_ids()) == {"light.hue_5678"} + new_entry = entity_registry.async_get_or_create( "light", "hue", @@ -203,6 +209,9 @@ def test_get_or_create_updates_data(entity_registry: er.EntityRegistry) -> None: unit_of_measurement=None, ) + assert set(entity_registry.async_device_ids()) == set() + assert set(entity_registry.async_entity_ids()) == {"light.hue_5678"} + def test_get_or_create_suggested_object_id_conflict_register( entity_registry: er.EntityRegistry, @@ -446,6 +455,8 @@ def test_async_get_entity_id(entity_registry: er.EntityRegistry) -> None: ) assert entity_registry.async_get_entity_id("light", "hue", "123") is None + assert set(entity_registry.async_entity_ids()) == {"light.hue_1234"} + async def test_updating_config_entry_id( hass: HomeAssistant, entity_registry: er.EntityRegistry @@ -1469,6 +1480,7 @@ def test_entity_registry_items() -> None: entities = er.EntityRegistryItems() assert entities.get_entity_id(("a", "b", "c")) is None assert entities.get_entry("abc") is None + assert set(entities.get_entity_ids()) == set() entry1 = er.RegistryEntry("test.entity1", "1234", "hue") entry2 = er.RegistryEntry("test.entity2", "2345", "hue") @@ -1482,6 +1494,7 @@ def test_entity_registry_items() -> None: assert entities.get_entry(entry1.id) is entry1 assert entities.get_entity_id(("test", "hue", "2345")) is entry2.entity_id assert entities.get_entry(entry2.id) is entry2 + assert set(entities.get_entity_ids()) == {"test.entity2", "test.entity1"} entities.pop("test.entity1") del entities["test.entity2"] @@ -1491,6 +1504,8 @@ def test_entity_registry_items() -> None: assert entities.get_entity_id(("test", "hue", "2345")) is None assert entities.get_entry(entry2.id) is None + assert set(entities.get_entity_ids()) == set() + async def test_disabled_by_str_not_allowed( hass: HomeAssistant, entity_registry: er.EntityRegistry diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 247a34c078b..d6c5d8bdc5c 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -378,7 +378,7 @@ async def test_remove_entry( MockConfigEntry(domain="test_other", entry_id="test3").add_to_manager(manager) # Check all config entries exist - assert [item.entry_id for item in manager.async_entries()] == [ + assert manager.async_entry_ids() == [ "test1", "test2", "test3", @@ -408,7 +408,7 @@ async def test_remove_entry( assert mock_remove_entry.call_count == 1 # Check that config entry was removed. - assert [item.entry_id for item in manager.async_entries()] == ["test1", "test3"] + assert manager.async_entry_ids() == ["test1", "test3"] # Check that entity state has been removed assert hass.states.get("light.test_entity") is None @@ -469,7 +469,7 @@ async def test_remove_entry_handles_callback_error( entry = MockConfigEntry(domain="test", entry_id="test1") entry.add_to_manager(manager) # Check all config entries exist - assert [item.entry_id for item in manager.async_entries()] == ["test1"] + assert manager.async_entry_ids() == ["test1"] # Setup entry await entry.async_setup(hass) await hass.async_block_till_done() @@ -482,7 +482,7 @@ async def test_remove_entry_handles_callback_error( # Check the remove callback was invoked. assert mock_remove_entry.call_count == 1 # Check that config entry was removed. - assert [item.entry_id for item in manager.async_entries()] == [] + assert manager.async_entry_ids() == [] async def test_remove_entry_raises( @@ -502,7 +502,7 @@ async def test_remove_entry_raises( ).add_to_manager(manager) MockConfigEntry(domain="test", entry_id="test3").add_to_manager(manager) - assert [item.entry_id for item in manager.async_entries()] == [ + assert manager.async_entry_ids() == [ "test1", "test2", "test3", @@ -511,7 +511,7 @@ async def test_remove_entry_raises( result = await manager.async_remove("test2") assert result == {"require_restart": True} - assert [item.entry_id for item in manager.async_entries()] == ["test1", "test3"] + assert manager.async_entry_ids() == ["test1", "test3"] async def test_remove_entry_if_not_loaded( @@ -526,7 +526,7 @@ async def test_remove_entry_if_not_loaded( MockConfigEntry(domain="comp", entry_id="test2").add_to_manager(manager) MockConfigEntry(domain="test", entry_id="test3").add_to_manager(manager) - assert [item.entry_id for item in manager.async_entries()] == [ + assert manager.async_entry_ids() == [ "test1", "test2", "test3", @@ -535,7 +535,7 @@ async def test_remove_entry_if_not_loaded( result = await manager.async_remove("test2") assert result == {"require_restart": False} - assert [item.entry_id for item in manager.async_entries()] == ["test1", "test3"] + assert manager.async_entry_ids() == ["test1", "test3"] assert len(mock_unload_entry.mock_calls) == 0 @@ -550,7 +550,7 @@ async def test_remove_entry_if_integration_deleted( MockConfigEntry(domain="comp", entry_id="test2").add_to_manager(manager) MockConfigEntry(domain="test", entry_id="test3").add_to_manager(manager) - assert [item.entry_id for item in manager.async_entries()] == [ + assert manager.async_entry_ids() == [ "test1", "test2", "test3", @@ -559,7 +559,7 @@ async def test_remove_entry_if_integration_deleted( result = await manager.async_remove("test2") assert result == {"require_restart": False} - assert [item.entry_id for item in manager.async_entries()] == ["test1", "test3"] + assert manager.async_entry_ids() == ["test1", "test3"] assert len(mock_unload_entry.mock_calls) == 0