From e86fec310b3fcff15d7a27de84ef2ee761c2b998 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 3 Apr 2024 10:24:44 -1000 Subject: [PATCH] Improve performance of extracting entities by label (#114720) --- homeassistant/helpers/entity_registry.py | 21 +++++++++-- homeassistant/helpers/service.py | 39 ++++++++++---------- tests/components/device_tracker/test_init.py | 13 +++++-- 3 files changed, 46 insertions(+), 27 deletions(-) diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index e19c4290a1d..27e73320841 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -512,11 +512,13 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]): class EntityRegistryItems(BaseRegistryItems[RegistryEntry]): """Container for entity registry items, maps entity_id -> entry. - Maintains four additional indexes: + Maintains six additional indexes: - id -> entry - (domain, platform, unique_id) -> entity_id - - config_entry_id -> list[key] - - device_id -> list[key] + - config_entry_id -> dict[key, True] + - device_id -> dict[key, True] + - area_id -> dict[key, True] + - label -> dict[key, True] """ def __init__(self) -> None: @@ -527,6 +529,7 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]): self._config_entry_id_index: dict[str, dict[str, Literal[True]]] = {} self._device_id_index: dict[str, dict[str, Literal[True]]] = {} self._area_id_index: dict[str, dict[str, Literal[True]]] = {} + self._labels_index: dict[str, dict[str, Literal[True]]] = {} def _index_entry(self, key: str, entry: RegistryEntry) -> None: """Index an entry.""" @@ -540,6 +543,8 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]): self._device_id_index.setdefault(device_id, {})[key] = True if (area_id := entry.area_id) is not None: self._area_id_index.setdefault(area_id, {})[key] = True + for label in entry.labels: + self._labels_index.setdefault(label, {})[key] = True def _unindex_entry( self, key: str, replacement_entry: RegistryEntry | None = None @@ -554,6 +559,9 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]): self._unindex_entry_value(key, device_id, self._device_id_index) if area_id := entry.area_id: self._unindex_entry_value(key, area_id, self._area_id_index) + if labels := entry.labels: + for label in labels: + self._unindex_entry_value(key, label, self._labels_index) def get_device_ids(self) -> KeysView[str]: """Return device ids.""" @@ -592,6 +600,11 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]): data = self.data return [data[key] for key in self._area_id_index.get(area_id, ())] + def get_entries_for_label(self, label: str) -> list[RegistryEntry]: + """Get entries for label.""" + data = self.data + return [data[key] for key in self._labels_index.get(label, ())] + class EntityRegistry(BaseRegistry): """Class to hold a registry of entities.""" @@ -1317,7 +1330,7 @@ def async_entries_for_label( registry: EntityRegistry, label_id: str ) -> list[RegistryEntry]: """Return entries that match a label.""" - return [entry for entry in registry.entities.values() if label_id in entry.labels] + return registry.entities.get_entries_for_label(label_id) @callback diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index da27df9d139..43942458233 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -503,15 +503,15 @@ def async_extract_referenced_entity_ids( # noqa: C901 ): return selected - ent_reg = entity_registry.async_get(hass) + entities = entity_registry.async_get(hass).entities dev_reg = device_registry.async_get(hass) area_reg = area_registry.async_get(hass) - floor_reg = floor_registry.async_get(hass) - label_reg = label_registry.async_get(hass) - for floor_id in selector.floor_ids: - if floor_id not in floor_reg.floors: - selected.missing_floors.add(floor_id) + if selector.floor_ids: + floor_reg = floor_registry.async_get(hass) + for floor_id in selector.floor_ids: + if floor_id not in floor_reg.floors: + selected.missing_floors.add(floor_id) for area_id in selector.area_ids: if area_id not in area_reg.areas: @@ -521,12 +521,20 @@ def async_extract_referenced_entity_ids( # noqa: C901 if device_id not in dev_reg.devices: selected.missing_devices.add(device_id) - for label_id in selector.label_ids: - if label_id not in label_reg.labels: - selected.missing_labels.add(label_id) - - # Find areas, devices & entities for targeted labels if selector.label_ids: + label_reg = label_registry.async_get(hass) + for label_id in selector.label_ids: + if label_id not in label_reg.labels: + selected.missing_labels.add(label_id) + + for entity_entry in entities.get_entries_for_label(label_id): + if ( + entity_entry.entity_category is None + and entity_entry.hidden_by is None + ): + selected.indirectly_referenced.add(entity_entry.entity_id) + + # Find areas, devices & entities for targeted labels for area_entry in area_reg.areas.values(): if area_entry.labels.intersection(selector.label_ids): selected.referenced_areas.add(area_entry.id) @@ -535,14 +543,6 @@ def async_extract_referenced_entity_ids( # noqa: C901 if device_entry.labels.intersection(selector.label_ids): selected.referenced_devices.add(device_entry.id) - for entity_entry in ent_reg.entities.values(): - if ( - entity_entry.entity_category is None - and entity_entry.hidden_by is None - and entity_entry.labels.intersection(selector.label_ids) - ): - selected.indirectly_referenced.add(entity_entry.entity_id) - # Find areas for targeted floors if selector.floor_ids: for area_entry in area_reg.areas.values(): @@ -561,7 +561,6 @@ def async_extract_referenced_entity_ids( # noqa: C901 if not selected.referenced_areas and not selected.referenced_devices: return selected - entities = ent_reg.entities # Add indirectly referenced by area selected.indirectly_referenced.update( entry.entity_id diff --git a/tests/components/device_tracker/test_init.py b/tests/components/device_tracker/test_init.py index cc6cf5c1c1e..6999a99f7ba 100644 --- a/tests/components/device_tracker/test_init.py +++ b/tests/components/device_tracker/test_init.py @@ -5,7 +5,7 @@ import json import logging import os from types import ModuleType -from unittest.mock import Mock, call, patch +from unittest.mock import call, patch import pytest @@ -25,6 +25,7 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant, State, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import discovery +from homeassistant.helpers.entity_registry import RegistryEntry from homeassistant.helpers.json import JSONEncoder from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -396,10 +397,16 @@ async def test_see_service_guard_config_entry( mock_device_tracker_conf: list[legacy.Device], ) -> None: """Test the guard if the device is registered in the entity registry.""" - mock_entry = Mock() dev_id = "test" entity_id = f"{const.DOMAIN}.{dev_id}" - mock_registry(hass, {entity_id: mock_entry}) + mock_registry( + hass, + { + entity_id: RegistryEntry( + entity_id=entity_id, unique_id=1, platform=const.DOMAIN + ) + }, + ) devices = mock_device_tracker_conf assert await async_setup_component(hass, device_tracker.DOMAIN, TEST_PLATFORM) await hass.async_block_till_done()