Do not allow coroutines to be passed to HassJob (#42073)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Paulus Schoutsen 2020-10-19 23:25:33 +02:00 committed by GitHub
parent ec7f329807
commit b8417a2ce2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 39 additions and 17 deletions

View File

@ -311,9 +311,10 @@ async def async_setup_entry(hass, config_entry):
_LOGGER.error("Config entry failed: %s", err) _LOGGER.error("Config entry failed: %s", err)
raise ConfigEntryNotReady from err raise ConfigEntryNotReady from err
hass.bus.async_listen_once( async def _async_disconnect_websocket(*_):
EVENT_HOMEASSISTANT_STOP, ambient.client.websocket.disconnect() await ambient.client.websocket.disconnect()
)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_disconnect_websocket)
return True return True

View File

@ -60,7 +60,7 @@ async def async_setup(hass, config):
conf.get(CONF_PASSWORD), conf.get(CONF_PASSWORD),
) )
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, kafka.shutdown) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, kafka.shutdown)
await kafka.start() await kafka.start()

View File

@ -58,7 +58,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
if not await sensor.connection(): if not await sensor.connection():
raise PlatformNotReady raise PlatformNotReady
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, sensor.shutdown()) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, sensor.shutdown)
async_add_entities([sensor], True) async_add_entities([sensor], True)
@ -173,7 +173,7 @@ class ImapSensor(Entity):
_LOGGER.warning("Lost %s (will attempt to reconnect)", self._server) _LOGGER.warning("Lost %s (will attempt to reconnect)", self._server)
self._connection = None self._connection = None
async def shutdown(self): async def shutdown(self, *_):
"""Close resources.""" """Close resources."""
if self._connection: if self._connection:
if self._connection.has_pending_idle(): if self._connection.has_pending_idle():

View File

@ -114,10 +114,11 @@ async def async_setup(hass, config):
if not result: if not result:
return False return False
async def _close(): @callback
def _close(*_):
controller.close() controller.close()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _close()) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _close)
_LOGGER.debug("Arm home config: %s, mode: %s ", conf, conf.get(CONF_ARM_HOME_MODE)) _LOGGER.debug("Arm home config: %s, mode: %s ", conf, conf.get(CONF_ARM_HOME_MODE))

View File

@ -153,10 +153,9 @@ def is_callback(func: Callable[..., Any]) -> bool:
class HassJobType(enum.Enum): class HassJobType(enum.Enum):
"""Represent a job type.""" """Represent a job type."""
Coroutine = 1 Coroutinefunction = 1
Coroutinefunction = 2 Callback = 2
Callback = 3 Executor = 3
Executor = 4
class HassJob: class HassJob:
@ -171,6 +170,9 @@ class HassJob:
def __init__(self, target: Callable): def __init__(self, target: Callable):
"""Create a job object.""" """Create a job object."""
if asyncio.iscoroutine(target):
raise ValueError("Coroutine not allowed to be passed to HassJob")
self.target = target self.target = target
self.job_type = _get_callable_job_type(target) self.job_type = _get_callable_job_type(target)
@ -186,8 +188,6 @@ def _get_callable_job_type(target: Callable) -> HassJobType:
while isinstance(check_target, functools.partial): while isinstance(check_target, functools.partial):
check_target = check_target.func check_target = check_target.func
if asyncio.iscoroutine(check_target):
return HassJobType.Coroutine
if asyncio.iscoroutinefunction(check_target): if asyncio.iscoroutinefunction(check_target):
return HassJobType.Coroutinefunction return HassJobType.Coroutinefunction
if is_callback(check_target): if is_callback(check_target):
@ -352,6 +352,9 @@ class HomeAssistant:
if target is None: if target is None:
raise ValueError("Don't call async_add_job with None") raise ValueError("Don't call async_add_job with None")
if asyncio.iscoroutine(target):
return self.async_create_task(cast(Coroutine, target))
return self.async_add_hass_job(HassJob(target), *args) return self.async_add_hass_job(HassJob(target), *args)
@callback @callback
@ -364,9 +367,7 @@ class HomeAssistant:
hassjob: HassJob to call. hassjob: HassJob to call.
args: parameters for method to call. args: parameters for method to call.
""" """
if hassjob.job_type == HassJobType.Coroutine: if hassjob.job_type == HassJobType.Coroutinefunction:
task = self.loop.create_task(hassjob.target) # type: ignore
elif hassjob.job_type == HassJobType.Coroutinefunction:
task = self.loop.create_task(hassjob.target(*args)) task = self.loop.create_task(hassjob.target(*args))
elif hassjob.job_type == HassJobType.Callback: elif hassjob.job_type == HassJobType.Callback:
self.loop.call_soon(hassjob.target, *args) self.loop.call_soon(hassjob.target, *args)
@ -445,6 +446,10 @@ class HomeAssistant:
target: target to call. target: target to call.
args: parameters for method to call. args: parameters for method to call.
""" """
if asyncio.iscoroutine(target):
self.async_create_task(cast(Coroutine, target))
return
self.async_run_hass_job(HassJob(target), *args) self.async_run_hass_job(HassJob(target), *args)
def block_till_done(self) -> None: def block_till_done(self) -> None:

View File

@ -1522,3 +1522,18 @@ async def test_async_entity_ids_count(hass):
assert hass.states.async_entity_ids_count() == 5 assert hass.states.async_entity_ids_count() == 5
assert hass.states.async_entity_ids_count("light") == 3 assert hass.states.async_entity_ids_count("light") == 3
async def test_hassjob_forbid_coroutine():
"""Test hassjob forbids coroutines."""
async def bla():
pass
coro = bla()
with pytest.raises(ValueError):
ha.HassJob(coro)
# To avoid warning about unawaited coro
await coro