diff --git a/homeassistant/components/humidifier/intent.py b/homeassistant/components/humidifier/intent.py index ee257cc7123..fafbb0a494a 100644 --- a/homeassistant/components/humidifier/intent.py +++ b/homeassistant/components/humidifier/intent.py @@ -41,8 +41,7 @@ class HumidityHandler(intent.IntentHandler): hass = intent_obj.hass slots = self.async_validate_slots(intent_obj.slots) state = hass.helpers.intent.async_match_state( - slots["name"]["value"], - [state for state in hass.states.async_all() if state.domain == DOMAIN], + slots["name"]["value"], hass.states.async_all(DOMAIN) ) service_data = {ATTR_ENTITY_ID: state.entity_id} @@ -87,7 +86,7 @@ class SetModeHandler(intent.IntentHandler): slots = self.async_validate_slots(intent_obj.slots) state = hass.helpers.intent.async_match_state( slots["name"]["value"], - [state for state in hass.states.async_all() if state.domain == DOMAIN], + hass.states.async_all(DOMAIN), ) service_data = {ATTR_ENTITY_ID: state.entity_id} diff --git a/homeassistant/components/light/intent.py b/homeassistant/components/light/intent.py index 58f74d8a422..be9346cf85b 100644 --- a/homeassistant/components/light/intent.py +++ b/homeassistant/components/light/intent.py @@ -39,8 +39,7 @@ class SetIntentHandler(intent.IntentHandler): hass = intent_obj.hass slots = self.async_validate_slots(intent_obj.slots) state = hass.helpers.intent.async_match_state( - slots["name"]["value"], - [state for state in hass.states.async_all() if state.domain == DOMAIN], + slots["name"]["value"], hass.states.async_all(DOMAIN) ) service_data = {ATTR_ENTITY_ID: state.entity_id} diff --git a/homeassistant/components/owntracks/__init__.py b/homeassistant/components/owntracks/__init__.py index cf034950154..24dc99de71c 100644 --- a/homeassistant/components/owntracks/__init__.py +++ b/homeassistant/components/owntracks/__init__.py @@ -183,10 +183,7 @@ async def handle_webhook(hass, webhook_id, request): response = [] - for person in hass.states.async_all(): - if person.domain != "person": - continue - + for person in hass.states.async_all("person"): if "latitude" in person.attributes and "longitude" in person.attributes: response.append( { diff --git a/homeassistant/core.py b/homeassistant/core.py index ad083f60574..8f3809bbd4c 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -918,17 +918,29 @@ class StateMachine: if state.domain in domain_filter ] - def all(self) -> List[State]: + def all(self, domain_filter: Optional[Union[str, Iterable]] = None) -> List[State]: """Create a list of all states.""" - return run_callback_threadsafe(self._loop, self.async_all).result() + return run_callback_threadsafe( + self._loop, self.async_all, domain_filter + ).result() @callback - def async_all(self) -> List[State]: - """Create a list of all states. + def async_all( + self, domain_filter: Optional[Union[str, Iterable]] = None + ) -> List[State]: + """Create a list of all states matching the filter. This method must be run in the event loop. """ - return list(self._states.values()) + if domain_filter is None: + return list(self._states.values()) + + if isinstance(domain_filter, str): + domain_filter = (domain_filter.lower(),) + + return [ + state for state in self._states.values() if state.domain in domain_filter + ] def get(self, entity_id: str) -> Optional[State]: """Retrieve state of entity_id or None if not found. diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index eeb43fb8756..405d8588532 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -459,8 +459,7 @@ class DomainStates: sorted( ( _wrap_state(self._hass, state) - for state in self._hass.states.async_all() - if state.domain == self._domain + for state in self._hass.states.async_all(self._domain) ), key=lambda state: state.entity_id, ) diff --git a/tests/test_core.py b/tests/test_core.py index 22f1e779061..f5de9c5f1a1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1454,3 +1454,26 @@ async def test_chained_logging_misses_log_timeout(hass, caplog): await hass.async_block_till_done() assert "_task_chain_" not in caplog.text + + +async def test_async_all(hass): + """Test async_all.""" + + hass.states.async_set("switch.link", "on") + hass.states.async_set("light.bowl", "on") + hass.states.async_set("light.frog", "on") + hass.states.async_set("vacuum.floor", "on") + + assert {state.entity_id for state in hass.states.async_all()} == { + "switch.link", + "light.bowl", + "light.frog", + "vacuum.floor", + } + assert {state.entity_id for state in hass.states.async_all("light")} == { + "light.bowl", + "light.frog", + } + assert { + state.entity_id for state in hass.states.async_all(["light", "switch"]) + } == {"light.bowl", "light.frog", "switch.link"}