From b8417a2ce22370755820f05ca3e20b21a4ddc4d7 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 19 Oct 2020 23:25:33 +0200 Subject: [PATCH] Do not allow coroutines to be passed to HassJob (#42073) Co-authored-by: J. Nick Koston --- .../components/ambient_station/__init__.py | 7 +++--- .../components/apache_kafka/__init__.py | 2 +- homeassistant/components/imap/sensor.py | 4 ++-- .../components/satel_integra/__init__.py | 5 ++-- homeassistant/core.py | 23 +++++++++++-------- tests/test_core.py | 15 ++++++++++++ 6 files changed, 39 insertions(+), 17 deletions(-) diff --git a/homeassistant/components/ambient_station/__init__.py b/homeassistant/components/ambient_station/__init__.py index 68bfb85cf62..23c0ad6e3dd 100644 --- a/homeassistant/components/ambient_station/__init__.py +++ b/homeassistant/components/ambient_station/__init__.py @@ -311,9 +311,10 @@ async def async_setup_entry(hass, config_entry): _LOGGER.error("Config entry failed: %s", err) raise ConfigEntryNotReady from err - hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_STOP, ambient.client.websocket.disconnect() - ) + async def _async_disconnect_websocket(*_): + await ambient.client.websocket.disconnect() + + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_disconnect_websocket) return True diff --git a/homeassistant/components/apache_kafka/__init__.py b/homeassistant/components/apache_kafka/__init__.py index a0bb5d20abe..5be3732757f 100644 --- a/homeassistant/components/apache_kafka/__init__.py +++ b/homeassistant/components/apache_kafka/__init__.py @@ -60,7 +60,7 @@ async def async_setup(hass, config): conf.get(CONF_PASSWORD), ) - hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, kafka.shutdown) + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, kafka.shutdown) await kafka.start() diff --git a/homeassistant/components/imap/sensor.py b/homeassistant/components/imap/sensor.py index a824d5f8ee9..4917abc6028 100644 --- a/homeassistant/components/imap/sensor.py +++ b/homeassistant/components/imap/sensor.py @@ -58,7 +58,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= if not await sensor.connection(): raise PlatformNotReady - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, sensor.shutdown()) + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, sensor.shutdown) async_add_entities([sensor], True) @@ -173,7 +173,7 @@ class ImapSensor(Entity): _LOGGER.warning("Lost %s (will attempt to reconnect)", self._server) self._connection = None - async def shutdown(self): + async def shutdown(self, *_): """Close resources.""" if self._connection: if self._connection.has_pending_idle(): diff --git a/homeassistant/components/satel_integra/__init__.py b/homeassistant/components/satel_integra/__init__.py index 0b007f63e01..34b5511c394 100644 --- a/homeassistant/components/satel_integra/__init__.py +++ b/homeassistant/components/satel_integra/__init__.py @@ -114,10 +114,11 @@ async def async_setup(hass, config): if not result: return False - async def _close(): + @callback + def _close(*_): controller.close() - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _close()) + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _close) _LOGGER.debug("Arm home config: %s, mode: %s ", conf, conf.get(CONF_ARM_HOME_MODE)) diff --git a/homeassistant/core.py b/homeassistant/core.py index 795a67f3f19..e30f05de842 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -153,10 +153,9 @@ def is_callback(func: Callable[..., Any]) -> bool: class HassJobType(enum.Enum): """Represent a job type.""" - Coroutine = 1 - Coroutinefunction = 2 - Callback = 3 - Executor = 4 + Coroutinefunction = 1 + Callback = 2 + Executor = 3 class HassJob: @@ -171,6 +170,9 @@ class HassJob: def __init__(self, target: Callable): """Create a job object.""" + if asyncio.iscoroutine(target): + raise ValueError("Coroutine not allowed to be passed to HassJob") + self.target = target self.job_type = _get_callable_job_type(target) @@ -186,8 +188,6 @@ def _get_callable_job_type(target: Callable) -> HassJobType: while isinstance(check_target, functools.partial): check_target = check_target.func - if asyncio.iscoroutine(check_target): - return HassJobType.Coroutine if asyncio.iscoroutinefunction(check_target): return HassJobType.Coroutinefunction if is_callback(check_target): @@ -352,6 +352,9 @@ class HomeAssistant: if target is None: raise ValueError("Don't call async_add_job with None") + if asyncio.iscoroutine(target): + return self.async_create_task(cast(Coroutine, target)) + return self.async_add_hass_job(HassJob(target), *args) @callback @@ -364,9 +367,7 @@ class HomeAssistant: hassjob: HassJob to call. args: parameters for method to call. """ - if hassjob.job_type == HassJobType.Coroutine: - task = self.loop.create_task(hassjob.target) # type: ignore - elif hassjob.job_type == HassJobType.Coroutinefunction: + if hassjob.job_type == HassJobType.Coroutinefunction: task = self.loop.create_task(hassjob.target(*args)) elif hassjob.job_type == HassJobType.Callback: self.loop.call_soon(hassjob.target, *args) @@ -445,6 +446,10 @@ class HomeAssistant: target: target to call. args: parameters for method to call. """ + if asyncio.iscoroutine(target): + self.async_create_task(cast(Coroutine, target)) + return + self.async_run_hass_job(HassJob(target), *args) def block_till_done(self) -> None: diff --git a/tests/test_core.py b/tests/test_core.py index 402b43b7d11..22ef5727dfc 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1522,3 +1522,18 @@ async def test_async_entity_ids_count(hass): assert hass.states.async_entity_ids_count() == 5 assert hass.states.async_entity_ids_count("light") == 3 + + +async def test_hassjob_forbid_coroutine(): + """Test hassjob forbids coroutines.""" + + async def bla(): + pass + + coro = bla() + + with pytest.raises(ValueError): + ha.HassJob(coro) + + # To avoid warning about unawaited coro + await coro