Add thread safety checks to async_create_task (#116339)

* Add thread safety checks to async_create_task

Calling async_create_task from a thread almost always results in an
fast crash. Since most internals are using async_create_background_task
or other task APIs, and this is the one integrations seem to get wrong
the most, add a thread safety check here

* Add thread safety checks to async_create_task

Calling async_create_task from a thread almost always results in an
fast crash. Since most internals are using async_create_background_task
or other task APIs, and this is the one integrations seem to get wrong
the most, add a thread safety check here

* missed one

* Update homeassistant/core.py

* fix mocks

* one more internal

* more places where internal can be used

* more places where internal can be used

* more places where internal can be used

* internal one more place since this is high volume and was already eager_start
This commit is contained in:
J. Nick Koston 2024-04-28 17:29:00 -05:00 committed by Paulus Schoutsen
parent 6786479a81
commit 66538ba34e
14 changed files with 70 additions and 23 deletions

View File

@ -731,7 +731,7 @@ async def async_setup_multi_components(
# to wait to be imported, and the sooner we can get the base platforms # to wait to be imported, and the sooner we can get the base platforms
# loaded the sooner we can start loading the rest of the integrations. # loaded the sooner we can start loading the rest of the integrations.
futures = { futures = {
domain: hass.async_create_task( domain: hass.async_create_task_internal(
async_setup_component(hass, domain, config), async_setup_component(hass, domain, config),
f"setup component {domain}", f"setup component {domain}",
eager_start=True, eager_start=True,

View File

@ -1087,7 +1087,7 @@ class ConfigEntry:
target: target to call. target: target to call.
""" """
task = hass.async_create_task( task = hass.async_create_task_internal(
target, f"{name} {self.title} {self.domain} {self.entry_id}", eager_start target, f"{name} {self.title} {self.domain} {self.entry_id}", eager_start
) )
if eager_start and task.done(): if eager_start and task.done():
@ -1643,7 +1643,7 @@ class ConfigEntries:
# starting a new flow with the 'unignore' step. If the integration doesn't # starting a new flow with the 'unignore' step. If the integration doesn't
# implement async_step_unignore then this will be a no-op. # implement async_step_unignore then this will be a no-op.
if entry.source == SOURCE_IGNORE: if entry.source == SOURCE_IGNORE:
self.hass.async_create_task( self.hass.async_create_task_internal(
self.hass.config_entries.flow.async_init( self.hass.config_entries.flow.async_init(
entry.domain, entry.domain,
context={"source": SOURCE_UNIGNORE}, context={"source": SOURCE_UNIGNORE},

View File

@ -785,7 +785,9 @@ class HomeAssistant:
target: target to call. target: target to call.
""" """
self.loop.call_soon_threadsafe( self.loop.call_soon_threadsafe(
functools.partial(self.async_create_task, target, name, eager_start=True) functools.partial(
self.async_create_task_internal, target, name, eager_start=True
)
) )
@callback @callback
@ -800,6 +802,37 @@ class HomeAssistant:
This method must be run in the event loop. If you are using this in your This method must be run in the event loop. If you are using this in your
integration, use the create task methods on the config entry instead. integration, use the create task methods on the config entry instead.
target: target to call.
"""
# We turned on asyncio debug in April 2024 in the dev containers
# in the hope of catching some of the issues that have been
# reported. It will take a while to get all the issues fixed in
# custom components.
#
# In 2025.5 we should guard the `verify_event_loop_thread`
# check with a check for the `hass.config.debug` flag being set as
# long term we don't want to be checking this in production
# environments since it is a performance hit.
self.verify_event_loop_thread("async_create_task")
return self.async_create_task_internal(target, name, eager_start)
@callback
def async_create_task_internal(
self,
target: Coroutine[Any, Any, _R],
name: str | None = None,
eager_start: bool = True,
) -> asyncio.Task[_R]:
"""Create a task from within the event loop, internal use only.
This method is intended to only be used by core internally
and should not be considered a stable API. We will make
breaking change to this function in the future and it
should not be used in integrations.
This method must be run in the event loop. If you are using this in your
integration, use the create task methods on the config entry instead.
target: target to call. target: target to call.
""" """
if eager_start: if eager_start:
@ -2695,7 +2728,7 @@ class ServiceRegistry:
coro = self._execute_service(handler, service_call) coro = self._execute_service(handler, service_call)
if not blocking: if not blocking:
self._hass.async_create_task( self._hass.async_create_task_internal(
self._run_service_call_catch_exceptions(coro, service_call), self._run_service_call_catch_exceptions(coro, service_call),
f"service call background {service_call.domain}.{service_call.service}", f"service call background {service_call.domain}.{service_call.service}",
eager_start=True, eager_start=True,

View File

@ -1490,7 +1490,7 @@ class Entity(
is_remove = action == "remove" is_remove = action == "remove"
self._removed_from_registry = is_remove self._removed_from_registry = is_remove
if action == "update" or is_remove: if action == "update" or is_remove:
self.hass.async_create_task( self.hass.async_create_task_internal(
self._async_process_registry_update_or_remove(event), eager_start=True self._async_process_registry_update_or_remove(event), eager_start=True
) )

View File

@ -146,7 +146,7 @@ class EntityComponent(Generic[_EntityT]):
# Look in config for Domain, Domain 2, Domain 3 etc and load them # Look in config for Domain, Domain 2, Domain 3 etc and load them
for p_type, p_config in conf_util.config_per_platform(config, self.domain): for p_type, p_config in conf_util.config_per_platform(config, self.domain):
if p_type is not None: if p_type is not None:
self.hass.async_create_task( self.hass.async_create_task_internal(
self.async_setup_platform(p_type, p_config), self.async_setup_platform(p_type, p_config),
f"EntityComponent setup platform {p_type} {self.domain}", f"EntityComponent setup platform {p_type} {self.domain}",
eager_start=True, eager_start=True,

View File

@ -477,7 +477,7 @@ class EntityPlatform:
self, new_entities: Iterable[Entity], update_before_add: bool = False self, new_entities: Iterable[Entity], update_before_add: bool = False
) -> None: ) -> None:
"""Schedule adding entities for a single platform async.""" """Schedule adding entities for a single platform async."""
task = self.hass.async_create_task( task = self.hass.async_create_task_internal(
self.async_add_entities(new_entities, update_before_add=update_before_add), self.async_add_entities(new_entities, update_before_add=update_before_add),
f"EntityPlatform async_add_entities {self.domain}.{self.platform_name}", f"EntityPlatform async_add_entities {self.domain}.{self.platform_name}",
eager_start=True, eager_start=True,

View File

@ -85,7 +85,7 @@ def _async_integration_platform_component_loaded(
# At least one of the platforms is not loaded, we need to load them # At least one of the platforms is not loaded, we need to load them
# so we have to fall back to creating a task. # so we have to fall back to creating a task.
hass.async_create_task( hass.async_create_task_internal(
_async_process_integration_platforms_for_component( _async_process_integration_platforms_for_component(
hass, integration, platforms_that_exist, integration_platforms_by_name hass, integration, platforms_that_exist, integration_platforms_by_name
), ),
@ -206,7 +206,7 @@ async def async_process_integration_platforms(
# We use hass.async_create_task instead of asyncio.create_task because # We use hass.async_create_task instead of asyncio.create_task because
# we want to make sure that startup waits for the task to complete. # we want to make sure that startup waits for the task to complete.
# #
future = hass.async_create_task( future = hass.async_create_task_internal(
_async_process_integration_platforms( _async_process_integration_platforms(
hass, platform_name, top_level_components.copy(), process_job hass, platform_name, top_level_components.copy(), process_job
), ),

View File

@ -659,7 +659,7 @@ class DynamicServiceIntentHandler(IntentHandler):
) )
await self._run_then_background( await self._run_then_background(
hass.async_create_task( hass.async_create_task_internal(
hass.services.async_call( hass.services.async_call(
domain, domain,
service, service,

View File

@ -236,7 +236,9 @@ class RestoreStateData:
# Dump the initial states now. This helps minimize the risk of having # Dump the initial states now. This helps minimize the risk of having
# old states loaded by overwriting the last states once Home Assistant # old states loaded by overwriting the last states once Home Assistant
# has started and the old states have been read. # has started and the old states have been read.
self.hass.async_create_task(_async_dump_states(), "RestoreStateData dump") self.hass.async_create_task_internal(
_async_dump_states(), "RestoreStateData dump"
)
# Dump states periodically # Dump states periodically
cancel_interval = async_track_time_interval( cancel_interval = async_track_time_interval(

View File

@ -734,7 +734,7 @@ class _ScriptRun:
) )
trace_set_result(params=params, running_script=running_script) trace_set_result(params=params, running_script=running_script)
response_data = await self._async_run_long_action( response_data = await self._async_run_long_action(
self._hass.async_create_task( self._hass.async_create_task_internal(
self._hass.services.async_call( self._hass.services.async_call(
**params, **params,
blocking=True, blocking=True,
@ -1208,7 +1208,7 @@ class _ScriptRun:
async def _async_run_script(self, script: Script) -> None: async def _async_run_script(self, script: Script) -> None:
"""Execute a script.""" """Execute a script."""
result = await self._async_run_long_action( result = await self._async_run_long_action(
self._hass.async_create_task( self._hass.async_create_task_internal(
script.async_run(self._variables, self._context), eager_start=True script.async_run(self._variables, self._context), eager_start=True
) )
) )

View File

@ -468,7 +468,7 @@ class Store(Generic[_T]):
# wrote. Reschedule the timer to the next write time. # wrote. Reschedule the timer to the next write time.
self._async_reschedule_delayed_write(self._next_write_time) self._async_reschedule_delayed_write(self._next_write_time)
return return
self.hass.async_create_task( self.hass.async_create_task_internal(
self._async_callback_delayed_write(), eager_start=True self._async_callback_delayed_write(), eager_start=True
) )

View File

@ -600,7 +600,7 @@ def _async_when_setup(
_LOGGER.exception("Error handling when_setup callback for %s", component) _LOGGER.exception("Error handling when_setup callback for %s", component)
if component in hass.config.components: if component in hass.config.components:
hass.async_create_task( hass.async_create_task_internal(
when_setup(), f"when setup {component}", eager_start=True when_setup(), f"when setup {component}", eager_start=True
) )
return return

View File

@ -234,7 +234,7 @@ async def async_test_home_assistant(
orig_async_add_job = hass.async_add_job orig_async_add_job = hass.async_add_job
orig_async_add_executor_job = hass.async_add_executor_job orig_async_add_executor_job = hass.async_add_executor_job
orig_async_create_task = hass.async_create_task orig_async_create_task_internal = hass.async_create_task_internal
orig_tz = dt_util.DEFAULT_TIME_ZONE orig_tz = dt_util.DEFAULT_TIME_ZONE
def async_add_job(target, *args, eager_start: bool = False): def async_add_job(target, *args, eager_start: bool = False):
@ -263,18 +263,18 @@ async def async_test_home_assistant(
return orig_async_add_executor_job(target, *args) return orig_async_add_executor_job(target, *args)
def async_create_task(coroutine, name=None, eager_start=True): def async_create_task_internal(coroutine, name=None, eager_start=True):
"""Create task.""" """Create task."""
if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock): if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock):
fut = asyncio.Future() fut = asyncio.Future()
fut.set_result(None) fut.set_result(None)
return fut return fut
return orig_async_create_task(coroutine, name, eager_start) return orig_async_create_task_internal(coroutine, name, eager_start)
hass.async_add_job = async_add_job hass.async_add_job = async_add_job
hass.async_add_executor_job = async_add_executor_job hass.async_add_executor_job = async_add_executor_job
hass.async_create_task = async_create_task hass.async_create_task_internal = async_create_task_internal
hass.data[loader.DATA_CUSTOM_COMPONENTS] = {} hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}

View File

@ -329,7 +329,7 @@ async def test_async_create_task_schedule_coroutine() -> None:
async def job(): async def job():
pass pass
ha.HomeAssistant.async_create_task(hass, job(), eager_start=False) ha.HomeAssistant.async_create_task_internal(hass, job(), eager_start=False)
assert len(hass.loop.call_soon.mock_calls) == 0 assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 1 assert len(hass.loop.create_task.mock_calls) == 1
assert len(hass.add_job.mock_calls) == 0 assert len(hass.add_job.mock_calls) == 0
@ -342,7 +342,7 @@ async def test_async_create_task_eager_start_schedule_coroutine() -> None:
async def job(): async def job():
pass pass
ha.HomeAssistant.async_create_task(hass, job(), eager_start=True) ha.HomeAssistant.async_create_task_internal(hass, job(), eager_start=True)
# Should create the task directly since 3.12 supports eager_start # Should create the task directly since 3.12 supports eager_start
assert len(hass.loop.create_task.mock_calls) == 0 assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 0 assert len(hass.add_job.mock_calls) == 0
@ -355,7 +355,7 @@ async def test_async_create_task_schedule_coroutine_with_name() -> None:
async def job(): async def job():
pass pass
task = ha.HomeAssistant.async_create_task( task = ha.HomeAssistant.async_create_task_internal(
hass, job(), "named task", eager_start=False hass, job(), "named task", eager_start=False
) )
assert len(hass.loop.call_soon.mock_calls) == 0 assert len(hass.loop.call_soon.mock_calls) == 0
@ -3480,3 +3480,15 @@ async def test_async_remove_thread_safety(hass: HomeAssistant) -> None:
await hass.async_add_executor_job( await hass.async_add_executor_job(
hass.services.async_remove, "test_domain", "test_service" hass.services.async_remove, "test_domain", "test_service"
) )
async def test_async_create_task_thread_safety(hass: HomeAssistant) -> None:
"""Test async_create_task thread safety."""
async def _any_coro():
pass
with pytest.raises(
RuntimeError, match="Detected code that calls async_create_task from a thread."
):
await hass.async_add_executor_job(hass.async_create_task, _any_coro)