Index the entity registry (#37994)

This commit is contained in:
J. Nick Koston 2020-07-19 19:52:41 -10:00 committed by GitHub
parent 41421b56a4
commit 890562e3ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 22 deletions

View File

@ -124,6 +124,7 @@ class EntityRegistry:
"""Initialize the registry.""" """Initialize the registry."""
self.hass = hass self.hass = hass
self.entities: Dict[str, RegistryEntry] self.entities: Dict[str, RegistryEntry]
self._index: Dict[Tuple[str, str, str], str] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self.hass.bus.async_listen( self.hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed
@ -160,14 +161,7 @@ class EntityRegistry:
self, domain: str, platform: str, unique_id: str self, domain: str, platform: str, unique_id: str
) -> Optional[str]: ) -> Optional[str]:
"""Check if an entity_id is currently registered.""" """Check if an entity_id is currently registered."""
for entity in self.entities.values(): return self._index.get((domain, platform, unique_id))
if (
entity.domain == domain
and entity.platform == platform
and entity.unique_id == unique_id
):
return entity.entity_id
return None
@callback @callback
def async_generate_entity_id( def async_generate_entity_id(
@ -270,7 +264,7 @@ class EntityRegistry:
original_name=original_name, original_name=original_name,
original_icon=original_icon, original_icon=original_icon,
) )
self.entities[entity_id] = entity self._register_entry(entity)
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id) _LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
self.async_schedule_save() self.async_schedule_save()
@ -283,7 +277,7 @@ class EntityRegistry:
@callback @callback
def async_remove(self, entity_id: str) -> None: def async_remove(self, entity_id: str) -> None:
"""Remove an entity from registry.""" """Remove an entity from registry."""
self.entities.pop(entity_id) self._unregister_entry(self.entities[entity_id])
self.hass.bus.async_fire( self.hass.bus.async_fire(
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id} EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id}
) )
@ -380,27 +374,22 @@ class EntityRegistry:
entity_id = changes["entity_id"] = new_entity_id entity_id = changes["entity_id"] = new_entity_id
if new_unique_id is not _UNDEF: if new_unique_id is not _UNDEF:
conflict = next( conflict_entity_id = self.async_get_entity_id(
( old.domain, old.platform, new_unique_id
entity
for entity in self.entities.values()
if entity.unique_id == new_unique_id
and entity.domain == old.domain
and entity.platform == old.platform
),
None,
) )
if conflict: if conflict_entity_id:
raise ValueError( raise ValueError(
f"Unique id '{new_unique_id}' is already in use by " f"Unique id '{new_unique_id}' is already in use by "
f"'{conflict.entity_id}'" f"'{conflict_entity_id}'"
) )
changes["unique_id"] = new_unique_id changes["unique_id"] = new_unique_id
if not changes: if not changes:
return old return old
new = self.entities[entity_id] = attr.evolve(old, **changes) self._remove_index(old)
new = attr.evolve(old, **changes)
self._register_entry(new)
self.async_schedule_save() self.async_schedule_save()
@ -451,6 +440,7 @@ class EntityRegistry:
) )
self.entities = entities self.entities = entities
self._rebuild_index()
@callback @callback
def async_schedule_save(self) -> None: def async_schedule_save(self) -> None:
@ -494,6 +484,25 @@ class EntityRegistry:
]: ]:
self.async_remove(entity_id) self.async_remove(entity_id)
def _register_entry(self, entry: RegistryEntry) -> None:
self.entities[entry.entity_id] = entry
self._add_index(entry)
def _add_index(self, entry: RegistryEntry) -> None:
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
def _unregister_entry(self, entry: RegistryEntry) -> None:
self._remove_index(entry)
del self.entities[entry.entity_id]
def _remove_index(self, entry: RegistryEntry) -> None:
del self._index[(entry.domain, entry.platform, entry.unique_id)]
def _rebuild_index(self) -> None:
self._index = {}
for entry in self.entities.values():
self._add_index(entry)
@singleton(DATA_REGISTRY) @singleton(DATA_REGISTRY)
async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry: async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:

View File

@ -351,6 +351,7 @@ def mock_registry(hass, mock_entries=None):
"""Mock the Entity Registry.""" """Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass) registry = entity_registry.EntityRegistry(hass)
registry.entities = mock_entries or OrderedDict() registry.entities = mock_entries or OrderedDict()
registry._rebuild_index()
hass.data[entity_registry.DATA_REGISTRY] = registry hass.data[entity_registry.DATA_REGISTRY] = registry
return registry return registry

View File

@ -428,6 +428,8 @@ async def test_update_entity_unique_id(registry):
entry = registry.async_get_or_create( entry = registry.async_get_or_create(
"light", "hue", "5678", config_entry=mock_config "light", "hue", "5678", config_entry=mock_config
) )
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
new_unique_id = "1234" new_unique_id = "1234"
with patch.object(registry, "async_schedule_save") as mock_schedule_save: with patch.object(registry, "async_schedule_save") as mock_schedule_save:
updated_entry = registry.async_update_entity( updated_entry = registry.async_update_entity(
@ -437,6 +439,9 @@ async def test_update_entity_unique_id(registry):
assert updated_entry.unique_id == new_unique_id assert updated_entry.unique_id == new_unique_id
assert mock_schedule_save.call_count == 1 assert mock_schedule_save.call_count == 1
assert registry.async_get_entity_id("light", "hue", "5678") is None
assert registry.async_get_entity_id("light", "hue", "1234") == entry.entity_id
async def test_update_entity_unique_id_conflict(registry): async def test_update_entity_unique_id_conflict(registry):
"""Test migration raises when unique_id already in use.""" """Test migration raises when unique_id already in use."""
@ -452,6 +457,8 @@ async def test_update_entity_unique_id_conflict(registry):
) as mock_schedule_save, pytest.raises(ValueError): ) as mock_schedule_save, pytest.raises(ValueError):
registry.async_update_entity(entry.entity_id, new_unique_id=entry2.unique_id) registry.async_update_entity(entry.entity_id, new_unique_id=entry2.unique_id)
assert mock_schedule_save.call_count == 0 assert mock_schedule_save.call_count == 0
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id
async def test_update_entity(registry): async def test_update_entity(registry):
@ -473,6 +480,10 @@ async def test_update_entity(registry):
assert getattr(updated_entry, attr_name) == new_value assert getattr(updated_entry, attr_name) == new_value
assert getattr(updated_entry, attr_name) != getattr(entry, attr_name) assert getattr(updated_entry, attr_name) != getattr(entry, attr_name)
assert (
registry.async_get_entity_id("light", "hue", "5678")
== updated_entry.entity_id
)
entry = updated_entry entry = updated_entry