diff --git a/homeassistant/helpers/dispatcher.py b/homeassistant/helpers/dispatcher.py index c1194c7da01..52d57e9cf08 100644 --- a/homeassistant/helpers/dispatcher.py +++ b/homeassistant/helpers/dispatcher.py @@ -7,7 +7,12 @@ from functools import partial import logging from typing import Any, TypeVarTuple, overload -from homeassistant.core import HassJob, HomeAssistant, callback +from homeassistant.core import ( + HassJob, + HomeAssistant, + callback, + get_hassjob_callable_job_type, +) from homeassistant.loader import bind_hass from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.logging import catch_log_exception @@ -161,9 +166,13 @@ def _generate_job( signal: SignalType[*_Ts] | str, target: Callable[[*_Ts], Any] | Callable[..., Any] ) -> HassJob[..., None | Coroutine[Any, Any, None]]: """Generate a HassJob for a signal and target.""" + job_type = get_hassjob_callable_job_type(target) return HassJob( - catch_log_exception(target, partial(_format_err, signal, target)), + catch_log_exception( + target, partial(_format_err, signal, target), job_type=job_type + ), f"dispatcher {signal}", + job_type=job_type, ) diff --git a/homeassistant/util/logging.py b/homeassistant/util/logging.py index 8709186face..ab163578846 100644 --- a/homeassistant/util/logging.py +++ b/homeassistant/util/logging.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio from collections.abc import Callable, Coroutine from functools import partial, wraps import inspect @@ -12,7 +11,12 @@ import queue import traceback from typing import Any, TypeVar, TypeVarTuple, cast, overload -from homeassistant.core import HomeAssistant, callback, is_callback +from homeassistant.core import ( + HassJobType, + HomeAssistant, + callback, + get_hassjob_callable_job_type, +) _T = TypeVar("_T") _Ts = TypeVarTuple("_Ts") @@ -129,34 +133,38 @@ def _callback_wrapper( @overload def catch_log_exception( - func: Callable[[*_Ts], Coroutine[Any, Any, Any]], format_err: Callable[[*_Ts], Any] + func: Callable[[*_Ts], Coroutine[Any, Any, Any]], + format_err: Callable[[*_Ts], Any], + job_type: HassJobType | None = None, ) -> Callable[[*_Ts], Coroutine[Any, Any, None]]: ... @overload def catch_log_exception( - func: Callable[[*_Ts], Any], format_err: Callable[[*_Ts], Any] + func: Callable[[*_Ts], Any], + format_err: Callable[[*_Ts], Any], + job_type: HassJobType | None = None, ) -> Callable[[*_Ts], None] | Callable[[*_Ts], Coroutine[Any, Any, None]]: ... def catch_log_exception( - func: Callable[[*_Ts], Any], format_err: Callable[[*_Ts], Any] + func: Callable[[*_Ts], Any], + format_err: Callable[[*_Ts], Any], + job_type: HassJobType | None = None, ) -> Callable[[*_Ts], None] | Callable[[*_Ts], Coroutine[Any, Any, None]]: """Decorate a function func to catch and log exceptions. If func is a coroutine function, a coroutine function will be returned. If func is a callback, a callback will be returned. """ - # Check for partials to properly determine if coroutine function - check_func = func - while isinstance(check_func, partial): - check_func = check_func.func # type: ignore[unreachable] # false positive + if job_type is None: + job_type = get_hassjob_callable_job_type(func) - if asyncio.iscoroutinefunction(check_func): + if job_type is HassJobType.Coroutinefunction: async_func = cast(Callable[[*_Ts], Coroutine[Any, Any, None]], func) return wraps(async_func)(partial(_async_wrapper, async_func, format_err)) # type: ignore[return-value] - if is_callback(check_func): + if job_type is HassJobType.Callback: return wraps(func)(partial(_callback_wrapper, func, format_err)) # type: ignore[return-value] return wraps(func)(partial(_sync_wrapper, func, format_err)) # type: ignore[return-value]