Added recursive detection of functools.partial. (#20284)

This commit is contained in:
Andrew Sayre 2019-01-21 00:27:32 -06:00 committed by Paulus Schoutsen
parent 9482a6303d
commit 5c208da82e
2 changed files with 10 additions and 4 deletions

View File

@ -259,9 +259,10 @@ class HomeAssistant:
""" """
task = None task = None
# Check for partials to properly determine if coroutine function
check_target = target check_target = target
if isinstance(target, functools.partial): while isinstance(check_target, functools.partial):
check_target = target.func check_target = check_target.func
if asyncio.iscoroutine(check_target): if asyncio.iscoroutine(check_target):
task = self.loop.create_task(target) # type: ignore task = self.loop.create_task(target) # type: ignore

View File

@ -1,7 +1,7 @@
"""Logging utilities.""" """Logging utilities."""
import asyncio import asyncio
from asyncio.events import AbstractEventLoop from asyncio.events import AbstractEventLoop
from functools import wraps from functools import partial, wraps
import inspect import inspect
import logging import logging
import threading import threading
@ -139,8 +139,13 @@ def catch_log_exception(
friendly_msg = format_err(*args) friendly_msg = format_err(*args)
logging.getLogger(module_name).error('%s\n%s', friendly_msg, exc_msg) logging.getLogger(module_name).error('%s\n%s', friendly_msg, exc_msg)
# Check for partials to properly determine if coroutine function
check_func = func
while isinstance(check_func, partial):
check_func = check_func.func
wrapper_func = None wrapper_func = None
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(check_func):
@wraps(func) @wraps(func)
async def async_wrapper(*args: Any) -> None: async def async_wrapper(*args: Any) -> None:
"""Catch and log exception.""" """Catch and log exception."""