diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index b9b9af6f5c1..c4c445b2be9 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -124,6 +124,7 @@ class EntityRegistry: """Initialize the registry.""" self.hass = hass self.entities: Dict[str, RegistryEntry] + self._index: Dict[Tuple[str, str, str], str] = {} self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self.hass.bus.async_listen( EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed @@ -160,14 +161,7 @@ class EntityRegistry: self, domain: str, platform: str, unique_id: str ) -> Optional[str]: """Check if an entity_id is currently registered.""" - for entity in self.entities.values(): - if ( - entity.domain == domain - and entity.platform == platform - and entity.unique_id == unique_id - ): - return entity.entity_id - return None + return self._index.get((domain, platform, unique_id)) @callback def async_generate_entity_id( @@ -270,7 +264,7 @@ class EntityRegistry: original_name=original_name, original_icon=original_icon, ) - self.entities[entity_id] = entity + self._register_entry(entity) _LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id) self.async_schedule_save() @@ -283,7 +277,7 @@ class EntityRegistry: @callback def async_remove(self, entity_id: str) -> None: """Remove an entity from registry.""" - self.entities.pop(entity_id) + self._unregister_entry(self.entities[entity_id]) self.hass.bus.async_fire( EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id} ) @@ -380,27 +374,22 @@ class EntityRegistry: entity_id = changes["entity_id"] = new_entity_id if new_unique_id is not _UNDEF: - conflict = next( - ( - entity - for entity in self.entities.values() - if entity.unique_id == new_unique_id - and entity.domain == old.domain - and entity.platform == old.platform - ), - None, + conflict_entity_id = self.async_get_entity_id( + old.domain, old.platform, new_unique_id ) - if conflict: + if conflict_entity_id: raise ValueError( f"Unique id '{new_unique_id}' is already in use by " - f"'{conflict.entity_id}'" + f"'{conflict_entity_id}'" ) changes["unique_id"] = new_unique_id if not changes: return old - new = self.entities[entity_id] = attr.evolve(old, **changes) + self._remove_index(old) + new = attr.evolve(old, **changes) + self._register_entry(new) self.async_schedule_save() @@ -451,6 +440,7 @@ class EntityRegistry: ) self.entities = entities + self._rebuild_index() @callback def async_schedule_save(self) -> None: @@ -494,6 +484,25 @@ class EntityRegistry: ]: self.async_remove(entity_id) + def _register_entry(self, entry: RegistryEntry) -> None: + self.entities[entry.entity_id] = entry + self._add_index(entry) + + def _add_index(self, entry: RegistryEntry) -> None: + self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id + + def _unregister_entry(self, entry: RegistryEntry) -> None: + self._remove_index(entry) + del self.entities[entry.entity_id] + + def _remove_index(self, entry: RegistryEntry) -> None: + del self._index[(entry.domain, entry.platform, entry.unique_id)] + + def _rebuild_index(self) -> None: + self._index = {} + for entry in self.entities.values(): + self._add_index(entry) + @singleton(DATA_REGISTRY) async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry: diff --git a/tests/common.py b/tests/common.py index 1436b0f5a8a..db060bc6b91 100644 --- a/tests/common.py +++ b/tests/common.py @@ -351,6 +351,7 @@ def mock_registry(hass, mock_entries=None): """Mock the Entity Registry.""" registry = entity_registry.EntityRegistry(hass) registry.entities = mock_entries or OrderedDict() + registry._rebuild_index() hass.data[entity_registry.DATA_REGISTRY] = registry return registry diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 285f43b6d4d..97d8af7d0ee 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -428,6 +428,8 @@ async def test_update_entity_unique_id(registry): entry = registry.async_get_or_create( "light", "hue", "5678", config_entry=mock_config ) + assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id + new_unique_id = "1234" with patch.object(registry, "async_schedule_save") as mock_schedule_save: updated_entry = registry.async_update_entity( @@ -437,6 +439,9 @@ async def test_update_entity_unique_id(registry): assert updated_entry.unique_id == new_unique_id assert mock_schedule_save.call_count == 1 + assert registry.async_get_entity_id("light", "hue", "5678") is None + assert registry.async_get_entity_id("light", "hue", "1234") == entry.entity_id + async def test_update_entity_unique_id_conflict(registry): """Test migration raises when unique_id already in use.""" @@ -452,6 +457,8 @@ async def test_update_entity_unique_id_conflict(registry): ) as mock_schedule_save, pytest.raises(ValueError): registry.async_update_entity(entry.entity_id, new_unique_id=entry2.unique_id) assert mock_schedule_save.call_count == 0 + assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id + assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id async def test_update_entity(registry): @@ -473,6 +480,10 @@ async def test_update_entity(registry): assert getattr(updated_entry, attr_name) == new_value assert getattr(updated_entry, attr_name) != getattr(entry, attr_name) + assert ( + registry.async_get_entity_id("light", "hue", "5678") + == updated_entry.entity_id + ) entry = updated_entry