From b9931aabe72adc47c9aa1a0d69379b26ff37339d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 1 Oct 2020 03:19:20 -0500 Subject: [PATCH] Seperate state change tracking from async_track_template_result into async_track_state_change_filtered (#40803) --- homeassistant/helpers/event.py | 334 +++++++++++++++++++++------------ tests/helpers/test_event.py | 138 ++++++++++++++ 2 files changed, 349 insertions(+), 123 deletions(-) diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 9af781b7abd..b396ebb1d91 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -61,13 +61,27 @@ TRACK_STATE_REMOVED_DOMAIN_LISTENER = "track_state_removed_domain_listener" TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks" TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener" -_TEMPLATE_ALL_LISTENER = "all" -_TEMPLATE_DOMAINS_LISTENER = "domains" -_TEMPLATE_ENTITIES_LISTENER = "entities" +_ALL_LISTENER = "all" +_DOMAINS_LISTENER = "domains" +_ENTITIES_LISTENER = "entities" _LOGGER = logging.getLogger(__name__) +@dataclass +class TrackStates: + """Class for keeping track of states being tracked. + + all_states: All states on the system are being tracked + entities: Entities to track + domains: Domains to track + """ + + all_states: bool + entities: Set + domains: Set + + @dataclass class TrackTemplate: """Class for keeping track of a template with variables. @@ -452,6 +466,158 @@ def _async_string_to_lower_list(instr: Union[str, Iterable[str]]) -> List[str]: return [mstr.lower() for mstr in instr] +class _TrackStateChangeFiltered: + """Handle removal / refresh of tracker.""" + + def __init__( + self, + hass: HomeAssistant, + track_states: TrackStates, + action: Callable[[Event], Any], + ): + """Handle removal / refresh of tracker init.""" + self.hass = hass + self._action = action + self._listeners: Dict[str, Callable] = {} + self._last_track_states: TrackStates = track_states + + @callback + def async_setup(self) -> None: + """Create listeners to track states.""" + track_states = self._last_track_states + + if ( + not track_states.all_states + and not track_states.domains + and not track_states.entities + ): + return + + if track_states.all_states: + self._setup_all_listener() + return + + self._setup_domains_listener(track_states.domains) + self._setup_entities_listener(track_states.domains, track_states.entities) + + @property + def listeners(self) -> Dict: + """State changes that will cause a re-render.""" + track_states = self._last_track_states + return { + _ALL_LISTENER: track_states.all_states, + _ENTITIES_LISTENER: track_states.entities, + _DOMAINS_LISTENER: track_states.domains, + } + + @callback + def async_update_listeners(self, new_track_states: TrackStates) -> None: + """Update the listeners based on the new TrackStates.""" + last_track_states = self._last_track_states + self._last_track_states = new_track_states + + had_all_listener = last_track_states.all_states + + if new_track_states.all_states: + if had_all_listener: + return + self._cancel_listener(_DOMAINS_LISTENER) + self._cancel_listener(_ENTITIES_LISTENER) + self._setup_all_listener() + return + + if had_all_listener: + self._cancel_listener(_ALL_LISTENER) + + domains_changed = new_track_states.domains != last_track_states.domains + + if had_all_listener or domains_changed: + domains_changed = True + self._cancel_listener(_DOMAINS_LISTENER) + self._setup_domains_listener(new_track_states.domains) + + if ( + had_all_listener + or domains_changed + or new_track_states.entities != last_track_states.entities + ): + self._cancel_listener(_ENTITIES_LISTENER) + self._setup_entities_listener( + new_track_states.domains, new_track_states.entities + ) + + @callback + def async_remove(self) -> None: + """Cancel the listeners.""" + for key in list(self._listeners): + self._listeners.pop(key)() + + @callback + def _cancel_listener(self, listener_name: str) -> None: + if listener_name not in self._listeners: + return + + self._listeners.pop(listener_name)() + + @callback + def _setup_entities_listener(self, domains: Set, entities: Set) -> None: + if domains: + entities = entities.copy() + entities.update(self.hass.states.async_entity_ids(domains)) + + # Entities has changed to none + if not entities: + return + + self._listeners[_ENTITIES_LISTENER] = async_track_state_change_event( + self.hass, entities, self._action + ) + + @callback + def _setup_domains_listener(self, domains: Set) -> None: + if not domains: + return + + self._listeners[_DOMAINS_LISTENER] = async_track_state_added_domain( + self.hass, domains, self._action + ) + + @callback + def _setup_all_listener(self) -> None: + self._listeners[_ALL_LISTENER] = self.hass.bus.async_listen( + EVENT_STATE_CHANGED, self._action + ) + + +@callback +@bind_hass +def async_track_state_change_filtered( + hass: HomeAssistant, + track_states: TrackStates, + action: Callable[[Event], Any], +) -> _TrackStateChangeFiltered: + """Track state changes with a TrackStates filter that can be updated. + + Parameters + ---------- + hass + Home assistant object. + track_states + A TrackStates data class. + action + Callable to call with results. + + Returns + ------- + Object used to update the listeners (async_update_listeners) with a new TrackStates or + cancel the tracking (async_remove). + + """ + tracker = _TrackStateChangeFiltered(hass, track_states, action) + tracker.async_setup() + return tracker + + @callback @bind_hass def async_track_template( @@ -557,12 +723,9 @@ class _TrackTemplateResultInfo: track_template_.template.hass = hass self._track_templates = track_templates - self._listeners: Dict[str, Callable] = {} - self._last_result: Dict[Template, Union[str, TemplateError]] = {} self._info: Dict[Template, RenderInfo] = {} - self._last_domains: Set = set() - self._last_entities: Set = set() + self._track_state_changes: Optional[_TrackStateChangeFiltered] = None def async_setup(self, raise_on_template_error: bool) -> None: """Activation of template tracking.""" @@ -580,7 +743,9 @@ class _TrackTemplateResultInfo: exc_info=self._info[template].exception, ) - self._create_listeners() + self._track_state_changes = async_track_state_change_filtered( + self.hass, _render_infos_to_track_states(self._info.values()), self._refresh + ) _LOGGER.debug( "Template group %s listens for %s", self._track_templates, @@ -590,123 +755,14 @@ class _TrackTemplateResultInfo: @property def listeners(self) -> Dict: """State changes that will cause a re-render.""" - return { - "all": _TEMPLATE_ALL_LISTENER in self._listeners, - "entities": self._last_entities, - "domains": self._last_domains, - } - - @property - def _needs_all_listener(self) -> bool: - for info in self._info.values(): - # Tracking all states - if info.all_states or info.all_states_lifecycle: - return True - - # Previous call had an exception - # so we do not know which states - # to track - if info.exception: - return True - - return False - - @property - def _all_templates_are_static(self) -> bool: - for info in self._info.values(): - if not info.is_static: - return False - - return True - - @callback - def _create_listeners(self) -> None: - if self._all_templates_are_static: - return - - if self._needs_all_listener: - self._setup_all_listener() - return - - self._last_entities, self._last_domains = _entities_domains_from_info( - self._info.values() - ) - self._setup_domains_listener(self._last_domains) - self._setup_entities_listener(self._last_domains, self._last_entities) - - @callback - def _cancel_listener(self, listener_name: str) -> None: - if listener_name not in self._listeners: - return - - self._listeners.pop(listener_name)() - - @callback - def _update_listeners(self) -> None: - had_all_listener = _TEMPLATE_ALL_LISTENER in self._listeners - - if self._needs_all_listener: - if had_all_listener: - return - self._last_domains = set() - self._last_entities = set() - self._cancel_listener(_TEMPLATE_DOMAINS_LISTENER) - self._cancel_listener(_TEMPLATE_ENTITIES_LISTENER) - self._setup_all_listener() - return - - if had_all_listener: - self._cancel_listener(_TEMPLATE_ALL_LISTENER) - - entities, domains = _entities_domains_from_info(self._info.values()) - domains_changed = domains != self._last_domains - - if had_all_listener or domains_changed: - domains_changed = True - self._cancel_listener(_TEMPLATE_DOMAINS_LISTENER) - self._setup_domains_listener(domains) - - if had_all_listener or domains_changed or entities != self._last_entities: - self._cancel_listener(_TEMPLATE_ENTITIES_LISTENER) - self._setup_entities_listener(domains, entities) - - self._last_domains = domains - self._last_entities = entities - - @callback - def _setup_entities_listener(self, domains: Set, entities: Set) -> None: - if domains: - entities = entities.copy() - entities.update(self.hass.states.async_entity_ids(domains)) - - # Entities has changed to none - if not entities: - return - - self._listeners[_TEMPLATE_ENTITIES_LISTENER] = async_track_state_change_event( - self.hass, entities, self._refresh - ) - - @callback - def _setup_domains_listener(self, domains: Set) -> None: - if not domains: - return - - self._listeners[_TEMPLATE_DOMAINS_LISTENER] = async_track_state_added_domain( - self.hass, domains, self._refresh - ) - - @callback - def _setup_all_listener(self) -> None: - self._listeners[_TEMPLATE_ALL_LISTENER] = self.hass.bus.async_listen( - EVENT_STATE_CHANGED, self._refresh - ) + assert self._track_state_changes + return self._track_state_changes.listeners @callback def async_remove(self) -> None: """Cancel the listener.""" - for key in list(self._listeners): - self._listeners.pop(key)() + assert self._track_state_changes + self._track_state_changes.async_remove() @callback def async_refresh(self) -> None: @@ -765,7 +821,10 @@ class _TrackTemplateResultInfo: updates.append(TrackTemplateResult(template, last_result, result)) if info_changed: - self._update_listeners() + assert self._track_state_changes + self._track_state_changes.async_update_listeners( + _render_infos_to_track_states(self._info.values()), + ) _LOGGER.debug( "Template group %s listens for %s", self._track_templates, @@ -1229,7 +1288,10 @@ def process_state_match( return lambda state: state in parameter_set -def _entities_domains_from_info(render_infos: Iterable[RenderInfo]) -> Tuple[Set, Set]: +@callback +def _entities_domains_from_render_infos( + render_infos: Iterable[RenderInfo], +) -> Tuple[Set, Set]: """Combine from multiple RenderInfo.""" entities = set() domains = set() @@ -1242,3 +1304,29 @@ def _entities_domains_from_info(render_infos: Iterable[RenderInfo]) -> Tuple[Set if render_info.domains_lifecycle: domains.update(render_info.domains_lifecycle) return entities, domains + + +@callback +def _render_infos_needs_all_listener(render_infos: Iterable[RenderInfo]) -> bool: + """Determine if an all listener is needed from RenderInfo.""" + for render_info in render_infos: + # Tracking all states + if render_info.all_states or render_info.all_states_lifecycle: + return True + + # Previous call had an exception + # so we do not know which states + # to track + if render_info.exception: + return True + + return False + + +@callback +def _render_infos_to_track_states(render_infos: Iterable[RenderInfo]) -> TrackStates: + """Create a TrackStates dataclass from the latest RenderInfo.""" + if _render_infos_needs_all_listener(render_infos): + return TrackStates(True, set(), set()) + + return TrackStates(False, *_entities_domains_from_render_infos(render_infos)) diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 8bdf9cb891c..887917fa74c 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -14,6 +14,7 @@ from homeassistant.core import callback from homeassistant.exceptions import TemplateError from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.event import ( + TrackStates, TrackTemplate, TrackTemplateResult, async_call_later, @@ -23,6 +24,7 @@ from homeassistant.helpers.event import ( async_track_state_added_domain, async_track_state_change, async_track_state_change_event, + async_track_state_change_filtered, async_track_state_removed_domain, async_track_sunrise, async_track_sunset, @@ -255,6 +257,142 @@ async def test_track_state_change(hass): assert len(wildercard_runs) == 6 +async def test_async_track_state_change_filtered(hass): + """Test async_track_state_change_filtered.""" + single_entity_id_tracker = [] + multiple_entity_id_tracker = [] + + @ha.callback + def single_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + single_entity_id_tracker.append((old_state, new_state)) + + @ha.callback + def multiple_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + multiple_entity_id_tracker.append((old_state, new_state)) + + @ha.callback + def callback_that_throws(event): + raise ValueError + + track_single = async_track_state_change_filtered( + hass, TrackStates(False, {"light.bowl"}, None), single_run_callback + ) + assert track_single.listeners == { + "all": False, + "domains": None, + "entities": {"light.bowl"}, + } + + track_multi = async_track_state_change_filtered( + hass, TrackStates(False, {"light.bowl"}, {"switch"}), multiple_run_callback + ) + assert track_multi.listeners == { + "all": False, + "domains": {"switch"}, + "entities": {"light.bowl"}, + } + + track_throws = async_track_state_change_filtered( + hass, TrackStates(False, {"light.bowl"}, {"switch"}), callback_that_throws + ) + assert track_throws.listeners == { + "all": False, + "domains": {"switch"}, + "entities": {"light.bowl"}, + } + + # Adding state to state machine + hass.states.async_set("light.Bowl", "on") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert single_entity_id_tracker[-1][0] is None + assert single_entity_id_tracker[-1][1] is not None + assert len(multiple_entity_id_tracker) == 1 + assert multiple_entity_id_tracker[-1][0] is None + assert multiple_entity_id_tracker[-1][1] is not None + + # Set same state should not trigger a state change/listener + hass.states.async_set("light.Bowl", "on") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert len(multiple_entity_id_tracker) == 1 + + # State change off -> on + hass.states.async_set("light.Bowl", "off") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 2 + assert len(multiple_entity_id_tracker) == 2 + + # State change off -> off + hass.states.async_set("light.Bowl", "off", {"some_attr": 1}) + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 3 + assert len(multiple_entity_id_tracker) == 3 + + # State change off -> on + hass.states.async_set("light.Bowl", "on") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 4 + assert len(multiple_entity_id_tracker) == 4 + + hass.states.async_remove("light.bowl") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 5 + assert single_entity_id_tracker[-1][0] is not None + assert single_entity_id_tracker[-1][1] is None + assert len(multiple_entity_id_tracker) == 5 + assert multiple_entity_id_tracker[-1][0] is not None + assert multiple_entity_id_tracker[-1][1] is None + + # Set state for different entity id + hass.states.async_set("switch.kitchen", "on") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 5 + assert len(multiple_entity_id_tracker) == 6 + + track_single.async_remove() + # Ensure unsubing the listener works + hass.states.async_set("light.Bowl", "off") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 5 + assert len(multiple_entity_id_tracker) == 7 + + assert track_multi.listeners == { + "all": False, + "domains": {"switch"}, + "entities": {"light.bowl"}, + } + track_multi.async_update_listeners(TrackStates(False, {"light.bowl"}, None)) + assert track_multi.listeners == { + "all": False, + "domains": None, + "entities": {"light.bowl"}, + } + hass.states.async_set("light.Bowl", "on") + await hass.async_block_till_done() + assert len(multiple_entity_id_tracker) == 8 + hass.states.async_set("switch.kitchen", "off") + await hass.async_block_till_done() + assert len(multiple_entity_id_tracker) == 8 + + track_multi.async_update_listeners(TrackStates(True, None, None)) + hass.states.async_set("switch.kitchen", "off") + await hass.async_block_till_done() + assert len(multiple_entity_id_tracker) == 8 + hass.states.async_set("switch.any", "off") + await hass.async_block_till_done() + assert len(multiple_entity_id_tracker) == 9 + + track_multi.async_remove() + track_throws.async_remove() + + async def test_async_track_state_change_event(hass): """Test async_track_state_change_event.""" single_entity_id_tracker = []