Do not allow coroutines to be passed to HassJob (#42073)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Paulus Schoutsen 2020-10-19 23:25:33 +02:00 committed by GitHub
parent ec7f329807
commit b8417a2ce2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 39 additions and 17 deletions

View File

@ -311,9 +311,10 @@ async def async_setup_entry(hass, config_entry):
_LOGGER.error("Config entry failed: %s", err)
raise ConfigEntryNotReady from err
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, ambient.client.websocket.disconnect()
)
async def _async_disconnect_websocket(*_):
await ambient.client.websocket.disconnect()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_disconnect_websocket)
return True

View File

@ -60,7 +60,7 @@ async def async_setup(hass, config):
conf.get(CONF_PASSWORD),
)
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, kafka.shutdown)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, kafka.shutdown)
await kafka.start()

View File

@ -58,7 +58,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
if not await sensor.connection():
raise PlatformNotReady
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, sensor.shutdown())
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, sensor.shutdown)
async_add_entities([sensor], True)
@ -173,7 +173,7 @@ class ImapSensor(Entity):
_LOGGER.warning("Lost %s (will attempt to reconnect)", self._server)
self._connection = None
async def shutdown(self):
async def shutdown(self, *_):
"""Close resources."""
if self._connection:
if self._connection.has_pending_idle():

View File

@ -114,10 +114,11 @@ async def async_setup(hass, config):
if not result:
return False
async def _close():
@callback
def _close(*_):
controller.close()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _close())
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _close)
_LOGGER.debug("Arm home config: %s, mode: %s ", conf, conf.get(CONF_ARM_HOME_MODE))

View File

@ -153,10 +153,9 @@ def is_callback(func: Callable[..., Any]) -> bool:
class HassJobType(enum.Enum):
"""Represent a job type."""
Coroutine = 1
Coroutinefunction = 2
Callback = 3
Executor = 4
Coroutinefunction = 1
Callback = 2
Executor = 3
class HassJob:
@ -171,6 +170,9 @@ class HassJob:
def __init__(self, target: Callable):
"""Create a job object."""
if asyncio.iscoroutine(target):
raise ValueError("Coroutine not allowed to be passed to HassJob")
self.target = target
self.job_type = _get_callable_job_type(target)
@ -186,8 +188,6 @@ def _get_callable_job_type(target: Callable) -> HassJobType:
while isinstance(check_target, functools.partial):
check_target = check_target.func
if asyncio.iscoroutine(check_target):
return HassJobType.Coroutine
if asyncio.iscoroutinefunction(check_target):
return HassJobType.Coroutinefunction
if is_callback(check_target):
@ -352,6 +352,9 @@ class HomeAssistant:
if target is None:
raise ValueError("Don't call async_add_job with None")
if asyncio.iscoroutine(target):
return self.async_create_task(cast(Coroutine, target))
return self.async_add_hass_job(HassJob(target), *args)
@callback
@ -364,9 +367,7 @@ class HomeAssistant:
hassjob: HassJob to call.
args: parameters for method to call.
"""
if hassjob.job_type == HassJobType.Coroutine:
task = self.loop.create_task(hassjob.target) # type: ignore
elif hassjob.job_type == HassJobType.Coroutinefunction:
if hassjob.job_type == HassJobType.Coroutinefunction:
task = self.loop.create_task(hassjob.target(*args))
elif hassjob.job_type == HassJobType.Callback:
self.loop.call_soon(hassjob.target, *args)
@ -445,6 +446,10 @@ class HomeAssistant:
target: target to call.
args: parameters for method to call.
"""
if asyncio.iscoroutine(target):
self.async_create_task(cast(Coroutine, target))
return
self.async_run_hass_job(HassJob(target), *args)
def block_till_done(self) -> None:

View File

@ -1522,3 +1522,18 @@ async def test_async_entity_ids_count(hass):
assert hass.states.async_entity_ids_count() == 5
assert hass.states.async_entity_ids_count("light") == 3
async def test_hassjob_forbid_coroutine():
"""Test hassjob forbids coroutines."""
async def bla():
pass
coro = bla()
with pytest.raises(ValueError):
ha.HassJob(coro)
# To avoid warning about unawaited coro
await coro