Add domain filter support to async_all to match async_entity_ids (#39725)

This avoids copying all the states before applying
the filter
This commit is contained in:
J. Nick Koston
2020-09-06 16:20:32 -05:00
committed by GitHub
parent 19818d96b7
commit 251d8919ea
6 changed files with 45 additions and 16 deletions

View File

@@ -918,17 +918,29 @@ class StateMachine:
if state.domain in domain_filter
]
def all(self) -> List[State]:
def all(self, domain_filter: Optional[Union[str, Iterable]] = None) -> List[State]:
"""Create a list of all states."""
return run_callback_threadsafe(self._loop, self.async_all).result()
return run_callback_threadsafe(
self._loop, self.async_all, domain_filter
).result()
@callback
def async_all(self) -> List[State]:
"""Create a list of all states.
def async_all(
self, domain_filter: Optional[Union[str, Iterable]] = None
) -> List[State]:
"""Create a list of all states matching the filter.
This method must be run in the event loop.
"""
return list(self._states.values())
if domain_filter is None:
return list(self._states.values())
if isinstance(domain_filter, str):
domain_filter = (domain_filter.lower(),)
return [
state for state in self._states.values() if state.domain in domain_filter
]
def get(self, entity_id: str) -> Optional[State]:
"""Retrieve state of entity_id or None if not found.