From 35533407fe8c9900328ec9f9fec9345821dafb5b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 26 Sep 2020 11:36:47 -0500 Subject: [PATCH] Improve performance of counting and iterating states in templates (#40250) Co-authored-by: Anders Melchiorsen --- homeassistant/core.py | 18 ++++++++++++++++++ homeassistant/helpers/template.py | 20 +++++++++----------- tests/helpers/test_template.py | 11 +++++++++++ tests/test_core.py | 17 +++++++++++++++++ 4 files changed, 55 insertions(+), 11 deletions(-) diff --git a/homeassistant/core.py b/homeassistant/core.py index fd34032112b..779d0f975a7 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -923,6 +923,24 @@ class StateMachine: if state.domain in domain_filter ] + @callback + def async_entity_ids_count( + self, domain_filter: Optional[Union[str, Iterable]] = None + ) -> int: + """Count the entity ids that are being tracked. + + This method must be run in the event loop. + """ + if domain_filter is None: + return len(self._states.keys()) + + if isinstance(domain_filter, str): + domain_filter = (domain_filter.lower(),) + + return len( + [None for state in self._states.values() if state.domain in domain_filter] + ) + def all(self, domain_filter: Optional[Union[str, Iterable]] = None) -> List[State]: """Create a list of all states.""" return run_callback_threadsafe( diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index bef6323d10c..aefdacbeeaa 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -9,7 +9,7 @@ import math from operator import attrgetter import random import re -from typing import Any, Iterable, List, Optional, Union +from typing import Any, Generator, Iterable, List, Optional, Union from urllib.parse import urlencode as urllib_urlencode import weakref @@ -425,12 +425,12 @@ class AllStates: def __iter__(self): """Return all states.""" self._collect_all() - return _state_iterator(self._hass, None) + return _state_generator(self._hass, None) def __len__(self) -> int: """Return number of states.""" self._collect_all() - return len(self._hass.states.async_entity_ids()) + return self._hass.states.async_entity_ids_count() def __call__(self, entity_id): """Return the states.""" @@ -465,12 +465,12 @@ class DomainStates: def __iter__(self): """Return the iteration over all the states.""" self._collect_domain() - return _state_iterator(self._hass, self._domain) + return _state_generator(self._hass, self._domain) def __len__(self) -> int: """Return number of states.""" self._collect_domain() - return len(self._hass.states.async_entity_ids(self._domain)) + return self._hass.states.async_entity_ids_count(self._domain) def __repr__(self) -> str: """Representation of Domain States.""" @@ -537,12 +537,10 @@ def _collect_state(hass: HomeAssistantType, entity_id: str) -> None: entity_collect.entities.add(entity_id) -def _state_iterator(hass: HomeAssistantType, domain: Optional[str]) -> Iterable: - """Create an state iterator for a domain or all states.""" - return iter( - TemplateState(hass, state) - for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id")) - ) +def _state_generator(hass: HomeAssistantType, domain: Optional[str]) -> Generator: + """State generator for a domain or all states.""" + for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id")): + yield TemplateState(hass, state) def _get_state(hass: HomeAssistantType, entity_id: str) -> Optional[TemplateState]: diff --git a/tests/helpers/test_template.py b/tests/helpers/test_template.py index ca36d8612d4..63a1e9de7c2 100644 --- a/tests/helpers/test_template.py +++ b/tests/helpers/test_template.py @@ -2420,3 +2420,14 @@ For loop example getting 3 entity values: assert "sensor0" in result assert "sensor1" in result assert "sun" in result + + +async def test_slice_states(hass): + """Test iterating states with a slice.""" + hass.states.async_set("sensor.test", "23") + + tpl = template.Template( + "{% for states in states | slice(1) -%}{% set state = states | first %}{{ state.entity_id }}{%- endfor %}", + hass, + ) + assert tpl.async_render() == "sensor.test" diff --git a/tests/test_core.py b/tests/test_core.py index f5de9c5f1a1..6c684ae1eac 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1477,3 +1477,20 @@ async def test_async_all(hass): assert { state.entity_id for state in hass.states.async_all(["light", "switch"]) } == {"light.bowl", "light.frog", "switch.link"} + + +async def test_async_entity_ids_count(hass): + """Test async_entity_ids_count.""" + + hass.states.async_set("switch.link", "on") + hass.states.async_set("light.bowl", "on") + hass.states.async_set("light.frog", "on") + hass.states.async_set("vacuum.floor", "on") + + assert hass.states.async_entity_ids_count() == 4 + assert hass.states.async_entity_ids_count("light") == 2 + + hass.states.async_set("light.cow", "on") + + assert hass.states.async_entity_ids_count() == 5 + assert hass.states.async_entity_ids_count("light") == 3