Improve typing [util.logging] (#70894)

This commit is contained in:
Marc Mueller 2022-04-27 22:26:56 +02:00 committed by GitHub
parent b4a0345b38
commit 9a3908d21d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 16 deletions

View File

@ -26,6 +26,7 @@ homeassistant.util.async_
homeassistant.util.color homeassistant.util.color
homeassistant.util.decorator homeassistant.util.decorator
homeassistant.util.location homeassistant.util.location
homeassistant.util.logging
homeassistant.util.process homeassistant.util.process
homeassistant.util.unit_system homeassistant.util.unit_system

View File

@ -2,18 +2,20 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable, Coroutine from collections.abc import Callable, Coroutine
from functools import partial, wraps from functools import partial, wraps
import inspect import inspect
import logging import logging
import logging.handlers import logging.handlers
import queue import queue
import traceback import traceback
from typing import Any, cast, overload from typing import Any, TypeVar, cast, overload
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE
from homeassistant.core import HomeAssistant, callback, is_callback from homeassistant.core import HomeAssistant, callback, is_callback
_T = TypeVar("_T")
class HideSensitiveDataFilter(logging.Filter): class HideSensitiveDataFilter(logging.Filter):
"""Filter API password calls.""" """Filter API password calls."""
@ -115,22 +117,24 @@ def log_exception(format_err: Callable[..., Any], *args: Any) -> None:
@overload @overload
def catch_log_exception( # type: ignore[misc] def catch_log_exception(
func: Callable[..., Awaitable[Any]], format_err: Callable[..., Any], *args: Any func: Callable[..., Coroutine[Any, Any, Any]],
) -> Callable[..., Awaitable[None]]: format_err: Callable[..., Any],
"""Overload for Callables that return an Awaitable.""" *args: Any,
) -> Callable[..., Coroutine[Any, Any, None]]:
"""Overload for Callables that return a Coroutine."""
@overload @overload
def catch_log_exception( def catch_log_exception(
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
) -> Callable[..., None]: ) -> Callable[..., None | Coroutine[Any, Any, None]]:
"""Overload for Callables that return Any.""" """Overload for Callables that return Any."""
def catch_log_exception( def catch_log_exception(
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
) -> Callable[..., None] | Callable[..., Awaitable[None]]: ) -> Callable[..., None | Coroutine[Any, Any, None]]:
"""Decorate a callback to catch and log exceptions.""" """Decorate a callback to catch and log exceptions."""
# Check for partials to properly determine if coroutine function # Check for partials to properly determine if coroutine function
@ -138,9 +142,9 @@ def catch_log_exception(
while isinstance(check_func, partial): while isinstance(check_func, partial):
check_func = check_func.func check_func = check_func.func
wrapper_func: Callable[..., None] | Callable[..., Awaitable[None]] wrapper_func: Callable[..., None | Coroutine[Any, Any, None]]
if asyncio.iscoroutinefunction(check_func): if asyncio.iscoroutinefunction(check_func):
async_func = cast(Callable[..., Awaitable[None]], func) async_func = cast(Callable[..., Coroutine[Any, Any, None]], func)
@wraps(async_func) @wraps(async_func)
async def async_wrapper(*args: Any) -> None: async def async_wrapper(*args: Any) -> None:
@ -170,11 +174,11 @@ def catch_log_exception(
def catch_log_coro_exception( def catch_log_coro_exception(
target: Coroutine[Any, Any, Any], format_err: Callable[..., Any], *args: Any target: Coroutine[Any, Any, _T], format_err: Callable[..., Any], *args: Any
) -> Coroutine[Any, Any, Any]: ) -> Coroutine[Any, Any, _T | None]:
"""Decorate a coroutine to catch and log exceptions.""" """Decorate a coroutine to catch and log exceptions."""
async def coro_wrapper(*args: Any) -> Any: async def coro_wrapper(*args: Any) -> _T | None:
"""Catch and log exception.""" """Catch and log exception."""
try: try:
return await target return await target
@ -182,10 +186,12 @@ def catch_log_coro_exception(
log_exception(format_err, *args) log_exception(format_err, *args)
return None return None
return coro_wrapper() return coro_wrapper(*args)
def async_create_catching_coro(target: Coroutine) -> Coroutine: def async_create_catching_coro(
target: Coroutine[Any, Any, _T]
) -> Coroutine[Any, Any, _T | None]:
"""Wrap a coroutine to catch and log exceptions. """Wrap a coroutine to catch and log exceptions.
The exception will be logged together with a stacktrace of where the The exception will be logged together with a stacktrace of where the
@ -196,7 +202,7 @@ def async_create_catching_coro(target: Coroutine) -> Coroutine:
trace = traceback.extract_stack() trace = traceback.extract_stack()
wrapped_target = catch_log_coro_exception( wrapped_target = catch_log_coro_exception(
target, target,
lambda *args: "Exception in {} called from\n {}".format( lambda: "Exception in {} called from\n {}".format(
target.__name__, target.__name__,
"".join(traceback.format_list(trace[:-1])), "".join(traceback.format_list(trace[:-1])),
), ),

View File

@ -89,6 +89,9 @@ disallow_any_generics = true
[mypy-homeassistant.util.location] [mypy-homeassistant.util.location]
disallow_any_generics = true disallow_any_generics = true
[mypy-homeassistant.util.logging]
disallow_any_generics = true
[mypy-homeassistant.util.process] [mypy-homeassistant.util.process]
disallow_any_generics = true disallow_any_generics = true