From 9fef1938b425d06b47443bb4521a0cb3a7da6c28 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 4 Feb 2024 16:11:56 -0600 Subject: [PATCH] Make get_entries_for_device_id skip disabled devices by default (#109645) --- .../components/device_automation/__init__.py | 3 +-- homeassistant/helpers/entity_registry.py | 23 +++++++++++-------- tests/helpers/test_entity_registry.py | 7 ++++++ 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/homeassistant/components/device_automation/__init__.py b/homeassistant/components/device_automation/__init__.py index 979c82acfe2..2bf87343c72 100644 --- a/homeassistant/components/device_automation/__init__.py +++ b/homeassistant/components/device_automation/__init__.py @@ -249,8 +249,7 @@ async def async_get_device_automations( for device_id in match_device_ids: for entry in entity_registry.entities.get_entries_for_device_id(device_id): - if not entry.disabled_by: - device_entities_domains.setdefault(device_id, set()).add(entry.domain) + device_entities_domains.setdefault(device_id, set()).add(entry.domain) for device_id in match_device_ids: combined_results[device_id] = [] diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 8ae9256754d..3f2e8a94b7c 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -498,17 +498,24 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): """Get entry from id.""" return self._entry_ids.get(key) - def get_entries_for_device_id(self, device_id: str) -> list[RegistryEntry]: + def get_entries_for_device_id( + self, device_id: str, include_disabled_entities: bool = False + ) -> list[RegistryEntry]: """Get entries for device.""" - return [self.data[key] for key in self._device_id_index.get(device_id, ())] + data = self.data + return [ + entry + for key in self._device_id_index.get(device_id, ()) + if not (entry := data[key]).disabled_by or include_disabled_entities + ] def get_entries_for_config_entry_id( self, config_entry_id: str ) -> list[RegistryEntry]: """Get entries for config entry.""" + data = self.data return [ - self.data[key] - for key in self._config_entry_id_index.get(config_entry_id, ()) + data[key] for key in self._config_entry_id_index.get(config_entry_id, ()) ] @@ -1249,11 +1256,9 @@ def async_entries_for_device( registry: EntityRegistry, device_id: str, include_disabled_entities: bool = False ) -> list[RegistryEntry]: """Return entries that match a device.""" - return [ - entry - for entry in registry.entities.get_entries_for_device_id(device_id) - if (not entry.disabled_by or include_disabled_entities) - ] + return registry.entities.get_entries_for_device_id( + device_id, include_disabled_entities + ) @callback diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 1c13da1192f..9e86b0279de 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -1370,6 +1370,13 @@ async def test_disabled_entities_excluded_from_entity_list( ) assert entries == [entry1, entry2] + ent_reg = er.async_get(hass) + assert ent_reg.entities.get_entries_for_device_id(device_entry.id) == [entry1] + + assert ent_reg.entities.get_entries_for_device_id( + device_entry.id, include_disabled_entities=True + ) == [entry1, entry2] + async def test_entity_max_length_exceeded(entity_registry: er.EntityRegistry) -> None: """Test that an exception is raised when the max character length is exceeded."""