Improve performance of counting and iterating states in templates (#40250)

Co-authored-by: Anders Melchiorsen <amelchio@nogoto.net>
This commit is contained in:
J. Nick Koston 2020-09-26 11:36:47 -05:00 committed by GitHub
parent 1d41f024cf
commit 35533407fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 11 deletions

View File

@ -923,6 +923,24 @@ class StateMachine:
if state.domain in domain_filter if state.domain in domain_filter
] ]
@callback
def async_entity_ids_count(
self, domain_filter: Optional[Union[str, Iterable]] = None
) -> int:
"""Count the entity ids that are being tracked.
This method must be run in the event loop.
"""
if domain_filter is None:
return len(self._states.keys())
if isinstance(domain_filter, str):
domain_filter = (domain_filter.lower(),)
return len(
[None for state in self._states.values() if state.domain in domain_filter]
)
def all(self, domain_filter: Optional[Union[str, Iterable]] = None) -> List[State]: def all(self, domain_filter: Optional[Union[str, Iterable]] = None) -> List[State]:
"""Create a list of all states.""" """Create a list of all states."""
return run_callback_threadsafe( return run_callback_threadsafe(

View File

@ -9,7 +9,7 @@ import math
from operator import attrgetter from operator import attrgetter
import random import random
import re import re
from typing import Any, Iterable, List, Optional, Union from typing import Any, Generator, Iterable, List, Optional, Union
from urllib.parse import urlencode as urllib_urlencode from urllib.parse import urlencode as urllib_urlencode
import weakref import weakref
@ -425,12 +425,12 @@ class AllStates:
def __iter__(self): def __iter__(self):
"""Return all states.""" """Return all states."""
self._collect_all() self._collect_all()
return _state_iterator(self._hass, None) return _state_generator(self._hass, None)
def __len__(self) -> int: def __len__(self) -> int:
"""Return number of states.""" """Return number of states."""
self._collect_all() self._collect_all()
return len(self._hass.states.async_entity_ids()) return self._hass.states.async_entity_ids_count()
def __call__(self, entity_id): def __call__(self, entity_id):
"""Return the states.""" """Return the states."""
@ -465,12 +465,12 @@ class DomainStates:
def __iter__(self): def __iter__(self):
"""Return the iteration over all the states.""" """Return the iteration over all the states."""
self._collect_domain() self._collect_domain()
return _state_iterator(self._hass, self._domain) return _state_generator(self._hass, self._domain)
def __len__(self) -> int: def __len__(self) -> int:
"""Return number of states.""" """Return number of states."""
self._collect_domain() self._collect_domain()
return len(self._hass.states.async_entity_ids(self._domain)) return self._hass.states.async_entity_ids_count(self._domain)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Representation of Domain States.""" """Representation of Domain States."""
@ -537,12 +537,10 @@ def _collect_state(hass: HomeAssistantType, entity_id: str) -> None:
entity_collect.entities.add(entity_id) entity_collect.entities.add(entity_id)
def _state_iterator(hass: HomeAssistantType, domain: Optional[str]) -> Iterable: def _state_generator(hass: HomeAssistantType, domain: Optional[str]) -> Generator:
"""Create an state iterator for a domain or all states.""" """State generator for a domain or all states."""
return iter( for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id")):
TemplateState(hass, state) yield TemplateState(hass, state)
for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id"))
)
def _get_state(hass: HomeAssistantType, entity_id: str) -> Optional[TemplateState]: def _get_state(hass: HomeAssistantType, entity_id: str) -> Optional[TemplateState]:

View File

@ -2420,3 +2420,14 @@ For loop example getting 3 entity values:
assert "sensor0" in result assert "sensor0" in result
assert "sensor1" in result assert "sensor1" in result
assert "sun" in result assert "sun" in result
async def test_slice_states(hass):
"""Test iterating states with a slice."""
hass.states.async_set("sensor.test", "23")
tpl = template.Template(
"{% for states in states | slice(1) -%}{% set state = states | first %}{{ state.entity_id }}{%- endfor %}",
hass,
)
assert tpl.async_render() == "sensor.test"

View File

@ -1477,3 +1477,20 @@ async def test_async_all(hass):
assert { assert {
state.entity_id for state in hass.states.async_all(["light", "switch"]) state.entity_id for state in hass.states.async_all(["light", "switch"])
} == {"light.bowl", "light.frog", "switch.link"} } == {"light.bowl", "light.frog", "switch.link"}
async def test_async_entity_ids_count(hass):
"""Test async_entity_ids_count."""
hass.states.async_set("switch.link", "on")
hass.states.async_set("light.bowl", "on")
hass.states.async_set("light.frog", "on")
hass.states.async_set("vacuum.floor", "on")
assert hass.states.async_entity_ids_count() == 4
assert hass.states.async_entity_ids_count("light") == 2
hass.states.async_set("light.cow", "on")
assert hass.states.async_entity_ids_count() == 5
assert hass.states.async_entity_ids_count("light") == 3