Seperate state change tracking from async_track_template_result into async_track_state_change_filtered (#40803)

This commit is contained in:
J. Nick Koston 2020-10-01 03:19:20 -05:00 committed by Paulus Schoutsen
parent 1c534ea027
commit b9931aabe7
2 changed files with 349 additions and 123 deletions

View File

@ -61,13 +61,27 @@ TRACK_STATE_REMOVED_DOMAIN_LISTENER = "track_state_removed_domain_listener"
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener"
_TEMPLATE_ALL_LISTENER = "all"
_TEMPLATE_DOMAINS_LISTENER = "domains"
_TEMPLATE_ENTITIES_LISTENER = "entities"
_ALL_LISTENER = "all"
_DOMAINS_LISTENER = "domains"
_ENTITIES_LISTENER = "entities"
_LOGGER = logging.getLogger(__name__)
@dataclass
class TrackStates:
"""Class for keeping track of states being tracked.
all_states: All states on the system are being tracked
entities: Entities to track
domains: Domains to track
"""
all_states: bool
entities: Set
domains: Set
@dataclass
class TrackTemplate:
"""Class for keeping track of a template with variables.
@ -452,6 +466,158 @@ def _async_string_to_lower_list(instr: Union[str, Iterable[str]]) -> List[str]:
return [mstr.lower() for mstr in instr]
class _TrackStateChangeFiltered:
"""Handle removal / refresh of tracker."""
def __init__(
self,
hass: HomeAssistant,
track_states: TrackStates,
action: Callable[[Event], Any],
):
"""Handle removal / refresh of tracker init."""
self.hass = hass
self._action = action
self._listeners: Dict[str, Callable] = {}
self._last_track_states: TrackStates = track_states
@callback
def async_setup(self) -> None:
"""Create listeners to track states."""
track_states = self._last_track_states
if (
not track_states.all_states
and not track_states.domains
and not track_states.entities
):
return
if track_states.all_states:
self._setup_all_listener()
return
self._setup_domains_listener(track_states.domains)
self._setup_entities_listener(track_states.domains, track_states.entities)
@property
def listeners(self) -> Dict:
"""State changes that will cause a re-render."""
track_states = self._last_track_states
return {
_ALL_LISTENER: track_states.all_states,
_ENTITIES_LISTENER: track_states.entities,
_DOMAINS_LISTENER: track_states.domains,
}
@callback
def async_update_listeners(self, new_track_states: TrackStates) -> None:
"""Update the listeners based on the new TrackStates."""
last_track_states = self._last_track_states
self._last_track_states = new_track_states
had_all_listener = last_track_states.all_states
if new_track_states.all_states:
if had_all_listener:
return
self._cancel_listener(_DOMAINS_LISTENER)
self._cancel_listener(_ENTITIES_LISTENER)
self._setup_all_listener()
return
if had_all_listener:
self._cancel_listener(_ALL_LISTENER)
domains_changed = new_track_states.domains != last_track_states.domains
if had_all_listener or domains_changed:
domains_changed = True
self._cancel_listener(_DOMAINS_LISTENER)
self._setup_domains_listener(new_track_states.domains)
if (
had_all_listener
or domains_changed
or new_track_states.entities != last_track_states.entities
):
self._cancel_listener(_ENTITIES_LISTENER)
self._setup_entities_listener(
new_track_states.domains, new_track_states.entities
)
@callback
def async_remove(self) -> None:
"""Cancel the listeners."""
for key in list(self._listeners):
self._listeners.pop(key)()
@callback
def _cancel_listener(self, listener_name: str) -> None:
if listener_name not in self._listeners:
return
self._listeners.pop(listener_name)()
@callback
def _setup_entities_listener(self, domains: Set, entities: Set) -> None:
if domains:
entities = entities.copy()
entities.update(self.hass.states.async_entity_ids(domains))
# Entities has changed to none
if not entities:
return
self._listeners[_ENTITIES_LISTENER] = async_track_state_change_event(
self.hass, entities, self._action
)
@callback
def _setup_domains_listener(self, domains: Set) -> None:
if not domains:
return
self._listeners[_DOMAINS_LISTENER] = async_track_state_added_domain(
self.hass, domains, self._action
)
@callback
def _setup_all_listener(self) -> None:
self._listeners[_ALL_LISTENER] = self.hass.bus.async_listen(
EVENT_STATE_CHANGED, self._action
)
@callback
@bind_hass
def async_track_state_change_filtered(
hass: HomeAssistant,
track_states: TrackStates,
action: Callable[[Event], Any],
) -> _TrackStateChangeFiltered:
"""Track state changes with a TrackStates filter that can be updated.
Parameters
----------
hass
Home assistant object.
track_states
A TrackStates data class.
action
Callable to call with results.
Returns
-------
Object used to update the listeners (async_update_listeners) with a new TrackStates or
cancel the tracking (async_remove).
"""
tracker = _TrackStateChangeFiltered(hass, track_states, action)
tracker.async_setup()
return tracker
@callback
@bind_hass
def async_track_template(
@ -557,12 +723,9 @@ class _TrackTemplateResultInfo:
track_template_.template.hass = hass
self._track_templates = track_templates
self._listeners: Dict[str, Callable] = {}
self._last_result: Dict[Template, Union[str, TemplateError]] = {}
self._info: Dict[Template, RenderInfo] = {}
self._last_domains: Set = set()
self._last_entities: Set = set()
self._track_state_changes: Optional[_TrackStateChangeFiltered] = None
def async_setup(self, raise_on_template_error: bool) -> None:
"""Activation of template tracking."""
@ -580,7 +743,9 @@ class _TrackTemplateResultInfo:
exc_info=self._info[template].exception,
)
self._create_listeners()
self._track_state_changes = async_track_state_change_filtered(
self.hass, _render_infos_to_track_states(self._info.values()), self._refresh
)
_LOGGER.debug(
"Template group %s listens for %s",
self._track_templates,
@ -590,123 +755,14 @@ class _TrackTemplateResultInfo:
@property
def listeners(self) -> Dict:
"""State changes that will cause a re-render."""
return {
"all": _TEMPLATE_ALL_LISTENER in self._listeners,
"entities": self._last_entities,
"domains": self._last_domains,
}
@property
def _needs_all_listener(self) -> bool:
for info in self._info.values():
# Tracking all states
if info.all_states or info.all_states_lifecycle:
return True
# Previous call had an exception
# so we do not know which states
# to track
if info.exception:
return True
return False
@property
def _all_templates_are_static(self) -> bool:
for info in self._info.values():
if not info.is_static:
return False
return True
@callback
def _create_listeners(self) -> None:
if self._all_templates_are_static:
return
if self._needs_all_listener:
self._setup_all_listener()
return
self._last_entities, self._last_domains = _entities_domains_from_info(
self._info.values()
)
self._setup_domains_listener(self._last_domains)
self._setup_entities_listener(self._last_domains, self._last_entities)
@callback
def _cancel_listener(self, listener_name: str) -> None:
if listener_name not in self._listeners:
return
self._listeners.pop(listener_name)()
@callback
def _update_listeners(self) -> None:
had_all_listener = _TEMPLATE_ALL_LISTENER in self._listeners
if self._needs_all_listener:
if had_all_listener:
return
self._last_domains = set()
self._last_entities = set()
self._cancel_listener(_TEMPLATE_DOMAINS_LISTENER)
self._cancel_listener(_TEMPLATE_ENTITIES_LISTENER)
self._setup_all_listener()
return
if had_all_listener:
self._cancel_listener(_TEMPLATE_ALL_LISTENER)
entities, domains = _entities_domains_from_info(self._info.values())
domains_changed = domains != self._last_domains
if had_all_listener or domains_changed:
domains_changed = True
self._cancel_listener(_TEMPLATE_DOMAINS_LISTENER)
self._setup_domains_listener(domains)
if had_all_listener or domains_changed or entities != self._last_entities:
self._cancel_listener(_TEMPLATE_ENTITIES_LISTENER)
self._setup_entities_listener(domains, entities)
self._last_domains = domains
self._last_entities = entities
@callback
def _setup_entities_listener(self, domains: Set, entities: Set) -> None:
if domains:
entities = entities.copy()
entities.update(self.hass.states.async_entity_ids(domains))
# Entities has changed to none
if not entities:
return
self._listeners[_TEMPLATE_ENTITIES_LISTENER] = async_track_state_change_event(
self.hass, entities, self._refresh
)
@callback
def _setup_domains_listener(self, domains: Set) -> None:
if not domains:
return
self._listeners[_TEMPLATE_DOMAINS_LISTENER] = async_track_state_added_domain(
self.hass, domains, self._refresh
)
@callback
def _setup_all_listener(self) -> None:
self._listeners[_TEMPLATE_ALL_LISTENER] = self.hass.bus.async_listen(
EVENT_STATE_CHANGED, self._refresh
)
assert self._track_state_changes
return self._track_state_changes.listeners
@callback
def async_remove(self) -> None:
"""Cancel the listener."""
for key in list(self._listeners):
self._listeners.pop(key)()
assert self._track_state_changes
self._track_state_changes.async_remove()
@callback
def async_refresh(self) -> None:
@ -765,7 +821,10 @@ class _TrackTemplateResultInfo:
updates.append(TrackTemplateResult(template, last_result, result))
if info_changed:
self._update_listeners()
assert self._track_state_changes
self._track_state_changes.async_update_listeners(
_render_infos_to_track_states(self._info.values()),
)
_LOGGER.debug(
"Template group %s listens for %s",
self._track_templates,
@ -1229,7 +1288,10 @@ def process_state_match(
return lambda state: state in parameter_set
def _entities_domains_from_info(render_infos: Iterable[RenderInfo]) -> Tuple[Set, Set]:
@callback
def _entities_domains_from_render_infos(
render_infos: Iterable[RenderInfo],
) -> Tuple[Set, Set]:
"""Combine from multiple RenderInfo."""
entities = set()
domains = set()
@ -1242,3 +1304,29 @@ def _entities_domains_from_info(render_infos: Iterable[RenderInfo]) -> Tuple[Set
if render_info.domains_lifecycle:
domains.update(render_info.domains_lifecycle)
return entities, domains
@callback
def _render_infos_needs_all_listener(render_infos: Iterable[RenderInfo]) -> bool:
"""Determine if an all listener is needed from RenderInfo."""
for render_info in render_infos:
# Tracking all states
if render_info.all_states or render_info.all_states_lifecycle:
return True
# Previous call had an exception
# so we do not know which states
# to track
if render_info.exception:
return True
return False
@callback
def _render_infos_to_track_states(render_infos: Iterable[RenderInfo]) -> TrackStates:
"""Create a TrackStates dataclass from the latest RenderInfo."""
if _render_infos_needs_all_listener(render_infos):
return TrackStates(True, set(), set())
return TrackStates(False, *_entities_domains_from_render_infos(render_infos))

View File

@ -14,6 +14,7 @@ from homeassistant.core import callback
from homeassistant.exceptions import TemplateError
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from homeassistant.helpers.event import (
TrackStates,
TrackTemplate,
TrackTemplateResult,
async_call_later,
@ -23,6 +24,7 @@ from homeassistant.helpers.event import (
async_track_state_added_domain,
async_track_state_change,
async_track_state_change_event,
async_track_state_change_filtered,
async_track_state_removed_domain,
async_track_sunrise,
async_track_sunset,
@ -255,6 +257,142 @@ async def test_track_state_change(hass):
assert len(wildercard_runs) == 6
async def test_async_track_state_change_filtered(hass):
"""Test async_track_state_change_filtered."""
single_entity_id_tracker = []
multiple_entity_id_tracker = []
@ha.callback
def single_run_callback(event):
old_state = event.data.get("old_state")
new_state = event.data.get("new_state")
single_entity_id_tracker.append((old_state, new_state))
@ha.callback
def multiple_run_callback(event):
old_state = event.data.get("old_state")
new_state = event.data.get("new_state")
multiple_entity_id_tracker.append((old_state, new_state))
@ha.callback
def callback_that_throws(event):
raise ValueError
track_single = async_track_state_change_filtered(
hass, TrackStates(False, {"light.bowl"}, None), single_run_callback
)
assert track_single.listeners == {
"all": False,
"domains": None,
"entities": {"light.bowl"},
}
track_multi = async_track_state_change_filtered(
hass, TrackStates(False, {"light.bowl"}, {"switch"}), multiple_run_callback
)
assert track_multi.listeners == {
"all": False,
"domains": {"switch"},
"entities": {"light.bowl"},
}
track_throws = async_track_state_change_filtered(
hass, TrackStates(False, {"light.bowl"}, {"switch"}), callback_that_throws
)
assert track_throws.listeners == {
"all": False,
"domains": {"switch"},
"entities": {"light.bowl"},
}
# Adding state to state machine
hass.states.async_set("light.Bowl", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert single_entity_id_tracker[-1][0] is None
assert single_entity_id_tracker[-1][1] is not None
assert len(multiple_entity_id_tracker) == 1
assert multiple_entity_id_tracker[-1][0] is None
assert multiple_entity_id_tracker[-1][1] is not None
# Set same state should not trigger a state change/listener
hass.states.async_set("light.Bowl", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1
# State change off -> on
hass.states.async_set("light.Bowl", "off")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 2
assert len(multiple_entity_id_tracker) == 2
# State change off -> off
hass.states.async_set("light.Bowl", "off", {"some_attr": 1})
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 3
assert len(multiple_entity_id_tracker) == 3
# State change off -> on
hass.states.async_set("light.Bowl", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 4
assert len(multiple_entity_id_tracker) == 4
hass.states.async_remove("light.bowl")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 5
assert single_entity_id_tracker[-1][0] is not None
assert single_entity_id_tracker[-1][1] is None
assert len(multiple_entity_id_tracker) == 5
assert multiple_entity_id_tracker[-1][0] is not None
assert multiple_entity_id_tracker[-1][1] is None
# Set state for different entity id
hass.states.async_set("switch.kitchen", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 5
assert len(multiple_entity_id_tracker) == 6
track_single.async_remove()
# Ensure unsubing the listener works
hass.states.async_set("light.Bowl", "off")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 5
assert len(multiple_entity_id_tracker) == 7
assert track_multi.listeners == {
"all": False,
"domains": {"switch"},
"entities": {"light.bowl"},
}
track_multi.async_update_listeners(TrackStates(False, {"light.bowl"}, None))
assert track_multi.listeners == {
"all": False,
"domains": None,
"entities": {"light.bowl"},
}
hass.states.async_set("light.Bowl", "on")
await hass.async_block_till_done()
assert len(multiple_entity_id_tracker) == 8
hass.states.async_set("switch.kitchen", "off")
await hass.async_block_till_done()
assert len(multiple_entity_id_tracker) == 8
track_multi.async_update_listeners(TrackStates(True, None, None))
hass.states.async_set("switch.kitchen", "off")
await hass.async_block_till_done()
assert len(multiple_entity_id_tracker) == 8
hass.states.async_set("switch.any", "off")
await hass.async_block_till_done()
assert len(multiple_entity_id_tracker) == 9
track_multi.async_remove()
track_throws.async_remove()
async def test_async_track_state_change_event(hass):
"""Test async_track_state_change_event."""
single_entity_id_tracker = []