Only work out job type once when setting up dispatcher (#116030)

This commit is contained in:
J. Nick Koston 2024-04-23 22:24:36 +02:00 committed by GitHub
parent f1fa33483e
commit 8f1761343e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 13 deletions

View File

@ -7,7 +7,12 @@ from functools import partial
import logging import logging
from typing import Any, TypeVarTuple, overload from typing import Any, TypeVarTuple, overload
from homeassistant.core import HassJob, HomeAssistant, callback from homeassistant.core import (
HassJob,
HomeAssistant,
callback,
get_hassjob_callable_job_type,
)
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.logging import catch_log_exception from homeassistant.util.logging import catch_log_exception
@ -161,9 +166,13 @@ def _generate_job(
signal: SignalType[*_Ts] | str, target: Callable[[*_Ts], Any] | Callable[..., Any] signal: SignalType[*_Ts] | str, target: Callable[[*_Ts], Any] | Callable[..., Any]
) -> HassJob[..., None | Coroutine[Any, Any, None]]: ) -> HassJob[..., None | Coroutine[Any, Any, None]]:
"""Generate a HassJob for a signal and target.""" """Generate a HassJob for a signal and target."""
job_type = get_hassjob_callable_job_type(target)
return HassJob( return HassJob(
catch_log_exception(target, partial(_format_err, signal, target)), catch_log_exception(
target, partial(_format_err, signal, target), job_type=job_type
),
f"dispatcher {signal}", f"dispatcher {signal}",
job_type=job_type,
) )

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from functools import partial, wraps from functools import partial, wraps
import inspect import inspect
@ -12,7 +11,12 @@ import queue
import traceback import traceback
from typing import Any, TypeVar, TypeVarTuple, cast, overload from typing import Any, TypeVar, TypeVarTuple, cast, overload
from homeassistant.core import HomeAssistant, callback, is_callback from homeassistant.core import (
HassJobType,
HomeAssistant,
callback,
get_hassjob_callable_job_type,
)
_T = TypeVar("_T") _T = TypeVar("_T")
_Ts = TypeVarTuple("_Ts") _Ts = TypeVarTuple("_Ts")
@ -129,34 +133,38 @@ def _callback_wrapper(
@overload @overload
def catch_log_exception( def catch_log_exception(
func: Callable[[*_Ts], Coroutine[Any, Any, Any]], format_err: Callable[[*_Ts], Any] func: Callable[[*_Ts], Coroutine[Any, Any, Any]],
format_err: Callable[[*_Ts], Any],
job_type: HassJobType | None = None,
) -> Callable[[*_Ts], Coroutine[Any, Any, None]]: ... ) -> Callable[[*_Ts], Coroutine[Any, Any, None]]: ...
@overload @overload
def catch_log_exception( def catch_log_exception(
func: Callable[[*_Ts], Any], format_err: Callable[[*_Ts], Any] func: Callable[[*_Ts], Any],
format_err: Callable[[*_Ts], Any],
job_type: HassJobType | None = None,
) -> Callable[[*_Ts], None] | Callable[[*_Ts], Coroutine[Any, Any, None]]: ... ) -> Callable[[*_Ts], None] | Callable[[*_Ts], Coroutine[Any, Any, None]]: ...
def catch_log_exception( def catch_log_exception(
func: Callable[[*_Ts], Any], format_err: Callable[[*_Ts], Any] func: Callable[[*_Ts], Any],
format_err: Callable[[*_Ts], Any],
job_type: HassJobType | None = None,
) -> Callable[[*_Ts], None] | Callable[[*_Ts], Coroutine[Any, Any, None]]: ) -> Callable[[*_Ts], None] | Callable[[*_Ts], Coroutine[Any, Any, None]]:
"""Decorate a function func to catch and log exceptions. """Decorate a function func to catch and log exceptions.
If func is a coroutine function, a coroutine function will be returned. If func is a coroutine function, a coroutine function will be returned.
If func is a callback, a callback will be returned. If func is a callback, a callback will be returned.
""" """
# Check for partials to properly determine if coroutine function if job_type is None:
check_func = func job_type = get_hassjob_callable_job_type(func)
while isinstance(check_func, partial):
check_func = check_func.func # type: ignore[unreachable] # false positive
if asyncio.iscoroutinefunction(check_func): if job_type is HassJobType.Coroutinefunction:
async_func = cast(Callable[[*_Ts], Coroutine[Any, Any, None]], func) async_func = cast(Callable[[*_Ts], Coroutine[Any, Any, None]], func)
return wraps(async_func)(partial(_async_wrapper, async_func, format_err)) # type: ignore[return-value] return wraps(async_func)(partial(_async_wrapper, async_func, format_err)) # type: ignore[return-value]
if is_callback(check_func): if job_type is HassJobType.Callback:
return wraps(func)(partial(_callback_wrapper, func, format_err)) # type: ignore[return-value] return wraps(func)(partial(_callback_wrapper, func, format_err)) # type: ignore[return-value]
return wraps(func)(partial(_sync_wrapper, func, format_err)) # type: ignore[return-value] return wraps(func)(partial(_sync_wrapper, func, format_err)) # type: ignore[return-value]