Refactor event time trackers to avoid using nonlocal (#107997)

This commit is contained in:
J. Nick Koston 2024-01-13 17:17:55 -10:00 committed by GitHub
parent e7c25d1c36
commit 659ee51914
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,7 +10,7 @@ import functools as ft
import logging import logging
from random import randint from random import randint
import time import time
from typing import Any, Concatenate, ParamSpec, TypedDict, TypeVar from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypedDict, TypeVar
import attr import attr
@ -1389,6 +1389,45 @@ def async_track_point_in_time(
track_point_in_time = threaded_listener_factory(async_track_point_in_time) track_point_in_time = threaded_listener_factory(async_track_point_in_time)
@dataclass(slots=True)
class _TrackPointUTCTime:
hass: HomeAssistant
job: HassJob[[datetime], Coroutine[Any, Any, None] | None]
utc_point_in_time: datetime
expected_fire_timestamp: float
_cancel_callback: asyncio.TimerHandle | None = None
def async_attach(self) -> None:
"""Initialize track job."""
loop = self.hass.loop
self._cancel_callback = loop.call_at(
loop.time() + self.expected_fire_timestamp - time.time(), self._run_action
)
@callback
def _run_action(self) -> None:
"""Call the action."""
# Depending on the available clock support (including timer hardware
# and the OS kernel) it can happen that we fire a little bit too early
# as measured by utcnow(). That is bad when callbacks have assumptions
# about the current time. Thus, we rearm the timer for the remaining
# time.
if (delta := (self.expected_fire_timestamp - time_tracker_timestamp())) > 0:
_LOGGER.debug("Called %f seconds too early, rearming", delta)
loop = self.hass.loop
self._cancel_callback = loop.call_at(loop.time() + delta, self._run_action)
return
self.hass.async_run_hass_job(self.job, self.utc_point_in_time)
@callback
def async_cancel(self) -> None:
"""Cancel the call_at."""
if TYPE_CHECKING:
assert self._cancel_callback is not None
self._cancel_callback.cancel()
@callback @callback
@bind_hass @bind_hass
def async_track_point_in_utc_time( def async_track_point_in_utc_time(
@ -1404,44 +1443,14 @@ def async_track_point_in_utc_time(
# Ensure point_in_time is UTC # Ensure point_in_time is UTC
utc_point_in_time = dt_util.as_utc(point_in_time) utc_point_in_time = dt_util.as_utc(point_in_time)
expected_fire_timestamp = dt_util.utc_to_timestamp(utc_point_in_time) expected_fire_timestamp = dt_util.utc_to_timestamp(utc_point_in_time)
# Since this is called once, we accept a HassJob so we can avoid
# having to figure out how to call the action every time its called.
cancel_callback: asyncio.TimerHandle | None = None
loop = hass.loop
@callback
def run_action(job: HassJob[[datetime], Coroutine[Any, Any, None] | None]) -> None:
"""Call the action."""
nonlocal cancel_callback
# Depending on the available clock support (including timer hardware
# and the OS kernel) it can happen that we fire a little bit too early
# as measured by utcnow(). That is bad when callbacks have assumptions
# about the current time. Thus, we rearm the timer for the remaining
# time.
if (delta := (expected_fire_timestamp - time_tracker_timestamp())) > 0:
_LOGGER.debug("Called %f seconds too early, rearming", delta)
cancel_callback = loop.call_at(loop.time() + delta, run_action, job)
return
hass.async_run_hass_job(job, utc_point_in_time)
job = ( job = (
action action
if isinstance(action, HassJob) if isinstance(action, HassJob)
else HassJob(action, f"track point in utc time {utc_point_in_time}") else HassJob(action, f"track point in utc time {utc_point_in_time}")
) )
delta = expected_fire_timestamp - time.time() track = _TrackPointUTCTime(hass, job, utc_point_in_time, expected_fire_timestamp)
cancel_callback = loop.call_at(loop.time() + delta, run_action, job) track.async_attach()
return track.async_cancel
@callback
def unsub_point_in_time_listener() -> None:
"""Cancel the call_at."""
assert cancel_callback is not None
cancel_callback.cancel()
return unsub_point_in_time_listener
track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_time) track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_time)
@ -1500,6 +1509,61 @@ def async_call_later(
call_later = threaded_listener_factory(async_call_later) call_later = threaded_listener_factory(async_call_later)
@dataclass(slots=True)
class _TrackTimeInterval:
"""Helper class to help listen to time interval events."""
hass: HomeAssistant
seconds: float
job_name: str
action: Callable[[datetime], Coroutine[Any, Any, None] | None]
cancel_on_shutdown: bool | None
_track_job: HassJob[[datetime], Coroutine[Any, Any, None] | None] | None = None
_run_job: HassJob[[datetime], Coroutine[Any, Any, None] | None] | None = None
_cancel_callback: CALLBACK_TYPE | None = None
def async_attach(self) -> None:
"""Initialize track job."""
hass = self.hass
self._track_job = HassJob(
self._interval_listener,
self.job_name,
job_type=HassJobType.Callback,
cancel_on_shutdown=self.cancel_on_shutdown,
)
self._run_job = HassJob(
self.action,
f"track time interval {self.seconds}",
cancel_on_shutdown=self.cancel_on_shutdown,
)
self._cancel_callback = async_call_at(
hass,
self._track_job,
hass.loop.time() + self.seconds,
)
@callback
def _interval_listener(self, now: datetime) -> None:
"""Handle elapsed intervals."""
if TYPE_CHECKING:
assert self._run_job is not None
assert self._track_job is not None
hass = self.hass
self._cancel_callback = async_call_at(
hass,
self._track_job,
hass.loop.time() + self.seconds,
)
hass.async_run_hass_job(self._run_job, now)
@callback
def async_cancel(self) -> None:
"""Cancel the call_at."""
if TYPE_CHECKING:
assert self._cancel_callback is not None
self._cancel_callback()
@callback @callback
@bind_hass @bind_hass
def async_track_time_interval( def async_track_time_interval(
@ -1514,41 +1578,13 @@ def async_track_time_interval(
The listener is passed the time it fires in UTC time. The listener is passed the time it fires in UTC time.
""" """
remove: CALLBACK_TYPE seconds = interval.total_seconds()
interval_listener_job: HassJob[[datetime], None] job_name = f"track time interval {seconds} {action}"
interval_seconds = interval.total_seconds()
job = HassJob(
action, f"track time interval {interval}", cancel_on_shutdown=cancel_on_shutdown
)
@callback
def interval_listener(now: datetime) -> None:
"""Handle elapsed intervals."""
nonlocal remove
nonlocal interval_listener_job
remove = async_call_later(hass, interval_seconds, interval_listener_job)
hass.async_run_hass_job(job, now)
if name: if name:
job_name = f"{name}: track time interval {interval} {action}" job_name = f"{name}: {job_name}"
else: track = _TrackTimeInterval(hass, seconds, job_name, action, cancel_on_shutdown)
job_name = f"track time interval {interval} {action}" track.async_attach()
return track.async_cancel
interval_listener_job = HassJob(
interval_listener,
job_name,
cancel_on_shutdown=cancel_on_shutdown,
job_type=HassJobType.Callback,
)
remove = async_call_later(hass, interval_seconds, interval_listener_job)
def remove_listener() -> None:
"""Remove interval listener."""
remove()
return remove_listener
track_time_interval = threaded_listener_factory(async_track_time_interval) track_time_interval = threaded_listener_factory(async_track_time_interval)