Make get_entries_for_device_id skip disabled devices by default (#109645)

This commit is contained in:
J. Nick Koston 2024-02-04 16:11:56 -06:00 committed by GitHub
parent 2c0b897658
commit 9fef1938b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 11 deletions

View File

@ -249,8 +249,7 @@ async def async_get_device_automations(
for device_id in match_device_ids:
for entry in entity_registry.entities.get_entries_for_device_id(device_id):
if not entry.disabled_by:
device_entities_domains.setdefault(device_id, set()).add(entry.domain)
device_entities_domains.setdefault(device_id, set()).add(entry.domain)
for device_id in match_device_ids:
combined_results[device_id] = []

View File

@ -498,17 +498,24 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
"""Get entry from id."""
return self._entry_ids.get(key)
def get_entries_for_device_id(self, device_id: str) -> list[RegistryEntry]:
def get_entries_for_device_id(
self, device_id: str, include_disabled_entities: bool = False
) -> list[RegistryEntry]:
"""Get entries for device."""
return [self.data[key] for key in self._device_id_index.get(device_id, ())]
data = self.data
return [
entry
for key in self._device_id_index.get(device_id, ())
if not (entry := data[key]).disabled_by or include_disabled_entities
]
def get_entries_for_config_entry_id(
self, config_entry_id: str
) -> list[RegistryEntry]:
"""Get entries for config entry."""
data = self.data
return [
self.data[key]
for key in self._config_entry_id_index.get(config_entry_id, ())
data[key] for key in self._config_entry_id_index.get(config_entry_id, ())
]
@ -1249,11 +1256,9 @@ def async_entries_for_device(
registry: EntityRegistry, device_id: str, include_disabled_entities: bool = False
) -> list[RegistryEntry]:
"""Return entries that match a device."""
return [
entry
for entry in registry.entities.get_entries_for_device_id(device_id)
if (not entry.disabled_by or include_disabled_entities)
]
return registry.entities.get_entries_for_device_id(
device_id, include_disabled_entities
)
@callback

View File

@ -1370,6 +1370,13 @@ async def test_disabled_entities_excluded_from_entity_list(
)
assert entries == [entry1, entry2]
ent_reg = er.async_get(hass)
assert ent_reg.entities.get_entries_for_device_id(device_entry.id) == [entry1]
assert ent_reg.entities.get_entries_for_device_id(
device_entry.id, include_disabled_entities=True
) == [entry1, entry2]
async def test_entity_max_length_exceeded(entity_registry: er.EntityRegistry) -> None:
"""Test that an exception is raised when the max character length is exceeded."""