mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 22:37:11 +00:00
Seperate state change tracking from async_track_template_result into async_track_state_change_filtered (#40803)
This commit is contained in:
parent
1c534ea027
commit
b9931aabe7
@ -61,13 +61,27 @@ TRACK_STATE_REMOVED_DOMAIN_LISTENER = "track_state_removed_domain_listener"
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener"
|
||||
|
||||
_TEMPLATE_ALL_LISTENER = "all"
|
||||
_TEMPLATE_DOMAINS_LISTENER = "domains"
|
||||
_TEMPLATE_ENTITIES_LISTENER = "entities"
|
||||
_ALL_LISTENER = "all"
|
||||
_DOMAINS_LISTENER = "domains"
|
||||
_ENTITIES_LISTENER = "entities"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackStates:
|
||||
"""Class for keeping track of states being tracked.
|
||||
|
||||
all_states: All states on the system are being tracked
|
||||
entities: Entities to track
|
||||
domains: Domains to track
|
||||
"""
|
||||
|
||||
all_states: bool
|
||||
entities: Set
|
||||
domains: Set
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackTemplate:
|
||||
"""Class for keeping track of a template with variables.
|
||||
@ -452,6 +466,158 @@ def _async_string_to_lower_list(instr: Union[str, Iterable[str]]) -> List[str]:
|
||||
return [mstr.lower() for mstr in instr]
|
||||
|
||||
|
||||
class _TrackStateChangeFiltered:
|
||||
"""Handle removal / refresh of tracker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
track_states: TrackStates,
|
||||
action: Callable[[Event], Any],
|
||||
):
|
||||
"""Handle removal / refresh of tracker init."""
|
||||
self.hass = hass
|
||||
self._action = action
|
||||
self._listeners: Dict[str, Callable] = {}
|
||||
self._last_track_states: TrackStates = track_states
|
||||
|
||||
@callback
|
||||
def async_setup(self) -> None:
|
||||
"""Create listeners to track states."""
|
||||
track_states = self._last_track_states
|
||||
|
||||
if (
|
||||
not track_states.all_states
|
||||
and not track_states.domains
|
||||
and not track_states.entities
|
||||
):
|
||||
return
|
||||
|
||||
if track_states.all_states:
|
||||
self._setup_all_listener()
|
||||
return
|
||||
|
||||
self._setup_domains_listener(track_states.domains)
|
||||
self._setup_entities_listener(track_states.domains, track_states.entities)
|
||||
|
||||
@property
|
||||
def listeners(self) -> Dict:
|
||||
"""State changes that will cause a re-render."""
|
||||
track_states = self._last_track_states
|
||||
return {
|
||||
_ALL_LISTENER: track_states.all_states,
|
||||
_ENTITIES_LISTENER: track_states.entities,
|
||||
_DOMAINS_LISTENER: track_states.domains,
|
||||
}
|
||||
|
||||
@callback
|
||||
def async_update_listeners(self, new_track_states: TrackStates) -> None:
|
||||
"""Update the listeners based on the new TrackStates."""
|
||||
last_track_states = self._last_track_states
|
||||
self._last_track_states = new_track_states
|
||||
|
||||
had_all_listener = last_track_states.all_states
|
||||
|
||||
if new_track_states.all_states:
|
||||
if had_all_listener:
|
||||
return
|
||||
self._cancel_listener(_DOMAINS_LISTENER)
|
||||
self._cancel_listener(_ENTITIES_LISTENER)
|
||||
self._setup_all_listener()
|
||||
return
|
||||
|
||||
if had_all_listener:
|
||||
self._cancel_listener(_ALL_LISTENER)
|
||||
|
||||
domains_changed = new_track_states.domains != last_track_states.domains
|
||||
|
||||
if had_all_listener or domains_changed:
|
||||
domains_changed = True
|
||||
self._cancel_listener(_DOMAINS_LISTENER)
|
||||
self._setup_domains_listener(new_track_states.domains)
|
||||
|
||||
if (
|
||||
had_all_listener
|
||||
or domains_changed
|
||||
or new_track_states.entities != last_track_states.entities
|
||||
):
|
||||
self._cancel_listener(_ENTITIES_LISTENER)
|
||||
self._setup_entities_listener(
|
||||
new_track_states.domains, new_track_states.entities
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_remove(self) -> None:
|
||||
"""Cancel the listeners."""
|
||||
for key in list(self._listeners):
|
||||
self._listeners.pop(key)()
|
||||
|
||||
@callback
|
||||
def _cancel_listener(self, listener_name: str) -> None:
|
||||
if listener_name not in self._listeners:
|
||||
return
|
||||
|
||||
self._listeners.pop(listener_name)()
|
||||
|
||||
@callback
|
||||
def _setup_entities_listener(self, domains: Set, entities: Set) -> None:
|
||||
if domains:
|
||||
entities = entities.copy()
|
||||
entities.update(self.hass.states.async_entity_ids(domains))
|
||||
|
||||
# Entities has changed to none
|
||||
if not entities:
|
||||
return
|
||||
|
||||
self._listeners[_ENTITIES_LISTENER] = async_track_state_change_event(
|
||||
self.hass, entities, self._action
|
||||
)
|
||||
|
||||
@callback
|
||||
def _setup_domains_listener(self, domains: Set) -> None:
|
||||
if not domains:
|
||||
return
|
||||
|
||||
self._listeners[_DOMAINS_LISTENER] = async_track_state_added_domain(
|
||||
self.hass, domains, self._action
|
||||
)
|
||||
|
||||
@callback
|
||||
def _setup_all_listener(self) -> None:
|
||||
self._listeners[_ALL_LISTENER] = self.hass.bus.async_listen(
|
||||
EVENT_STATE_CHANGED, self._action
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_track_state_change_filtered(
|
||||
hass: HomeAssistant,
|
||||
track_states: TrackStates,
|
||||
action: Callable[[Event], Any],
|
||||
) -> _TrackStateChangeFiltered:
|
||||
"""Track state changes with a TrackStates filter that can be updated.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hass
|
||||
Home assistant object.
|
||||
track_states
|
||||
A TrackStates data class.
|
||||
action
|
||||
Callable to call with results.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object used to update the listeners (async_update_listeners) with a new TrackStates or
|
||||
cancel the tracking (async_remove).
|
||||
|
||||
"""
|
||||
tracker = _TrackStateChangeFiltered(hass, track_states, action)
|
||||
tracker.async_setup()
|
||||
return tracker
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_track_template(
|
||||
@ -557,12 +723,9 @@ class _TrackTemplateResultInfo:
|
||||
track_template_.template.hass = hass
|
||||
self._track_templates = track_templates
|
||||
|
||||
self._listeners: Dict[str, Callable] = {}
|
||||
|
||||
self._last_result: Dict[Template, Union[str, TemplateError]] = {}
|
||||
self._info: Dict[Template, RenderInfo] = {}
|
||||
self._last_domains: Set = set()
|
||||
self._last_entities: Set = set()
|
||||
self._track_state_changes: Optional[_TrackStateChangeFiltered] = None
|
||||
|
||||
def async_setup(self, raise_on_template_error: bool) -> None:
|
||||
"""Activation of template tracking."""
|
||||
@ -580,7 +743,9 @@ class _TrackTemplateResultInfo:
|
||||
exc_info=self._info[template].exception,
|
||||
)
|
||||
|
||||
self._create_listeners()
|
||||
self._track_state_changes = async_track_state_change_filtered(
|
||||
self.hass, _render_infos_to_track_states(self._info.values()), self._refresh
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"Template group %s listens for %s",
|
||||
self._track_templates,
|
||||
@ -590,123 +755,14 @@ class _TrackTemplateResultInfo:
|
||||
@property
|
||||
def listeners(self) -> Dict:
|
||||
"""State changes that will cause a re-render."""
|
||||
return {
|
||||
"all": _TEMPLATE_ALL_LISTENER in self._listeners,
|
||||
"entities": self._last_entities,
|
||||
"domains": self._last_domains,
|
||||
}
|
||||
|
||||
@property
|
||||
def _needs_all_listener(self) -> bool:
|
||||
for info in self._info.values():
|
||||
# Tracking all states
|
||||
if info.all_states or info.all_states_lifecycle:
|
||||
return True
|
||||
|
||||
# Previous call had an exception
|
||||
# so we do not know which states
|
||||
# to track
|
||||
if info.exception:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def _all_templates_are_static(self) -> bool:
|
||||
for info in self._info.values():
|
||||
if not info.is_static:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@callback
|
||||
def _create_listeners(self) -> None:
|
||||
if self._all_templates_are_static:
|
||||
return
|
||||
|
||||
if self._needs_all_listener:
|
||||
self._setup_all_listener()
|
||||
return
|
||||
|
||||
self._last_entities, self._last_domains = _entities_domains_from_info(
|
||||
self._info.values()
|
||||
)
|
||||
self._setup_domains_listener(self._last_domains)
|
||||
self._setup_entities_listener(self._last_domains, self._last_entities)
|
||||
|
||||
@callback
|
||||
def _cancel_listener(self, listener_name: str) -> None:
|
||||
if listener_name not in self._listeners:
|
||||
return
|
||||
|
||||
self._listeners.pop(listener_name)()
|
||||
|
||||
@callback
|
||||
def _update_listeners(self) -> None:
|
||||
had_all_listener = _TEMPLATE_ALL_LISTENER in self._listeners
|
||||
|
||||
if self._needs_all_listener:
|
||||
if had_all_listener:
|
||||
return
|
||||
self._last_domains = set()
|
||||
self._last_entities = set()
|
||||
self._cancel_listener(_TEMPLATE_DOMAINS_LISTENER)
|
||||
self._cancel_listener(_TEMPLATE_ENTITIES_LISTENER)
|
||||
self._setup_all_listener()
|
||||
return
|
||||
|
||||
if had_all_listener:
|
||||
self._cancel_listener(_TEMPLATE_ALL_LISTENER)
|
||||
|
||||
entities, domains = _entities_domains_from_info(self._info.values())
|
||||
domains_changed = domains != self._last_domains
|
||||
|
||||
if had_all_listener or domains_changed:
|
||||
domains_changed = True
|
||||
self._cancel_listener(_TEMPLATE_DOMAINS_LISTENER)
|
||||
self._setup_domains_listener(domains)
|
||||
|
||||
if had_all_listener or domains_changed or entities != self._last_entities:
|
||||
self._cancel_listener(_TEMPLATE_ENTITIES_LISTENER)
|
||||
self._setup_entities_listener(domains, entities)
|
||||
|
||||
self._last_domains = domains
|
||||
self._last_entities = entities
|
||||
|
||||
@callback
|
||||
def _setup_entities_listener(self, domains: Set, entities: Set) -> None:
|
||||
if domains:
|
||||
entities = entities.copy()
|
||||
entities.update(self.hass.states.async_entity_ids(domains))
|
||||
|
||||
# Entities has changed to none
|
||||
if not entities:
|
||||
return
|
||||
|
||||
self._listeners[_TEMPLATE_ENTITIES_LISTENER] = async_track_state_change_event(
|
||||
self.hass, entities, self._refresh
|
||||
)
|
||||
|
||||
@callback
|
||||
def _setup_domains_listener(self, domains: Set) -> None:
|
||||
if not domains:
|
||||
return
|
||||
|
||||
self._listeners[_TEMPLATE_DOMAINS_LISTENER] = async_track_state_added_domain(
|
||||
self.hass, domains, self._refresh
|
||||
)
|
||||
|
||||
@callback
|
||||
def _setup_all_listener(self) -> None:
|
||||
self._listeners[_TEMPLATE_ALL_LISTENER] = self.hass.bus.async_listen(
|
||||
EVENT_STATE_CHANGED, self._refresh
|
||||
)
|
||||
assert self._track_state_changes
|
||||
return self._track_state_changes.listeners
|
||||
|
||||
@callback
|
||||
def async_remove(self) -> None:
|
||||
"""Cancel the listener."""
|
||||
for key in list(self._listeners):
|
||||
self._listeners.pop(key)()
|
||||
assert self._track_state_changes
|
||||
self._track_state_changes.async_remove()
|
||||
|
||||
@callback
|
||||
def async_refresh(self) -> None:
|
||||
@ -765,7 +821,10 @@ class _TrackTemplateResultInfo:
|
||||
updates.append(TrackTemplateResult(template, last_result, result))
|
||||
|
||||
if info_changed:
|
||||
self._update_listeners()
|
||||
assert self._track_state_changes
|
||||
self._track_state_changes.async_update_listeners(
|
||||
_render_infos_to_track_states(self._info.values()),
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"Template group %s listens for %s",
|
||||
self._track_templates,
|
||||
@ -1229,7 +1288,10 @@ def process_state_match(
|
||||
return lambda state: state in parameter_set
|
||||
|
||||
|
||||
def _entities_domains_from_info(render_infos: Iterable[RenderInfo]) -> Tuple[Set, Set]:
|
||||
@callback
|
||||
def _entities_domains_from_render_infos(
|
||||
render_infos: Iterable[RenderInfo],
|
||||
) -> Tuple[Set, Set]:
|
||||
"""Combine from multiple RenderInfo."""
|
||||
entities = set()
|
||||
domains = set()
|
||||
@ -1242,3 +1304,29 @@ def _entities_domains_from_info(render_infos: Iterable[RenderInfo]) -> Tuple[Set
|
||||
if render_info.domains_lifecycle:
|
||||
domains.update(render_info.domains_lifecycle)
|
||||
return entities, domains
|
||||
|
||||
|
||||
@callback
|
||||
def _render_infos_needs_all_listener(render_infos: Iterable[RenderInfo]) -> bool:
|
||||
"""Determine if an all listener is needed from RenderInfo."""
|
||||
for render_info in render_infos:
|
||||
# Tracking all states
|
||||
if render_info.all_states or render_info.all_states_lifecycle:
|
||||
return True
|
||||
|
||||
# Previous call had an exception
|
||||
# so we do not know which states
|
||||
# to track
|
||||
if render_info.exception:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@callback
|
||||
def _render_infos_to_track_states(render_infos: Iterable[RenderInfo]) -> TrackStates:
|
||||
"""Create a TrackStates dataclass from the latest RenderInfo."""
|
||||
if _render_infos_needs_all_listener(render_infos):
|
||||
return TrackStates(True, set(), set())
|
||||
|
||||
return TrackStates(False, *_entities_domains_from_render_infos(render_infos))
|
||||
|
@ -14,6 +14,7 @@ from homeassistant.core import callback
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
|
||||
from homeassistant.helpers.event import (
|
||||
TrackStates,
|
||||
TrackTemplate,
|
||||
TrackTemplateResult,
|
||||
async_call_later,
|
||||
@ -23,6 +24,7 @@ from homeassistant.helpers.event import (
|
||||
async_track_state_added_domain,
|
||||
async_track_state_change,
|
||||
async_track_state_change_event,
|
||||
async_track_state_change_filtered,
|
||||
async_track_state_removed_domain,
|
||||
async_track_sunrise,
|
||||
async_track_sunset,
|
||||
@ -255,6 +257,142 @@ async def test_track_state_change(hass):
|
||||
assert len(wildercard_runs) == 6
|
||||
|
||||
|
||||
async def test_async_track_state_change_filtered(hass):
|
||||
"""Test async_track_state_change_filtered."""
|
||||
single_entity_id_tracker = []
|
||||
multiple_entity_id_tracker = []
|
||||
|
||||
@ha.callback
|
||||
def single_run_callback(event):
|
||||
old_state = event.data.get("old_state")
|
||||
new_state = event.data.get("new_state")
|
||||
|
||||
single_entity_id_tracker.append((old_state, new_state))
|
||||
|
||||
@ha.callback
|
||||
def multiple_run_callback(event):
|
||||
old_state = event.data.get("old_state")
|
||||
new_state = event.data.get("new_state")
|
||||
|
||||
multiple_entity_id_tracker.append((old_state, new_state))
|
||||
|
||||
@ha.callback
|
||||
def callback_that_throws(event):
|
||||
raise ValueError
|
||||
|
||||
track_single = async_track_state_change_filtered(
|
||||
hass, TrackStates(False, {"light.bowl"}, None), single_run_callback
|
||||
)
|
||||
assert track_single.listeners == {
|
||||
"all": False,
|
||||
"domains": None,
|
||||
"entities": {"light.bowl"},
|
||||
}
|
||||
|
||||
track_multi = async_track_state_change_filtered(
|
||||
hass, TrackStates(False, {"light.bowl"}, {"switch"}), multiple_run_callback
|
||||
)
|
||||
assert track_multi.listeners == {
|
||||
"all": False,
|
||||
"domains": {"switch"},
|
||||
"entities": {"light.bowl"},
|
||||
}
|
||||
|
||||
track_throws = async_track_state_change_filtered(
|
||||
hass, TrackStates(False, {"light.bowl"}, {"switch"}), callback_that_throws
|
||||
)
|
||||
assert track_throws.listeners == {
|
||||
"all": False,
|
||||
"domains": {"switch"},
|
||||
"entities": {"light.bowl"},
|
||||
}
|
||||
|
||||
# Adding state to state machine
|
||||
hass.states.async_set("light.Bowl", "on")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert single_entity_id_tracker[-1][0] is None
|
||||
assert single_entity_id_tracker[-1][1] is not None
|
||||
assert len(multiple_entity_id_tracker) == 1
|
||||
assert multiple_entity_id_tracker[-1][0] is None
|
||||
assert multiple_entity_id_tracker[-1][1] is not None
|
||||
|
||||
# Set same state should not trigger a state change/listener
|
||||
hass.states.async_set("light.Bowl", "on")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert len(multiple_entity_id_tracker) == 1
|
||||
|
||||
# State change off -> on
|
||||
hass.states.async_set("light.Bowl", "off")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 2
|
||||
assert len(multiple_entity_id_tracker) == 2
|
||||
|
||||
# State change off -> off
|
||||
hass.states.async_set("light.Bowl", "off", {"some_attr": 1})
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 3
|
||||
assert len(multiple_entity_id_tracker) == 3
|
||||
|
||||
# State change off -> on
|
||||
hass.states.async_set("light.Bowl", "on")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 4
|
||||
assert len(multiple_entity_id_tracker) == 4
|
||||
|
||||
hass.states.async_remove("light.bowl")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 5
|
||||
assert single_entity_id_tracker[-1][0] is not None
|
||||
assert single_entity_id_tracker[-1][1] is None
|
||||
assert len(multiple_entity_id_tracker) == 5
|
||||
assert multiple_entity_id_tracker[-1][0] is not None
|
||||
assert multiple_entity_id_tracker[-1][1] is None
|
||||
|
||||
# Set state for different entity id
|
||||
hass.states.async_set("switch.kitchen", "on")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 5
|
||||
assert len(multiple_entity_id_tracker) == 6
|
||||
|
||||
track_single.async_remove()
|
||||
# Ensure unsubing the listener works
|
||||
hass.states.async_set("light.Bowl", "off")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 5
|
||||
assert len(multiple_entity_id_tracker) == 7
|
||||
|
||||
assert track_multi.listeners == {
|
||||
"all": False,
|
||||
"domains": {"switch"},
|
||||
"entities": {"light.bowl"},
|
||||
}
|
||||
track_multi.async_update_listeners(TrackStates(False, {"light.bowl"}, None))
|
||||
assert track_multi.listeners == {
|
||||
"all": False,
|
||||
"domains": None,
|
||||
"entities": {"light.bowl"},
|
||||
}
|
||||
hass.states.async_set("light.Bowl", "on")
|
||||
await hass.async_block_till_done()
|
||||
assert len(multiple_entity_id_tracker) == 8
|
||||
hass.states.async_set("switch.kitchen", "off")
|
||||
await hass.async_block_till_done()
|
||||
assert len(multiple_entity_id_tracker) == 8
|
||||
|
||||
track_multi.async_update_listeners(TrackStates(True, None, None))
|
||||
hass.states.async_set("switch.kitchen", "off")
|
||||
await hass.async_block_till_done()
|
||||
assert len(multiple_entity_id_tracker) == 8
|
||||
hass.states.async_set("switch.any", "off")
|
||||
await hass.async_block_till_done()
|
||||
assert len(multiple_entity_id_tracker) == 9
|
||||
|
||||
track_multi.async_remove()
|
||||
track_throws.async_remove()
|
||||
|
||||
|
||||
async def test_async_track_state_change_event(hass):
|
||||
"""Test async_track_state_change_event."""
|
||||
single_entity_id_tracker = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user