mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Index the entity registry (#37994)
This commit is contained in:
parent
41421b56a4
commit
890562e3ae
@ -124,6 +124,7 @@ class EntityRegistry:
|
|||||||
"""Initialize the registry."""
|
"""Initialize the registry."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.entities: Dict[str, RegistryEntry]
|
self.entities: Dict[str, RegistryEntry]
|
||||||
|
self._index: Dict[Tuple[str, str, str], str] = {}
|
||||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||||
self.hass.bus.async_listen(
|
self.hass.bus.async_listen(
|
||||||
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed
|
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed
|
||||||
@ -160,14 +161,7 @@ class EntityRegistry:
|
|||||||
self, domain: str, platform: str, unique_id: str
|
self, domain: str, platform: str, unique_id: str
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Check if an entity_id is currently registered."""
|
"""Check if an entity_id is currently registered."""
|
||||||
for entity in self.entities.values():
|
return self._index.get((domain, platform, unique_id))
|
||||||
if (
|
|
||||||
entity.domain == domain
|
|
||||||
and entity.platform == platform
|
|
||||||
and entity.unique_id == unique_id
|
|
||||||
):
|
|
||||||
return entity.entity_id
|
|
||||||
return None
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_generate_entity_id(
|
def async_generate_entity_id(
|
||||||
@ -270,7 +264,7 @@ class EntityRegistry:
|
|||||||
original_name=original_name,
|
original_name=original_name,
|
||||||
original_icon=original_icon,
|
original_icon=original_icon,
|
||||||
)
|
)
|
||||||
self.entities[entity_id] = entity
|
self._register_entry(entity)
|
||||||
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
|
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
|
|
||||||
@ -283,7 +277,7 @@ class EntityRegistry:
|
|||||||
@callback
|
@callback
|
||||||
def async_remove(self, entity_id: str) -> None:
|
def async_remove(self, entity_id: str) -> None:
|
||||||
"""Remove an entity from registry."""
|
"""Remove an entity from registry."""
|
||||||
self.entities.pop(entity_id)
|
self._unregister_entry(self.entities[entity_id])
|
||||||
self.hass.bus.async_fire(
|
self.hass.bus.async_fire(
|
||||||
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id}
|
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id}
|
||||||
)
|
)
|
||||||
@ -380,27 +374,22 @@ class EntityRegistry:
|
|||||||
entity_id = changes["entity_id"] = new_entity_id
|
entity_id = changes["entity_id"] = new_entity_id
|
||||||
|
|
||||||
if new_unique_id is not _UNDEF:
|
if new_unique_id is not _UNDEF:
|
||||||
conflict = next(
|
conflict_entity_id = self.async_get_entity_id(
|
||||||
(
|
old.domain, old.platform, new_unique_id
|
||||||
entity
|
|
||||||
for entity in self.entities.values()
|
|
||||||
if entity.unique_id == new_unique_id
|
|
||||||
and entity.domain == old.domain
|
|
||||||
and entity.platform == old.platform
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
if conflict:
|
if conflict_entity_id:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unique id '{new_unique_id}' is already in use by "
|
f"Unique id '{new_unique_id}' is already in use by "
|
||||||
f"'{conflict.entity_id}'"
|
f"'{conflict_entity_id}'"
|
||||||
)
|
)
|
||||||
changes["unique_id"] = new_unique_id
|
changes["unique_id"] = new_unique_id
|
||||||
|
|
||||||
if not changes:
|
if not changes:
|
||||||
return old
|
return old
|
||||||
|
|
||||||
new = self.entities[entity_id] = attr.evolve(old, **changes)
|
self._remove_index(old)
|
||||||
|
new = attr.evolve(old, **changes)
|
||||||
|
self._register_entry(new)
|
||||||
|
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
|
|
||||||
@ -451,6 +440,7 @@ class EntityRegistry:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.entities = entities
|
self.entities = entities
|
||||||
|
self._rebuild_index()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_schedule_save(self) -> None:
|
def async_schedule_save(self) -> None:
|
||||||
@ -494,6 +484,25 @@ class EntityRegistry:
|
|||||||
]:
|
]:
|
||||||
self.async_remove(entity_id)
|
self.async_remove(entity_id)
|
||||||
|
|
||||||
|
def _register_entry(self, entry: RegistryEntry) -> None:
|
||||||
|
self.entities[entry.entity_id] = entry
|
||||||
|
self._add_index(entry)
|
||||||
|
|
||||||
|
def _add_index(self, entry: RegistryEntry) -> None:
|
||||||
|
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
|
||||||
|
|
||||||
|
def _unregister_entry(self, entry: RegistryEntry) -> None:
|
||||||
|
self._remove_index(entry)
|
||||||
|
del self.entities[entry.entity_id]
|
||||||
|
|
||||||
|
def _remove_index(self, entry: RegistryEntry) -> None:
|
||||||
|
del self._index[(entry.domain, entry.platform, entry.unique_id)]
|
||||||
|
|
||||||
|
def _rebuild_index(self) -> None:
|
||||||
|
self._index = {}
|
||||||
|
for entry in self.entities.values():
|
||||||
|
self._add_index(entry)
|
||||||
|
|
||||||
|
|
||||||
@singleton(DATA_REGISTRY)
|
@singleton(DATA_REGISTRY)
|
||||||
async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
|
async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
|
||||||
|
@ -351,6 +351,7 @@ def mock_registry(hass, mock_entries=None):
|
|||||||
"""Mock the Entity Registry."""
|
"""Mock the Entity Registry."""
|
||||||
registry = entity_registry.EntityRegistry(hass)
|
registry = entity_registry.EntityRegistry(hass)
|
||||||
registry.entities = mock_entries or OrderedDict()
|
registry.entities = mock_entries or OrderedDict()
|
||||||
|
registry._rebuild_index()
|
||||||
|
|
||||||
hass.data[entity_registry.DATA_REGISTRY] = registry
|
hass.data[entity_registry.DATA_REGISTRY] = registry
|
||||||
return registry
|
return registry
|
||||||
|
@ -428,6 +428,8 @@ async def test_update_entity_unique_id(registry):
|
|||||||
entry = registry.async_get_or_create(
|
entry = registry.async_get_or_create(
|
||||||
"light", "hue", "5678", config_entry=mock_config
|
"light", "hue", "5678", config_entry=mock_config
|
||||||
)
|
)
|
||||||
|
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
|
||||||
|
|
||||||
new_unique_id = "1234"
|
new_unique_id = "1234"
|
||||||
with patch.object(registry, "async_schedule_save") as mock_schedule_save:
|
with patch.object(registry, "async_schedule_save") as mock_schedule_save:
|
||||||
updated_entry = registry.async_update_entity(
|
updated_entry = registry.async_update_entity(
|
||||||
@ -437,6 +439,9 @@ async def test_update_entity_unique_id(registry):
|
|||||||
assert updated_entry.unique_id == new_unique_id
|
assert updated_entry.unique_id == new_unique_id
|
||||||
assert mock_schedule_save.call_count == 1
|
assert mock_schedule_save.call_count == 1
|
||||||
|
|
||||||
|
assert registry.async_get_entity_id("light", "hue", "5678") is None
|
||||||
|
assert registry.async_get_entity_id("light", "hue", "1234") == entry.entity_id
|
||||||
|
|
||||||
|
|
||||||
async def test_update_entity_unique_id_conflict(registry):
|
async def test_update_entity_unique_id_conflict(registry):
|
||||||
"""Test migration raises when unique_id already in use."""
|
"""Test migration raises when unique_id already in use."""
|
||||||
@ -452,6 +457,8 @@ async def test_update_entity_unique_id_conflict(registry):
|
|||||||
) as mock_schedule_save, pytest.raises(ValueError):
|
) as mock_schedule_save, pytest.raises(ValueError):
|
||||||
registry.async_update_entity(entry.entity_id, new_unique_id=entry2.unique_id)
|
registry.async_update_entity(entry.entity_id, new_unique_id=entry2.unique_id)
|
||||||
assert mock_schedule_save.call_count == 0
|
assert mock_schedule_save.call_count == 0
|
||||||
|
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
|
||||||
|
assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id
|
||||||
|
|
||||||
|
|
||||||
async def test_update_entity(registry):
|
async def test_update_entity(registry):
|
||||||
@ -473,6 +480,10 @@ async def test_update_entity(registry):
|
|||||||
assert getattr(updated_entry, attr_name) == new_value
|
assert getattr(updated_entry, attr_name) == new_value
|
||||||
assert getattr(updated_entry, attr_name) != getattr(entry, attr_name)
|
assert getattr(updated_entry, attr_name) != getattr(entry, attr_name)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
registry.async_get_entity_id("light", "hue", "5678")
|
||||||
|
== updated_entry.entity_id
|
||||||
|
)
|
||||||
entry = updated_entry
|
entry = updated_entry
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user