mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add async_track_state_removed_domain to allow tracking when a state is removed from a domain (#39859)
when a state is removed from a domain
This commit is contained in:
parent
101b5b3b35
commit
e208aac834
@ -55,6 +55,9 @@ TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"
|
||||
TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks"
|
||||
TRACK_STATE_ADDED_DOMAIN_LISTENER = "track_state_added_domain_listener"
|
||||
|
||||
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS = "track_state_removed_domain_callbacks"
|
||||
TRACK_STATE_REMOVED_DOMAIN_LISTENER = "track_state_removed_domain_listener"
|
||||
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener"
|
||||
|
||||
@ -235,10 +238,7 @@ def async_track_state_change_event(
|
||||
EVENT_STATE_CHANGED, _async_state_change_dispatcher
|
||||
)
|
||||
|
||||
if isinstance(entity_ids, str):
|
||||
entity_ids = [entity_ids]
|
||||
|
||||
entity_ids = [entity_id.lower() for entity_id in entity_ids]
|
||||
entity_ids = _async_string_to_lower_list(entity_ids)
|
||||
|
||||
for entity_id in entity_ids:
|
||||
entity_callbacks.setdefault(entity_id, []).append(action)
|
||||
@ -315,10 +315,7 @@ def async_track_entity_registry_updated_event(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_updated_dispatcher
|
||||
)
|
||||
|
||||
if isinstance(entity_ids, str):
|
||||
entity_ids = [entity_ids]
|
||||
|
||||
entity_ids = [entity_id.lower() for entity_id in entity_ids]
|
||||
entity_ids = _async_string_to_lower_list(entity_ids)
|
||||
|
||||
for entity_id in entity_ids:
|
||||
entity_callbacks.setdefault(entity_id, []).append(action)
|
||||
@ -337,6 +334,26 @@ def async_track_entity_registry_updated_event(
|
||||
return remove_listener
|
||||
|
||||
|
||||
@callback
|
||||
def _async_dispatch_domain_event(
|
||||
hass: HomeAssistant, event: Event, callbacks: Dict[str, List]
|
||||
) -> None:
|
||||
domain = split_entity_id(event.data["entity_id"])[0]
|
||||
|
||||
if domain not in callbacks and MATCH_ALL not in callbacks:
|
||||
return
|
||||
|
||||
listeners = callbacks.get(domain, []) + callbacks.get(MATCH_ALL, [])
|
||||
|
||||
for action in listeners:
|
||||
try:
|
||||
hass.async_run_job(action, event)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
"Error while processing event %s for domain %s", event, domain
|
||||
)
|
||||
|
||||
|
||||
@bind_hass
|
||||
def async_track_state_added_domain(
|
||||
hass: HomeAssistant,
|
||||
@ -355,27 +372,13 @@ def async_track_state_added_domain(
|
||||
if event.data.get("old_state") is not None:
|
||||
return
|
||||
|
||||
domain = split_entity_id(event.data["entity_id"])[0]
|
||||
|
||||
if domain not in domain_callbacks:
|
||||
return
|
||||
|
||||
for action in domain_callbacks[domain][:]:
|
||||
try:
|
||||
hass.async_run_job(action, event)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
"Error while processing state added for %s", domain
|
||||
)
|
||||
_async_dispatch_domain_event(hass, event, domain_callbacks)
|
||||
|
||||
hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen(
|
||||
EVENT_STATE_CHANGED, _async_state_change_dispatcher
|
||||
)
|
||||
|
||||
if isinstance(domains, str):
|
||||
domains = [domains]
|
||||
|
||||
domains = [domains.lower() for domains in domains]
|
||||
domains = _async_string_to_lower_list(domains)
|
||||
|
||||
for domain in domains:
|
||||
domain_callbacks.setdefault(domain, []).append(action)
|
||||
@ -394,6 +397,57 @@ def async_track_state_added_domain(
|
||||
return remove_listener
|
||||
|
||||
|
||||
@bind_hass
|
||||
def async_track_state_removed_domain(
|
||||
hass: HomeAssistant,
|
||||
domains: Union[str, Iterable[str]],
|
||||
action: Callable[[Event], Any],
|
||||
) -> Callable[[], None]:
|
||||
"""Track state change events when an entity is removed from domains."""
|
||||
|
||||
domain_callbacks = hass.data.setdefault(TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {})
|
||||
|
||||
if TRACK_STATE_REMOVED_DOMAIN_LISTENER not in hass.data:
|
||||
|
||||
@callback
|
||||
def _async_state_change_dispatcher(event: Event) -> None:
|
||||
"""Dispatch state changes by entity_id."""
|
||||
if event.data.get("new_state") is not None:
|
||||
return
|
||||
|
||||
_async_dispatch_domain_event(hass, event, domain_callbacks)
|
||||
|
||||
hass.data[TRACK_STATE_REMOVED_DOMAIN_LISTENER] = hass.bus.async_listen(
|
||||
EVENT_STATE_CHANGED, _async_state_change_dispatcher
|
||||
)
|
||||
|
||||
domains = _async_string_to_lower_list(domains)
|
||||
|
||||
for domain in domains:
|
||||
domain_callbacks.setdefault(domain, []).append(action)
|
||||
|
||||
@callback
|
||||
def remove_listener() -> None:
|
||||
"""Remove state change listener."""
|
||||
_async_remove_indexed_listeners(
|
||||
hass,
|
||||
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS,
|
||||
TRACK_STATE_REMOVED_DOMAIN_LISTENER,
|
||||
domains,
|
||||
action,
|
||||
)
|
||||
|
||||
return remove_listener
|
||||
|
||||
|
||||
@callback
|
||||
def _async_string_to_lower_list(instr: Union[str, Iterable[str]]) -> List[str]:
|
||||
if isinstance(instr, str):
|
||||
return [instr.lower()]
|
||||
|
||||
return [mstr.lower() for mstr in instr]
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_track_template(
|
||||
|
@ -23,6 +23,7 @@ from homeassistant.helpers.event import (
|
||||
async_track_state_added_domain,
|
||||
async_track_state_change,
|
||||
async_track_state_change_event,
|
||||
async_track_state_removed_domain,
|
||||
async_track_sunrise,
|
||||
async_track_sunset,
|
||||
async_track_template,
|
||||
@ -429,6 +430,132 @@ async def test_async_track_state_added_domain(hass):
|
||||
unsub_throws()
|
||||
|
||||
|
||||
async def test_async_track_state_removed_domain(hass):
|
||||
"""Test async_track_state_removed_domain."""
|
||||
single_entity_id_tracker = []
|
||||
multiple_entity_id_tracker = []
|
||||
|
||||
@ha.callback
|
||||
def single_run_callback(event):
|
||||
old_state = event.data.get("old_state")
|
||||
new_state = event.data.get("new_state")
|
||||
|
||||
single_entity_id_tracker.append((old_state, new_state))
|
||||
|
||||
@ha.callback
|
||||
def multiple_run_callback(event):
|
||||
old_state = event.data.get("old_state")
|
||||
new_state = event.data.get("new_state")
|
||||
|
||||
multiple_entity_id_tracker.append((old_state, new_state))
|
||||
|
||||
@ha.callback
|
||||
def callback_that_throws(event):
|
||||
raise ValueError
|
||||
|
||||
unsub_single = async_track_state_removed_domain(hass, "light", single_run_callback)
|
||||
unsub_multi = async_track_state_removed_domain(
|
||||
hass, ["light", "switch"], multiple_run_callback
|
||||
)
|
||||
unsub_throws = async_track_state_removed_domain(
|
||||
hass, ["light", "switch"], callback_that_throws
|
||||
)
|
||||
|
||||
# Adding state to state machine
|
||||
hass.states.async_set("light.Bowl", "on")
|
||||
hass.states.async_remove("light.Bowl")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert single_entity_id_tracker[-1][1] is None
|
||||
assert single_entity_id_tracker[-1][0] is not None
|
||||
assert len(multiple_entity_id_tracker) == 1
|
||||
assert multiple_entity_id_tracker[-1][1] is None
|
||||
assert multiple_entity_id_tracker[-1][0] is not None
|
||||
|
||||
# Added and than removed (light)
|
||||
hass.states.async_set("light.Bowl", "on")
|
||||
hass.states.async_remove("light.Bowl")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 2
|
||||
assert len(multiple_entity_id_tracker) == 2
|
||||
|
||||
# Added and than removed (light)
|
||||
hass.states.async_set("light.Bowl", "off")
|
||||
hass.states.async_remove("light.Bowl")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 3
|
||||
assert len(multiple_entity_id_tracker) == 3
|
||||
|
||||
# Added and than removed (light)
|
||||
hass.states.async_set("light.Bowl", "off", {"some_attr": 1})
|
||||
hass.states.async_remove("light.Bowl")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 4
|
||||
assert len(multiple_entity_id_tracker) == 4
|
||||
|
||||
# Added and than removed (switch)
|
||||
hass.states.async_set("switch.kitchen", "on")
|
||||
hass.states.async_remove("switch.kitchen")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 4
|
||||
assert len(multiple_entity_id_tracker) == 5
|
||||
|
||||
unsub_single()
|
||||
# Ensure unsubing the listener works
|
||||
hass.states.async_set("light.new", "off")
|
||||
hass.states.async_remove("light.new")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 4
|
||||
assert len(multiple_entity_id_tracker) == 6
|
||||
|
||||
unsub_multi()
|
||||
unsub_throws()
|
||||
|
||||
|
||||
async def test_async_track_state_removed_domain_match_all(hass):
|
||||
"""Test async_track_state_removed_domain with a match_all."""
|
||||
single_entity_id_tracker = []
|
||||
match_all_entity_id_tracker = []
|
||||
|
||||
@ha.callback
|
||||
def single_run_callback(event):
|
||||
old_state = event.data.get("old_state")
|
||||
new_state = event.data.get("new_state")
|
||||
|
||||
single_entity_id_tracker.append((old_state, new_state))
|
||||
|
||||
@ha.callback
|
||||
def match_all_run_callback(event):
|
||||
old_state = event.data.get("old_state")
|
||||
new_state = event.data.get("new_state")
|
||||
|
||||
match_all_entity_id_tracker.append((old_state, new_state))
|
||||
|
||||
unsub_single = async_track_state_removed_domain(hass, "light", single_run_callback)
|
||||
unsub_match_all = async_track_state_removed_domain(
|
||||
hass, MATCH_ALL, match_all_run_callback
|
||||
)
|
||||
hass.states.async_set("light.new", "off")
|
||||
hass.states.async_remove("light.new")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert len(match_all_entity_id_tracker) == 1
|
||||
|
||||
hass.states.async_set("switch.new", "off")
|
||||
hass.states.async_remove("switch.new")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert len(match_all_entity_id_tracker) == 2
|
||||
|
||||
unsub_match_all()
|
||||
unsub_single()
|
||||
hass.states.async_set("switch.new", "off")
|
||||
hass.states.async_remove("switch.new")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert len(match_all_entity_id_tracker) == 2
|
||||
|
||||
|
||||
async def test_track_template(hass):
|
||||
"""Test tracking template."""
|
||||
specific_runs = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user