Add decorator typing [zha] (#107599)

This commit is contained in:
Marc Mueller 2024-01-12 11:42:10 +01:00 committed by GitHub
parent 827a1b1f48
commit c1faafc6a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,7 +8,7 @@ from __future__ import annotations
import asyncio import asyncio
import binascii import binascii
import collections import collections
from collections.abc import Callable, Iterator from collections.abc import Callable, Collection, Coroutine, Iterator
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
import enum import enum
@ -17,7 +17,7 @@ import itertools
import logging import logging
from random import uniform from random import uniform
import re import re
from typing import TYPE_CHECKING, Any, TypeVar from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar
import voluptuous as vol import voluptuous as vol
import zigpy.exceptions import zigpy.exceptions
@ -37,10 +37,14 @@ from .const import CLUSTER_TYPE_IN, CLUSTER_TYPE_OUT, CUSTOM_CONFIGURATION, DATA
from .registries import BINDABLE_CLUSTERS from .registries import BINDABLE_CLUSTERS
if TYPE_CHECKING: if TYPE_CHECKING:
from .cluster_handlers import ClusterHandler
from .device import ZHADevice from .device import ZHADevice
from .gateway import ZHAGateway from .gateway import ZHAGateway
_ClusterHandlerT = TypeVar("_ClusterHandlerT", bound="ClusterHandler")
_T = TypeVar("_T") _T = TypeVar("_T")
_R = TypeVar("_R")
_P = ParamSpec("_P")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -319,8 +323,12 @@ class LogMixin:
def retryable_req( def retryable_req(
delays=(1, 5, 10, 15, 30, 60, 120, 180, 360, 600, 900, 1800), raise_=False delays: Collection[float] = (1, 5, 10, 15, 30, 60, 120, 180, 360, 600, 900, 1800),
): raise_: bool = False,
) -> Callable[
[Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R]]],
Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R | None]],
]:
"""Make a method with ZCL requests retryable. """Make a method with ZCL requests retryable.
This adds delays keyword argument to function. This adds delays keyword argument to function.
@ -328,9 +336,13 @@ def retryable_req(
raise_ if the final attempt should raise the exception. raise_ if the final attempt should raise the exception.
""" """
def decorator(func): def decorator(
func: Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R]],
) -> Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R | None]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(cluster_handler, *args, **kwargs): async def wrapper(
cluster_handler: _ClusterHandlerT, *args: _P.args, **kwargs: _P.kwargs
) -> _R | None:
exceptions = (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError) exceptions = (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError)
try_count, errors = 1, [] try_count, errors = 1, []
for delay in itertools.chain(delays, [None]): for delay in itertools.chain(delays, [None]):
@ -355,6 +367,7 @@ def retryable_req(
) )
if raise_: if raise_:
raise raise
return None
return wrapper return wrapper