diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index a0d2f1c8702..ef53721efd2 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -83,6 +83,8 @@ from .models import ( Events, StateAttributes, States, + StatisticData, + StatisticMetaData, StatisticsRuns, process_timestamp, ) @@ -460,8 +462,8 @@ class StatisticsTask(RecorderTask): class ExternalStatisticsTask(RecorderTask): """An object to insert into the recorder queue to run an external statistics task.""" - metadata: dict - statistics: Iterable[dict] + metadata: StatisticMetaData + statistics: Iterable[StatisticData] def run(self, instance: Recorder) -> None: """Run statistics task.""" @@ -916,7 +918,9 @@ class Recorder(threading.Thread): self.queue.put(UpdateStatisticsMetadataTask(statistic_id, unit_of_measurement)) @callback - def async_external_statistics(self, metadata: dict, stats: Iterable[dict]) -> None: + def async_external_statistics( + self, metadata: StatisticMetaData, stats: Iterable[StatisticData] + ) -> None: """Schedule external statistics.""" self.queue.put(ExternalStatisticsTask(metadata, stats)) diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index b67f4c6d558..b63bbe740bc 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -8,7 +8,7 @@ import functools import logging import os import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar from awesomeversion import ( AwesomeVersion, @@ -20,6 +20,7 @@ from sqlalchemy.engine.cursor import CursorFetchStrategy from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session +from typing_extensions import Concatenate, ParamSpec from homeassistant.core import HomeAssistant import homeassistant.util.dt as dt_util @@ -40,6 +41,9 @@ from .models import ( if TYPE_CHECKING: from . import Recorder +_RecorderT = TypeVar("_RecorderT", bound="Recorder") +_P = ParamSpec("_P") + _LOGGER = logging.getLogger(__name__) RETRIES = 3 @@ -430,15 +434,22 @@ def end_incomplete_runs(session: Session, start_time: datetime) -> None: session.add(run) -def retryable_database_job(description: str) -> Callable: +def retryable_database_job( + description: str, +) -> Callable[ + [Callable[Concatenate[_RecorderT, _P], bool]], + Callable[Concatenate[_RecorderT, _P], bool], +]: """Try to execute a database job. The job should return True if it finished, and False if it needs to be rescheduled. """ - def decorator(job: Callable[[Any], bool]) -> Callable: + def decorator( + job: Callable[Concatenate[_RecorderT, _P], bool] + ) -> Callable[Concatenate[_RecorderT, _P], bool]: @functools.wraps(job) - def wrapper(instance: Recorder, *args: Any, **kwargs: Any) -> bool: + def wrapper(instance: _RecorderT, *args: _P.args, **kwargs: _P.kwargs) -> bool: try: return job(instance, *args, **kwargs) except OperationalError as err: