Await callbacks to keep cleaner stacktraces (#43693)

This commit is contained in:
Paulus Schoutsen 2020-11-27 17:48:43 +01:00 committed by GitHub
parent 20ed40d7ad
commit 5b6d9abe2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 18 deletions

View File

@ -1485,20 +1485,22 @@ def async_subscribe_connection_status(hass, connection_status_callback):
connection_status_callback_job = HassJob(connection_status_callback) connection_status_callback_job = HassJob(connection_status_callback)
@callback async def connected():
def connected(): task = hass.async_run_hass_job(connection_status_callback_job, True)
hass.async_add_hass_job(connection_status_callback_job, True) if task:
await task
@callback async def disconnected():
def disconnected(): task = hass.async_run_hass_job(connection_status_callback_job, False)
_LOGGER.error("Calling connection_status_callback, False") if task:
hass.async_add_hass_job(connection_status_callback_job, False) await task
subscriptions = { subscriptions = {
"connect": async_dispatcher_connect(hass, MQTT_CONNECTED, connected), "connect": async_dispatcher_connect(hass, MQTT_CONNECTED, connected),
"disconnect": async_dispatcher_connect(hass, MQTT_DISCONNECTED, disconnected), "disconnect": async_dispatcher_connect(hass, MQTT_DISCONNECTED, disconnected),
} }
@callback
def unsubscribe(): def unsubscribe():
subscriptions["connect"]() subscriptions["connect"]()
subscriptions["disconnect"]() subscriptions["disconnect"]()

View File

@ -48,7 +48,7 @@ class Debouncer:
async def async_call(self) -> None: async def async_call(self) -> None:
"""Call the function.""" """Call the function."""
assert self.function is not None assert self._job is not None
if self._timer_task: if self._timer_task:
if not self._execute_at_end_of_timer: if not self._execute_at_end_of_timer:
@ -70,13 +70,15 @@ class Debouncer:
if self._timer_task: if self._timer_task:
return return
await self.hass.async_add_hass_job(self._job) # type: ignore task = self.hass.async_run_hass_job(self._job)
if task:
await task
self._schedule_timer() self._schedule_timer()
async def _handle_timer_finish(self) -> None: async def _handle_timer_finish(self) -> None:
"""Handle a finished timer.""" """Handle a finished timer."""
assert self.function is not None assert self._job is not None
self._timer_task = None self._timer_task = None
@ -95,7 +97,9 @@ class Debouncer:
return # type: ignore return # type: ignore
try: try:
await self.hass.async_add_hass_job(self._job) # type: ignore task = self.hass.async_run_hass_job(self._job)
if task:
await task
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
self.logger.exception("Unexpected exception from %s", self.function) self.logger.exception("Unexpected exception from %s", self.function)

View File

@ -44,13 +44,14 @@ def async_listen(
job = core.HassJob(callback) job = core.HassJob(callback)
@core.callback async def discovery_event_listener(event: core.Event) -> None:
def discovery_event_listener(event: core.Event) -> None:
"""Listen for discovery events.""" """Listen for discovery events."""
if ATTR_SERVICE in event.data and event.data[ATTR_SERVICE] in service: if ATTR_SERVICE in event.data and event.data[ATTR_SERVICE] in service:
hass.async_add_hass_job( task = hass.async_run_hass_job(
job, event.data[ATTR_SERVICE], event.data.get(ATTR_DISCOVERED) job, event.data[ATTR_SERVICE], event.data.get(ATTR_DISCOVERED)
) )
if task:
await task
hass.bus.async_listen(EVENT_PLATFORM_DISCOVERED, discovery_event_listener) hass.bus.async_listen(EVENT_PLATFORM_DISCOVERED, discovery_event_listener)
@ -114,8 +115,7 @@ def async_listen_platform(
service = EVENT_LOAD_PLATFORM.format(component) service = EVENT_LOAD_PLATFORM.format(component)
job = core.HassJob(callback) job = core.HassJob(callback)
@core.callback async def discovery_platform_listener(event: core.Event) -> None:
def discovery_platform_listener(event: core.Event) -> None:
"""Listen for platform discovery events.""" """Listen for platform discovery events."""
if event.data.get(ATTR_SERVICE) != service: if event.data.get(ATTR_SERVICE) != service:
return return
@ -125,7 +125,9 @@ def async_listen_platform(
if not platform: if not platform:
return return
hass.async_run_hass_job(job, platform, event.data.get(ATTR_DISCOVERED)) task = hass.async_run_hass_job(job, platform, event.data.get(ATTR_DISCOVERED))
if task:
await task
hass.bus.async_listen(EVENT_PLATFORM_DISCOVERED, discovery_platform_listener) hass.bus.async_listen(EVENT_PLATFORM_DISCOVERED, discovery_platform_listener)

View File

@ -1,9 +1,9 @@
"""Test to verify that we can load components.""" """Test to verify that we can load components."""
import pytest import pytest
from homeassistant import core, loader
from homeassistant.components import http, hue from homeassistant.components import http, hue
from homeassistant.components.hue import light as hue_light from homeassistant.components.hue import light as hue_light
import homeassistant.loader as loader
from tests.async_mock import ANY, patch from tests.async_mock import ANY, patch
from tests.common import MockModule, async_mock_service, mock_integration from tests.common import MockModule, async_mock_service, mock_integration
@ -83,6 +83,7 @@ async def test_helpers_wrapper(hass):
result = [] result = []
@core.callback
def discovery_callback(service, discovered): def discovery_callback(service, discovered):
"""Handle discovery callback.""" """Handle discovery callback."""
result.append(discovered) result.append(discovered)