diff --git a/homeassistant/components/cast/media_player.py b/homeassistant/components/cast/media_player.py index 653f3b5aed2..ee10f06c985 100644 --- a/homeassistant/components/cast/media_player.py +++ b/homeassistant/components/cast/media_player.py @@ -24,6 +24,7 @@ from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, dispatcher_send) from homeassistant.helpers.typing import ConfigType, HomeAssistantType import homeassistant.util.dt as dt_util +from homeassistant.util.logging import async_create_catching_coro from . import DOMAIN as CAST_DOMAIN @@ -522,8 +523,8 @@ class CastDevice(MediaPlayerDevice): if _is_matching_dynamic_group(self._cast_info, discover): _LOGGER.debug("Discovered matching dynamic group: %s", discover) - self.hass.async_create_task( - self.async_set_dynamic_group(discover)) + self.hass.async_create_task(async_create_catching_coro( + self.async_set_dynamic_group(discover))) return if self._cast_info.uuid != discover.uuid: @@ -536,7 +537,8 @@ class CastDevice(MediaPlayerDevice): self._cast_info.host, self._cast_info.port) return _LOGGER.debug("Discovered chromecast with same UUID: %s", discover) - self.hass.async_create_task(self.async_set_cast_info(discover)) + self.hass.async_create_task(async_create_catching_coro( + self.async_set_cast_info(discover))) def async_cast_removed(discover: ChromecastInfo): """Handle removal of Chromecast.""" @@ -546,13 +548,15 @@ class CastDevice(MediaPlayerDevice): if (self._dynamic_group_cast_info is not None and self._dynamic_group_cast_info.uuid == discover.uuid): _LOGGER.debug("Removed matching dynamic group: %s", discover) - self.hass.async_create_task(self.async_del_dynamic_group()) + self.hass.async_create_task(async_create_catching_coro( + self.async_del_dynamic_group())) return if self._cast_info.uuid != discover.uuid: # Removed is not our device. return _LOGGER.debug("Removed chromecast with same UUID: %s", discover) - self.hass.async_create_task(self.async_del_cast_info(discover)) + self.hass.async_create_task(async_create_catching_coro( + self.async_del_cast_info(discover))) async def async_stop(event): """Disconnect socket on Home Assistant stop.""" @@ -565,14 +569,15 @@ class CastDevice(MediaPlayerDevice): self.hass, SIGNAL_CAST_REMOVED, async_cast_removed) self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop) - self.hass.async_create_task(self.async_set_cast_info(self._cast_info)) + self.hass.async_create_task(async_create_catching_coro( + self.async_set_cast_info(self._cast_info))) for info in self.hass.data[KNOWN_CHROMECAST_INFO_KEY]: if _is_matching_dynamic_group(self._cast_info, info): _LOGGER.debug("[%s %s (%s:%s)] Found dynamic group: %s", self.entity_id, self._cast_info.friendly_name, self._cast_info.host, self._cast_info.port, info) - self.hass.async_create_task( - self.async_set_dynamic_group(info)) + self.hass.async_create_task(async_create_catching_coro( + self.async_set_dynamic_group(info))) break async def async_will_remove_from_hass(self) -> None: diff --git a/homeassistant/util/logging.py b/homeassistant/util/logging.py index 214d9417e2a..317a30d9d56 100644 --- a/homeassistant/util/logging.py +++ b/homeassistant/util/logging.py @@ -6,7 +6,7 @@ import inspect import logging import threading import traceback -from typing import Any, Callable, Optional +from typing import Any, Callable, Coroutine, Optional from .async_ import run_coroutine_threadsafe @@ -130,7 +130,7 @@ def catch_log_exception( func: Callable[..., Any], format_err: Callable[..., Any], *args: Any) -> Callable[[], None]: - """Decorate an callback to catch and log exceptions.""" + """Decorate a callback to catch and log exceptions.""" def log_exception(*args: Any) -> None: module_name = inspect.getmodule(inspect.trace()[1][0]).__name__ # Do not print the wrapper in the traceback @@ -164,3 +164,43 @@ def catch_log_exception( log_exception(*args) wrapper_func = wrapper return wrapper_func + + +def catch_log_coro_exception( + target: Coroutine[Any, Any, Any], + format_err: Callable[..., Any], + *args: Any) -> Coroutine[Any, Any, Any]: + """Decorate a coroutine to catch and log exceptions.""" + async def coro_wrapper(*args: Any) -> Any: + """Catch and log exception.""" + try: + return await target + except Exception: # pylint: disable=broad-except + module_name = inspect.getmodule(inspect.trace()[1][0]).__name__ + # Do not print the wrapper in the traceback + frames = len(inspect.trace()) - 1 + exc_msg = traceback.format_exc(-frames) + friendly_msg = format_err(*args) + logging.getLogger(module_name).error('%s\n%s', + friendly_msg, exc_msg) + return None + return coro_wrapper() + + +def async_create_catching_coro( + target: Coroutine) -> Coroutine: + """Wrap a coroutine to catch and log exceptions. + + The exception will be logged together with a stacktrace of where the + coroutine was wrapped. + + target: target coroutine. + """ + trace = traceback.extract_stack() + wrapped_target = catch_log_coro_exception( + target, lambda *args: + "Exception in {} called from\n {}".format( + target.__name__, # type: ignore + "".join(traceback.format_list(trace[:-1])))) + + return wrapped_target diff --git a/tests/util/test_logging.py b/tests/util/test_logging.py index c67b2aea448..92a06587fda 100644 --- a/tests/util/test_logging.py +++ b/tests/util/test_logging.py @@ -65,3 +65,16 @@ def test_async_handler_thread_log(loop): assert queue.get_nowait() == log_record assert queue.empty() + + +async def test_async_create_catching_coro(hass, caplog): + """Test exception logging of wrapped coroutine.""" + async def job(): + raise Exception('This is a bad coroutine') + pass + + hass.async_create_task(logging_util.async_create_catching_coro(job())) + await hass.async_block_till_done() + assert 'This is a bad coroutine' in caplog.text + assert ('hass.async_create_task(' + 'logging_util.async_create_catching_coro(job()))' in caplog.text)