diff --git a/homeassistant/core.py b/homeassistant/core.py index 37d1134ef29..2834730408e 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -8,6 +8,7 @@ import asyncio from concurrent.futures import ThreadPoolExecutor import datetime import enum +import functools import logging import os import pathlib @@ -258,11 +259,15 @@ class HomeAssistant: """ task = None - if asyncio.iscoroutine(target): + check_target = target + if isinstance(target, functools.partial): + check_target = target.func + + if asyncio.iscoroutine(check_target): task = self.loop.create_task(target) # type: ignore - elif is_callback(target): + elif is_callback(check_target): self.loop.call_soon(target, *args) - elif asyncio.iscoroutinefunction(target): + elif asyncio.iscoroutinefunction(check_target): task = self.loop.create_task(target(*args)) else: task = self.loop.run_in_executor( # type: ignore diff --git a/tests/test_core.py b/tests/test_core.py index 5ee9f5cdb05..f1900979bec 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,6 +1,7 @@ """Test to verify that Home Assistant core works.""" # pylint: disable=protected-access import asyncio +import functools import logging import os import unittest @@ -45,11 +46,24 @@ def test_async_add_job_schedule_callback(): assert len(hass.add_job.mock_calls) == 0 -@patch('asyncio.iscoroutinefunction', return_value=True) -def test_async_add_job_schedule_coroutinefunction(mock_iscoro): - """Test that we schedule coroutines and add jobs to the job pool.""" +def test_async_add_job_schedule_partial_callback(): + """Test that we schedule partial coros and add jobs to the job pool.""" hass = MagicMock() job = MagicMock() + partial = functools.partial(ha.callback(job)) + + ha.HomeAssistant.async_add_job(hass, partial) + assert len(hass.loop.call_soon.mock_calls) == 1 + assert len(hass.loop.create_task.mock_calls) == 0 + assert len(hass.add_job.mock_calls) == 0 + + +def test_async_add_job_schedule_coroutinefunction(): + """Test that we schedule coroutines and add jobs to the job pool.""" + hass = MagicMock() + + async def job(): + pass ha.HomeAssistant.async_add_job(hass, job) assert len(hass.loop.call_soon.mock_calls) == 0 @@ -57,11 +71,26 @@ def test_async_add_job_schedule_coroutinefunction(mock_iscoro): assert len(hass.add_job.mock_calls) == 0 -@patch('asyncio.iscoroutinefunction', return_value=False) -def test_async_add_job_add_threaded_job_to_pool(mock_iscoro): +def test_async_add_job_schedule_partial_coroutinefunction(): + """Test that we schedule partial coros and add jobs to the job pool.""" + hass = MagicMock() + + async def job(): + pass + partial = functools.partial(job) + + ha.HomeAssistant.async_add_job(hass, partial) + assert len(hass.loop.call_soon.mock_calls) == 0 + assert len(hass.loop.create_task.mock_calls) == 1 + assert len(hass.add_job.mock_calls) == 0 + + +def test_async_add_job_add_threaded_job_to_pool(): """Test that we schedule coroutines and add jobs to the job pool.""" hass = MagicMock() - job = MagicMock() + + def job(): + pass ha.HomeAssistant.async_add_job(hass, job) assert len(hass.loop.call_soon.mock_calls) == 0 @@ -69,13 +98,14 @@ def test_async_add_job_add_threaded_job_to_pool(mock_iscoro): assert len(hass.loop.run_in_executor.mock_calls) == 1 -@patch('asyncio.iscoroutine', return_value=True) -def test_async_create_task_schedule_coroutine(mock_iscoro): +def test_async_create_task_schedule_coroutine(): """Test that we schedule coroutines and add jobs to the job pool.""" hass = MagicMock() - job = MagicMock() - ha.HomeAssistant.async_create_task(hass, job) + async def job(): + pass + + ha.HomeAssistant.async_create_task(hass, job()) assert len(hass.loop.call_soon.mock_calls) == 0 assert len(hass.loop.create_task.mock_calls) == 1 assert len(hass.add_job.mock_calls) == 0