Make get_entries_for_device_id skip disabled devices by default (#109645)

This commit is contained in:
J. Nick Koston 2024-02-04 16:11:56 -06:00 committed by GitHub
parent 2c0b897658
commit 9fef1938b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 11 deletions

View File

@ -249,7 +249,6 @@ async def async_get_device_automations(
for device_id in match_device_ids: for device_id in match_device_ids:
for entry in entity_registry.entities.get_entries_for_device_id(device_id): for entry in entity_registry.entities.get_entries_for_device_id(device_id):
if not entry.disabled_by:
device_entities_domains.setdefault(device_id, set()).add(entry.domain) device_entities_domains.setdefault(device_id, set()).add(entry.domain)
for device_id in match_device_ids: for device_id in match_device_ids:

View File

@ -498,17 +498,24 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
"""Get entry from id.""" """Get entry from id."""
return self._entry_ids.get(key) return self._entry_ids.get(key)
def get_entries_for_device_id(self, device_id: str) -> list[RegistryEntry]: def get_entries_for_device_id(
self, device_id: str, include_disabled_entities: bool = False
) -> list[RegistryEntry]:
"""Get entries for device.""" """Get entries for device."""
return [self.data[key] for key in self._device_id_index.get(device_id, ())] data = self.data
return [
entry
for key in self._device_id_index.get(device_id, ())
if not (entry := data[key]).disabled_by or include_disabled_entities
]
def get_entries_for_config_entry_id( def get_entries_for_config_entry_id(
self, config_entry_id: str self, config_entry_id: str
) -> list[RegistryEntry]: ) -> list[RegistryEntry]:
"""Get entries for config entry.""" """Get entries for config entry."""
data = self.data
return [ return [
self.data[key] data[key] for key in self._config_entry_id_index.get(config_entry_id, ())
for key in self._config_entry_id_index.get(config_entry_id, ())
] ]
@ -1249,11 +1256,9 @@ def async_entries_for_device(
registry: EntityRegistry, device_id: str, include_disabled_entities: bool = False registry: EntityRegistry, device_id: str, include_disabled_entities: bool = False
) -> list[RegistryEntry]: ) -> list[RegistryEntry]:
"""Return entries that match a device.""" """Return entries that match a device."""
return [ return registry.entities.get_entries_for_device_id(
entry device_id, include_disabled_entities
for entry in registry.entities.get_entries_for_device_id(device_id) )
if (not entry.disabled_by or include_disabled_entities)
]
@callback @callback

View File

@ -1370,6 +1370,13 @@ async def test_disabled_entities_excluded_from_entity_list(
) )
assert entries == [entry1, entry2] assert entries == [entry1, entry2]
ent_reg = er.async_get(hass)
assert ent_reg.entities.get_entries_for_device_id(device_entry.id) == [entry1]
assert ent_reg.entities.get_entries_for_device_id(
device_entry.id, include_disabled_entities=True
) == [entry1, entry2]
async def test_entity_max_length_exceeded(entity_registry: er.EntityRegistry) -> None: async def test_entity_max_length_exceeded(entity_registry: er.EntityRegistry) -> None:
"""Test that an exception is raised when the max character length is exceeded.""" """Test that an exception is raised when the max character length is exceeded."""