mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Determine how to run listeners at setup time instead of execution time (#41304)
This commit is contained in:
parent
8d94dff75c
commit
9e1461da62
@ -149,6 +149,52 @@ def is_callback(func: Callable[..., Any]) -> bool:
|
||||
return getattr(func, "_hass_callback", False) is True
|
||||
|
||||
|
||||
@enum.unique
|
||||
class HassJobType(enum.Enum):
|
||||
"""Represent a job type."""
|
||||
|
||||
Coroutine = 1
|
||||
Coroutinefunction = 2
|
||||
Callback = 3
|
||||
Executor = 4
|
||||
|
||||
|
||||
class HassJob:
|
||||
"""Represent a job to be run later.
|
||||
|
||||
We check the callable type in advance
|
||||
so we can avoid checking it every time
|
||||
we run the job.
|
||||
"""
|
||||
|
||||
__slots__ = ("job_type", "target")
|
||||
|
||||
def __init__(self, target: Callable):
|
||||
"""Create a job object."""
|
||||
self.target = target
|
||||
self.job_type = _get_callable_job_type(target)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the job."""
|
||||
return f"<Job {self.job_type} {self.target}>"
|
||||
|
||||
|
||||
def _get_callable_job_type(target: Callable) -> HassJobType:
|
||||
"""Determine the job type from the callable."""
|
||||
# Check for partials to properly determine if coroutine function
|
||||
check_target = target
|
||||
while isinstance(check_target, functools.partial):
|
||||
check_target = check_target.func
|
||||
|
||||
if asyncio.iscoroutine(check_target):
|
||||
return HassJobType.Coroutine
|
||||
if asyncio.iscoroutinefunction(check_target):
|
||||
return HassJobType.Coroutinefunction
|
||||
if is_callback(check_target):
|
||||
return HassJobType.Callback
|
||||
return HassJobType.Executor
|
||||
|
||||
|
||||
class CoreState(enum.Enum):
|
||||
"""Represent the current state of Home Assistant."""
|
||||
|
||||
@ -306,24 +352,32 @@ class HomeAssistant:
|
||||
if target is None:
|
||||
raise ValueError("Don't call async_add_job with None")
|
||||
|
||||
task = None
|
||||
return self.async_add_hass_job(HassJob(target), *args)
|
||||
|
||||
# Check for partials to properly determine if coroutine function
|
||||
check_target = target
|
||||
while isinstance(check_target, functools.partial):
|
||||
check_target = check_target.func
|
||||
@callback
|
||||
def async_add_hass_job(
|
||||
self, hassjob: HassJob, *args: Any
|
||||
) -> Optional[asyncio.Future]:
|
||||
"""Add a HassJob from within the event loop.
|
||||
|
||||
if asyncio.iscoroutine(check_target):
|
||||
task = self.loop.create_task(target) # type: ignore
|
||||
elif asyncio.iscoroutinefunction(check_target):
|
||||
task = self.loop.create_task(target(*args))
|
||||
elif is_callback(check_target):
|
||||
self.loop.call_soon(target, *args)
|
||||
This method must be run in the event loop.
|
||||
hassjob: HassJob to call.
|
||||
args: parameters for method to call.
|
||||
"""
|
||||
if hassjob.job_type == HassJobType.Coroutine:
|
||||
task = self.loop.create_task(hassjob.target) # type: ignore
|
||||
elif hassjob.job_type == HassJobType.Coroutinefunction:
|
||||
task = self.loop.create_task(hassjob.target(*args))
|
||||
elif hassjob.job_type == HassJobType.Callback:
|
||||
self.loop.call_soon(hassjob.target, *args)
|
||||
return None
|
||||
else:
|
||||
task = self.loop.run_in_executor(None, target, *args) # type: ignore
|
||||
task = self.loop.run_in_executor( # type: ignore
|
||||
None, hassjob.target, *args
|
||||
)
|
||||
|
||||
# If a task is scheduled
|
||||
if self._track_task and task is not None:
|
||||
if self._track_task:
|
||||
self._pending_tasks.append(task)
|
||||
|
||||
return task
|
||||
@ -366,6 +420,20 @@ class HomeAssistant:
|
||||
"""Stop track tasks so you can't wait for all tasks to be done."""
|
||||
self._track_task = False
|
||||
|
||||
@callback
|
||||
def async_run_hass_job(self, hassjob: HassJob, *args: Any) -> None:
|
||||
"""Run a HassJob from within the event loop.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
||||
hassjob: HassJob
|
||||
args: parameters for method to call.
|
||||
"""
|
||||
if hassjob.job_type == HassJobType.Callback:
|
||||
hassjob.target(*args)
|
||||
else:
|
||||
self.async_add_hass_job(hassjob, *args)
|
||||
|
||||
@callback
|
||||
def async_run_job(
|
||||
self, target: Callable[..., Union[None, Awaitable]], *args: Any
|
||||
@ -377,14 +445,7 @@ class HomeAssistant:
|
||||
target: target to call.
|
||||
args: parameters for method to call.
|
||||
"""
|
||||
if (
|
||||
not asyncio.iscoroutine(target)
|
||||
and not asyncio.iscoroutinefunction(target)
|
||||
and is_callback(target)
|
||||
):
|
||||
target(*args)
|
||||
else:
|
||||
self.async_add_job(target, *args)
|
||||
self.async_run_hass_job(HassJob(target), *args)
|
||||
|
||||
def block_till_done(self) -> None:
|
||||
"""Block until all pending work is done."""
|
||||
@ -592,7 +653,7 @@ class EventBus:
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize a new event bus."""
|
||||
self._listeners: Dict[str, List[Callable]] = {}
|
||||
self._listeners: Dict[str, List[HassJob]] = {}
|
||||
self._hass = hass
|
||||
|
||||
@callback
|
||||
@ -648,8 +709,8 @@ class EventBus:
|
||||
if not listeners:
|
||||
return
|
||||
|
||||
for func in listeners:
|
||||
self._hass.async_add_job(func, event)
|
||||
for job in listeners:
|
||||
self._hass.async_add_hass_job(job, event)
|
||||
|
||||
def listen(self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
|
||||
"""Listen for all events or events of a specific type.
|
||||
@ -676,14 +737,15 @@ class EventBus:
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if event_type in self._listeners:
|
||||
self._listeners[event_type].append(listener)
|
||||
else:
|
||||
self._listeners[event_type] = [listener]
|
||||
return self._async_listen_job(event_type, HassJob(listener))
|
||||
|
||||
@callback
|
||||
def _async_listen_job(self, event_type: str, hassjob: HassJob) -> CALLBACK_TYPE:
|
||||
self._listeners.setdefault(event_type, []).append(hassjob)
|
||||
|
||||
def remove_listener() -> None:
|
||||
"""Remove the listener."""
|
||||
self._async_remove_listener(event_type, listener)
|
||||
self._async_remove_listener(event_type, hassjob)
|
||||
|
||||
return remove_listener
|
||||
|
||||
@ -716,31 +778,36 @@ class EventBus:
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
job: Optional[HassJob] = None
|
||||
|
||||
@callback
|
||||
def onetime_listener(event: Event) -> None:
|
||||
def _onetime_listener(event: Event) -> None:
|
||||
"""Remove listener from event bus and then fire listener."""
|
||||
if hasattr(onetime_listener, "run"):
|
||||
nonlocal job
|
||||
if hasattr(_onetime_listener, "run"):
|
||||
return
|
||||
# Set variable so that we will never run twice.
|
||||
# Because the event bus loop might have async_fire queued multiple
|
||||
# times, its possible this listener may already be lined up
|
||||
# multiple times as well.
|
||||
# This will make sure the second time it does nothing.
|
||||
setattr(onetime_listener, "run", True)
|
||||
self._async_remove_listener(event_type, onetime_listener)
|
||||
setattr(_onetime_listener, "run", True)
|
||||
assert job is not None
|
||||
self._async_remove_listener(event_type, job)
|
||||
self._hass.async_run_job(listener, event)
|
||||
|
||||
return self.async_listen(event_type, onetime_listener)
|
||||
job = HassJob(_onetime_listener)
|
||||
|
||||
return self._async_listen_job(event_type, job)
|
||||
|
||||
@callback
|
||||
def _async_remove_listener(self, event_type: str, listener: Callable) -> None:
|
||||
def _async_remove_listener(self, event_type: str, hassjob: HassJob) -> None:
|
||||
"""Remove a listener of a specific event_type.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
try:
|
||||
self._listeners[event_type].remove(listener)
|
||||
self._listeners[event_type].remove(hassjob)
|
||||
|
||||
# delete event_type list if empty
|
||||
if not self._listeners[event_type]:
|
||||
@ -748,7 +815,7 @@ class EventBus:
|
||||
except (KeyError, ValueError):
|
||||
# KeyError is key event_type listener did not exist
|
||||
# ValueError if listener did not exist within event_type
|
||||
_LOGGER.warning("Unable to remove unknown listener %s", listener)
|
||||
_LOGGER.warning("Unable to remove unknown job listener %s", hassjob)
|
||||
|
||||
|
||||
class State:
|
||||
@ -1094,7 +1161,7 @@ class StateMachine:
|
||||
class Service:
|
||||
"""Representation of a callable service."""
|
||||
|
||||
__slots__ = ["func", "schema", "is_callback", "is_coroutinefunction"]
|
||||
__slots__ = ["job", "schema"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1103,13 +1170,8 @@ class Service:
|
||||
context: Optional[Context] = None,
|
||||
) -> None:
|
||||
"""Initialize a service."""
|
||||
self.func = func
|
||||
self.job = HassJob(func)
|
||||
self.schema = schema
|
||||
# Properly detect wrapped functions
|
||||
while isinstance(func, functools.partial):
|
||||
func = func.func
|
||||
self.is_callback = is_callback(func)
|
||||
self.is_coroutinefunction = asyncio.iscoroutinefunction(func)
|
||||
|
||||
|
||||
class ServiceCall:
|
||||
@ -1377,12 +1439,12 @@ class ServiceRegistry:
|
||||
self, handler: Service, service_call: ServiceCall
|
||||
) -> None:
|
||||
"""Execute a service."""
|
||||
if handler.is_coroutinefunction:
|
||||
await handler.func(service_call)
|
||||
elif handler.is_callback:
|
||||
handler.func(service_call)
|
||||
if handler.job.job_type == HassJobType.Coroutinefunction:
|
||||
await handler.job.target(service_call)
|
||||
elif handler.job.job_type == HassJobType.Callback:
|
||||
handler.job.target(service_call)
|
||||
else:
|
||||
await self._hass.async_add_executor_job(handler.func, service_call)
|
||||
await self._hass.async_add_executor_job(handler.job.target, service_call)
|
||||
|
||||
|
||||
class Config:
|
||||
|
@ -2,7 +2,7 @@
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import HassJob, callback
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
from homeassistant.util.logging import catch_log_exception
|
||||
@ -41,26 +41,25 @@ def async_dispatcher_connect(
|
||||
if DATA_DISPATCHER not in hass.data:
|
||||
hass.data[DATA_DISPATCHER] = {}
|
||||
|
||||
if signal not in hass.data[DATA_DISPATCHER]:
|
||||
hass.data[DATA_DISPATCHER][signal] = []
|
||||
|
||||
wrapped_target = catch_log_exception(
|
||||
target,
|
||||
lambda *args: "Exception in {} when dispatching '{}': {}".format(
|
||||
# Functions wrapped in partial do not have a __name__
|
||||
getattr(target, "__name__", None) or str(target),
|
||||
signal,
|
||||
args,
|
||||
),
|
||||
job = HassJob(
|
||||
catch_log_exception(
|
||||
target,
|
||||
lambda *args: "Exception in {} when dispatching '{}': {}".format(
|
||||
# Functions wrapped in partial do not have a __name__
|
||||
getattr(target, "__name__", None) or str(target),
|
||||
signal,
|
||||
args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
hass.data[DATA_DISPATCHER][signal].append(wrapped_target)
|
||||
hass.data[DATA_DISPATCHER].setdefault(signal, []).append(job)
|
||||
|
||||
@callback
|
||||
def async_remove_dispatcher() -> None:
|
||||
"""Remove signal listener."""
|
||||
try:
|
||||
hass.data[DATA_DISPATCHER][signal].remove(wrapped_target)
|
||||
hass.data[DATA_DISPATCHER][signal].remove(job)
|
||||
except (KeyError, ValueError):
|
||||
# KeyError is key target listener did not exist
|
||||
# ValueError if listener did not exist within signal
|
||||
@ -84,5 +83,5 @@ def async_dispatcher_send(hass: HomeAssistantType, signal: str, *args: Any) -> N
|
||||
"""
|
||||
target_list = hass.data.get(DATA_DISPATCHER, {}).get(signal, [])
|
||||
|
||||
for target in target_list:
|
||||
hass.async_add_job(target, *args)
|
||||
for job in target_list:
|
||||
hass.async_add_hass_job(job, *args)
|
||||
|
@ -34,6 +34,7 @@ from homeassistant.const import (
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
Event,
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
State,
|
||||
callback,
|
||||
@ -174,6 +175,8 @@ def async_track_state_change(
|
||||
else:
|
||||
entity_ids = tuple(entity_id.lower() for entity_id in entity_ids)
|
||||
|
||||
job = HassJob(action)
|
||||
|
||||
@callback
|
||||
def state_change_listener(event: Event) -> None:
|
||||
"""Handle specific state changes."""
|
||||
@ -192,8 +195,8 @@ def async_track_state_change(
|
||||
if not match_to_state(new_state):
|
||||
return
|
||||
|
||||
hass.async_run_job(
|
||||
action,
|
||||
hass.async_run_hass_job(
|
||||
job,
|
||||
event.data.get("entity_id"),
|
||||
event.data.get("old_state"),
|
||||
event.data.get("new_state"),
|
||||
@ -246,9 +249,9 @@ def async_track_state_change_event(
|
||||
if entity_id not in entity_callbacks:
|
||||
return
|
||||
|
||||
for action in entity_callbacks[entity_id][:]:
|
||||
for job in entity_callbacks[entity_id][:]:
|
||||
try:
|
||||
hass.async_run_job(action, event)
|
||||
hass.async_run_hass_job(job, event)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
"Error while processing state changed for %s", entity_id
|
||||
@ -258,10 +261,12 @@ def async_track_state_change_event(
|
||||
EVENT_STATE_CHANGED, _async_state_change_dispatcher
|
||||
)
|
||||
|
||||
job = HassJob(action)
|
||||
|
||||
entity_ids = _async_string_to_lower_list(entity_ids)
|
||||
|
||||
for entity_id in entity_ids:
|
||||
entity_callbacks.setdefault(entity_id, []).append(action)
|
||||
entity_callbacks.setdefault(entity_id, []).append(job)
|
||||
|
||||
@callback
|
||||
def remove_listener() -> None:
|
||||
@ -271,7 +276,7 @@ def async_track_state_change_event(
|
||||
TRACK_STATE_CHANGE_CALLBACKS,
|
||||
TRACK_STATE_CHANGE_LISTENER,
|
||||
entity_ids,
|
||||
action,
|
||||
job,
|
||||
)
|
||||
|
||||
return remove_listener
|
||||
@ -283,14 +288,14 @@ def _async_remove_indexed_listeners(
|
||||
data_key: str,
|
||||
listener_key: str,
|
||||
storage_keys: Iterable[str],
|
||||
action: Callable[[Event], Any],
|
||||
job: HassJob,
|
||||
) -> None:
|
||||
"""Remove a listener."""
|
||||
|
||||
callbacks = hass.data[data_key]
|
||||
|
||||
for storage_key in storage_keys:
|
||||
callbacks[storage_key].remove(action)
|
||||
callbacks[storage_key].remove(job)
|
||||
if len(callbacks[storage_key]) == 0:
|
||||
del callbacks[storage_key]
|
||||
|
||||
@ -322,9 +327,9 @@ def async_track_entity_registry_updated_event(
|
||||
if entity_id not in entity_callbacks:
|
||||
return
|
||||
|
||||
for action in entity_callbacks[entity_id][:]:
|
||||
for job in entity_callbacks[entity_id][:]:
|
||||
try:
|
||||
hass.async_run_job(action, event)
|
||||
hass.async_run_hass_job(job, event)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
"Error while processing entity registry update for %s",
|
||||
@ -335,10 +340,12 @@ def async_track_entity_registry_updated_event(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_updated_dispatcher
|
||||
)
|
||||
|
||||
job = HassJob(action)
|
||||
|
||||
entity_ids = _async_string_to_lower_list(entity_ids)
|
||||
|
||||
for entity_id in entity_ids:
|
||||
entity_callbacks.setdefault(entity_id, []).append(action)
|
||||
entity_callbacks.setdefault(entity_id, []).append(job)
|
||||
|
||||
@callback
|
||||
def remove_listener() -> None:
|
||||
@ -348,7 +355,7 @@ def async_track_entity_registry_updated_event(
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS,
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER,
|
||||
entity_ids,
|
||||
action,
|
||||
job,
|
||||
)
|
||||
|
||||
return remove_listener
|
||||
@ -365,9 +372,9 @@ def _async_dispatch_domain_event(
|
||||
|
||||
listeners = callbacks.get(domain, []) + callbacks.get(MATCH_ALL, [])
|
||||
|
||||
for action in listeners:
|
||||
for job in listeners:
|
||||
try:
|
||||
hass.async_run_job(action, event)
|
||||
hass.async_run_hass_job(job, event)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
"Error while processing event %s for domain %s", event, domain
|
||||
@ -398,10 +405,12 @@ def async_track_state_added_domain(
|
||||
EVENT_STATE_CHANGED, _async_state_change_dispatcher
|
||||
)
|
||||
|
||||
job = HassJob(action)
|
||||
|
||||
domains = _async_string_to_lower_list(domains)
|
||||
|
||||
for domain in domains:
|
||||
domain_callbacks.setdefault(domain, []).append(action)
|
||||
domain_callbacks.setdefault(domain, []).append(job)
|
||||
|
||||
@callback
|
||||
def remove_listener() -> None:
|
||||
@ -411,7 +420,7 @@ def async_track_state_added_domain(
|
||||
TRACK_STATE_ADDED_DOMAIN_CALLBACKS,
|
||||
TRACK_STATE_ADDED_DOMAIN_LISTENER,
|
||||
domains,
|
||||
action,
|
||||
job,
|
||||
)
|
||||
|
||||
return remove_listener
|
||||
@ -441,10 +450,12 @@ def async_track_state_removed_domain(
|
||||
EVENT_STATE_CHANGED, _async_state_change_dispatcher
|
||||
)
|
||||
|
||||
job = HassJob(action)
|
||||
|
||||
domains = _async_string_to_lower_list(domains)
|
||||
|
||||
for domain in domains:
|
||||
domain_callbacks.setdefault(domain, []).append(action)
|
||||
domain_callbacks.setdefault(domain, []).append(job)
|
||||
|
||||
@callback
|
||||
def remove_listener() -> None:
|
||||
@ -454,7 +465,7 @@ def async_track_state_removed_domain(
|
||||
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS,
|
||||
TRACK_STATE_REMOVED_DOMAIN_LISTENER,
|
||||
domains,
|
||||
action,
|
||||
job,
|
||||
)
|
||||
|
||||
return remove_listener
|
||||
@ -665,6 +676,8 @@ def async_track_template(
|
||||
|
||||
"""
|
||||
|
||||
job = HassJob(action)
|
||||
|
||||
@callback
|
||||
def _template_changed_listener(
|
||||
event: Event, updates: List[TrackTemplateResult]
|
||||
@ -691,8 +704,8 @@ def async_track_template(
|
||||
):
|
||||
return
|
||||
|
||||
hass.async_run_job(
|
||||
action,
|
||||
hass.async_run_hass_job(
|
||||
job,
|
||||
event.data.get("entity_id"),
|
||||
event.data.get("old_state"),
|
||||
event.data.get("new_state"),
|
||||
@ -719,7 +732,7 @@ class _TrackTemplateResultInfo:
|
||||
):
|
||||
"""Handle removal / refresh of tracker init."""
|
||||
self.hass = hass
|
||||
self._action = action
|
||||
self._job = HassJob(action)
|
||||
|
||||
for track_template_ in track_templates:
|
||||
track_template_.template.hass = hass
|
||||
@ -866,7 +879,7 @@ class _TrackTemplateResultInfo:
|
||||
for track_result in updates:
|
||||
self._last_result[track_result.template] = track_result.result
|
||||
|
||||
self.hass.async_run_job(self._action, event, updates)
|
||||
self.hass.async_run_hass_job(self._job, event, updates)
|
||||
|
||||
|
||||
TrackTemplateResultListener = Callable[
|
||||
@ -951,6 +964,8 @@ def async_track_same_state(
|
||||
async_remove_state_for_cancel: Optional[CALLBACK_TYPE] = None
|
||||
async_remove_state_for_listener: Optional[CALLBACK_TYPE] = None
|
||||
|
||||
job = HassJob(action)
|
||||
|
||||
@callback
|
||||
def clear_listener() -> None:
|
||||
"""Clear all unsub listener."""
|
||||
@ -969,7 +984,7 @@ def async_track_same_state(
|
||||
nonlocal async_remove_state_for_listener
|
||||
async_remove_state_for_listener = None
|
||||
clear_listener()
|
||||
hass.async_run_job(action)
|
||||
hass.async_run_hass_job(job)
|
||||
|
||||
@callback
|
||||
def state_for_cancel_listener(event: Event) -> None:
|
||||
@ -1005,14 +1020,18 @@ track_same_state = threaded_listener_factory(async_track_same_state)
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_track_point_in_time(
|
||||
hass: HomeAssistant, action: Callable[..., None], point_in_time: datetime
|
||||
hass: HomeAssistant,
|
||||
action: Union[HassJob, Callable[..., None]],
|
||||
point_in_time: datetime,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Add a listener that fires once after a specific point in time."""
|
||||
|
||||
job = action if isinstance(action, HassJob) else HassJob(action)
|
||||
|
||||
@callback
|
||||
def utc_converter(utc_now: datetime) -> None:
|
||||
"""Convert passed in UTC now to local now."""
|
||||
hass.async_run_job(action, dt_util.as_local(utc_now))
|
||||
hass.async_run_hass_job(job, dt_util.as_local(utc_now))
|
||||
|
||||
return async_track_point_in_utc_time(hass, utc_converter, point_in_time)
|
||||
|
||||
@ -1023,16 +1042,22 @@ track_point_in_time = threaded_listener_factory(async_track_point_in_time)
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_track_point_in_utc_time(
|
||||
hass: HomeAssistant, action: Callable[..., Any], point_in_time: datetime
|
||||
hass: HomeAssistant,
|
||||
action: Union[HassJob, Callable[..., None]],
|
||||
point_in_time: datetime,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Add a listener that fires once after a specific point in UTC time."""
|
||||
# Ensure point_in_time is UTC
|
||||
utc_point_in_time = dt_util.as_utc(point_in_time)
|
||||
|
||||
# Since this is called once, we accept a HassJob so we can avoid
|
||||
# having to figure out how to call the action every time its called.
|
||||
job = action if isinstance(action, HassJob) else HassJob(action)
|
||||
|
||||
cancel_callback = hass.loop.call_at(
|
||||
hass.loop.time() + point_in_time.timestamp() - time.time(),
|
||||
hass.async_run_job,
|
||||
action,
|
||||
hass.async_run_hass_job,
|
||||
job,
|
||||
utc_point_in_time,
|
||||
)
|
||||
|
||||
@ -1050,7 +1075,7 @@ track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_tim
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_call_later(
|
||||
hass: HomeAssistant, delay: float, action: Callable[..., None]
|
||||
hass: HomeAssistant, delay: float, action: Union[HassJob, Callable[..., None]]
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Add a listener that is called in <delay>."""
|
||||
return async_track_point_in_utc_time(
|
||||
@ -1071,6 +1096,8 @@ def async_track_time_interval(
|
||||
"""Add a listener that fires repetitively at every timedelta interval."""
|
||||
remove = None
|
||||
|
||||
job = HassJob(action)
|
||||
|
||||
def next_interval() -> datetime:
|
||||
"""Return the next interval."""
|
||||
return dt_util.utcnow() + interval
|
||||
@ -1080,7 +1107,7 @@ def async_track_time_interval(
|
||||
"""Handle elapsed intervals."""
|
||||
nonlocal remove
|
||||
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
|
||||
hass.async_run_job(action, now)
|
||||
hass.async_run_hass_job(job, now)
|
||||
|
||||
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
|
||||
|
||||
@ -1196,6 +1223,8 @@ def async_track_utc_time_change(
|
||||
local: bool = False,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Add a listener that will fire if time matches a pattern."""
|
||||
|
||||
job = HassJob(action)
|
||||
# We do not have to wrap the function with time pattern matching logic
|
||||
# if no pattern given
|
||||
if all(val is None for val in (hour, minute, second)):
|
||||
@ -1203,7 +1232,7 @@ def async_track_utc_time_change(
|
||||
@callback
|
||||
def time_change_listener(event: Event) -> None:
|
||||
"""Fire every time event that comes in."""
|
||||
hass.async_run_job(action, event.data[ATTR_NOW])
|
||||
hass.async_run_hass_job(job, event.data[ATTR_NOW])
|
||||
|
||||
return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener)
|
||||
|
||||
@ -1233,7 +1262,7 @@ def async_track_utc_time_change(
|
||||
nonlocal next_time, cancel_callback
|
||||
|
||||
now = pattern_utc_now()
|
||||
hass.async_run_job(action, dt_util.as_local(now) if local else now)
|
||||
hass.async_run_hass_job(job, dt_util.as_local(now) if local else now)
|
||||
|
||||
calculate_next(now + timedelta(seconds=1))
|
||||
|
||||
|
@ -9,7 +9,7 @@ import urllib.error
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||
from homeassistant.helpers import entity, event
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
@ -48,6 +48,7 @@ class DataUpdateCoordinator(Generic[T]):
|
||||
self.data: Optional[T] = None
|
||||
|
||||
self._listeners: List[CALLBACK_TYPE] = []
|
||||
self._job = HassJob(self._handle_refresh_interval)
|
||||
self._unsub_refresh: Optional[CALLBACK_TYPE] = None
|
||||
self._request_refresh_task: Optional[asyncio.TimerHandle] = None
|
||||
self.last_update_success = True
|
||||
@ -108,7 +109,7 @@ class DataUpdateCoordinator(Generic[T]):
|
||||
# as long as the update process takes less than a second
|
||||
self._unsub_refresh = event.async_track_point_in_utc_time(
|
||||
self.hass,
|
||||
self._handle_refresh_interval,
|
||||
self._job,
|
||||
utcnow().replace(microsecond=0) + self.update_interval,
|
||||
)
|
||||
|
||||
|
@ -269,7 +269,7 @@ async def test_turn_on_to_not_block_for_domains_without_service(hass):
|
||||
"homeassistant.core.ServiceRegistry.async_call",
|
||||
return_value=None,
|
||||
) as mock_call:
|
||||
await service.func(service_call)
|
||||
await service.job.target(service_call)
|
||||
|
||||
assert mock_call.call_count == 2
|
||||
assert mock_call.call_args_list[0][0] == (
|
||||
|
@ -415,6 +415,10 @@ def legacy_patchable_time():
|
||||
# Ensure point_in_time is UTC
|
||||
point_in_time = event.dt_util.as_utc(point_in_time)
|
||||
|
||||
# Since this is called once, we accept a HassJob so we can avoid
|
||||
# having to figure out how to call the action every time its called.
|
||||
job = action if isinstance(action, ha.HassJob) else ha.HassJob(action)
|
||||
|
||||
@ha.callback
|
||||
def point_in_time_listener(event):
|
||||
"""Listen for matching time_changed events."""
|
||||
@ -431,7 +435,7 @@ def legacy_patchable_time():
|
||||
setattr(point_in_time_listener, "run", True)
|
||||
async_unsub()
|
||||
|
||||
hass.async_run_job(action, now)
|
||||
hass.async_run_hass_job(job, now)
|
||||
|
||||
async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED, point_in_time_listener)
|
||||
|
||||
@ -443,6 +447,8 @@ def legacy_patchable_time():
|
||||
hass, action, hour=None, minute=None, second=None, local=False
|
||||
):
|
||||
"""Add a listener that will fire if time matches a pattern."""
|
||||
|
||||
job = ha.HassJob(action)
|
||||
# We do not have to wrap the function with time pattern matching logic
|
||||
# if no pattern given
|
||||
if all(val is None for val in (hour, minute, second)):
|
||||
@ -450,7 +456,7 @@ def legacy_patchable_time():
|
||||
@ha.callback
|
||||
def time_change_listener(ev) -> None:
|
||||
"""Fire every time event that comes in."""
|
||||
hass.async_run_job(action, ev.data[ATTR_NOW])
|
||||
hass.async_run_hass_job(job, ev.data[ATTR_NOW])
|
||||
|
||||
return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener)
|
||||
|
||||
@ -487,8 +493,8 @@ def legacy_patchable_time():
|
||||
last_now = now
|
||||
|
||||
if next_time <= now:
|
||||
hass.async_run_job(
|
||||
action, event.dt_util.as_local(now) if local else now
|
||||
hass.async_run_hass_job(
|
||||
job, event.dt_util.as_local(now) if local else now
|
||||
)
|
||||
calculate_next(now + datetime.timedelta(seconds=1))
|
||||
|
||||
|
@ -48,43 +48,43 @@ def test_split_entity_id():
|
||||
assert ha.split_entity_id("domain.object_id") == ["domain", "object_id"]
|
||||
|
||||
|
||||
def test_async_add_job_schedule_callback():
|
||||
def test_async_add_hass_job_schedule_callback():
|
||||
"""Test that we schedule coroutines and add jobs to the job pool."""
|
||||
hass = MagicMock()
|
||||
job = MagicMock()
|
||||
|
||||
ha.HomeAssistant.async_add_job(hass, ha.callback(job))
|
||||
ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(ha.callback(job)))
|
||||
assert len(hass.loop.call_soon.mock_calls) == 1
|
||||
assert len(hass.loop.create_task.mock_calls) == 0
|
||||
assert len(hass.add_job.mock_calls) == 0
|
||||
|
||||
|
||||
def test_async_add_job_schedule_partial_callback():
|
||||
def test_async_add_hass_job_schedule_partial_callback():
|
||||
"""Test that we schedule partial coros and add jobs to the job pool."""
|
||||
hass = MagicMock()
|
||||
job = MagicMock()
|
||||
partial = functools.partial(ha.callback(job))
|
||||
|
||||
ha.HomeAssistant.async_add_job(hass, partial)
|
||||
ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(partial))
|
||||
assert len(hass.loop.call_soon.mock_calls) == 1
|
||||
assert len(hass.loop.create_task.mock_calls) == 0
|
||||
assert len(hass.add_job.mock_calls) == 0
|
||||
|
||||
|
||||
def test_async_add_job_schedule_coroutinefunction(loop):
|
||||
def test_async_add_hass_job_schedule_coroutinefunction(loop):
|
||||
"""Test that we schedule coroutines and add jobs to the job pool."""
|
||||
hass = MagicMock(loop=MagicMock(wraps=loop))
|
||||
|
||||
async def job():
|
||||
pass
|
||||
|
||||
ha.HomeAssistant.async_add_job(hass, job)
|
||||
ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(job))
|
||||
assert len(hass.loop.call_soon.mock_calls) == 0
|
||||
assert len(hass.loop.create_task.mock_calls) == 1
|
||||
assert len(hass.add_job.mock_calls) == 0
|
||||
|
||||
|
||||
def test_async_add_job_schedule_partial_coroutinefunction(loop):
|
||||
def test_async_add_hass_job_schedule_partial_coroutinefunction(loop):
|
||||
"""Test that we schedule partial coros and add jobs to the job pool."""
|
||||
hass = MagicMock(loop=MagicMock(wraps=loop))
|
||||
|
||||
@ -93,20 +93,20 @@ def test_async_add_job_schedule_partial_coroutinefunction(loop):
|
||||
|
||||
partial = functools.partial(job)
|
||||
|
||||
ha.HomeAssistant.async_add_job(hass, partial)
|
||||
ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(partial))
|
||||
assert len(hass.loop.call_soon.mock_calls) == 0
|
||||
assert len(hass.loop.create_task.mock_calls) == 1
|
||||
assert len(hass.add_job.mock_calls) == 0
|
||||
|
||||
|
||||
def test_async_add_job_add_threaded_job_to_pool():
|
||||
def test_async_add_job_add_hass_threaded_job_to_pool():
|
||||
"""Test that we schedule coroutines and add jobs to the job pool."""
|
||||
hass = MagicMock()
|
||||
|
||||
def job():
|
||||
pass
|
||||
|
||||
ha.HomeAssistant.async_add_job(hass, job)
|
||||
ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(job))
|
||||
assert len(hass.loop.call_soon.mock_calls) == 0
|
||||
assert len(hass.loop.create_task.mock_calls) == 0
|
||||
assert len(hass.loop.run_in_executor.mock_calls) == 1
|
||||
@ -125,7 +125,7 @@ def test_async_create_task_schedule_coroutine(loop):
|
||||
assert len(hass.add_job.mock_calls) == 0
|
||||
|
||||
|
||||
def test_async_run_job_calls_callback():
|
||||
def test_async_run_hass_job_calls_callback():
|
||||
"""Test that the callback annotation is respected."""
|
||||
hass = MagicMock()
|
||||
calls = []
|
||||
@ -133,12 +133,12 @@ def test_async_run_job_calls_callback():
|
||||
def job():
|
||||
calls.append(1)
|
||||
|
||||
ha.HomeAssistant.async_run_job(hass, ha.callback(job))
|
||||
ha.HomeAssistant.async_run_hass_job(hass, ha.HassJob(ha.callback(job)))
|
||||
assert len(calls) == 1
|
||||
assert len(hass.async_add_job.mock_calls) == 0
|
||||
|
||||
|
||||
def test_async_run_job_delegates_non_async():
|
||||
def test_async_run_hass_job_delegates_non_async():
|
||||
"""Test that the callback annotation is respected."""
|
||||
hass = MagicMock()
|
||||
calls = []
|
||||
@ -146,9 +146,9 @@ def test_async_run_job_delegates_non_async():
|
||||
def job():
|
||||
calls.append(1)
|
||||
|
||||
ha.HomeAssistant.async_run_job(hass, job)
|
||||
ha.HomeAssistant.async_run_hass_job(hass, ha.HassJob(job))
|
||||
assert len(calls) == 0
|
||||
assert len(hass.async_add_job.mock_calls) == 1
|
||||
assert len(hass.async_add_hass_job.mock_calls) == 1
|
||||
|
||||
|
||||
def test_stage_shutdown():
|
||||
|
Loading…
x
Reference in New Issue
Block a user