Lazy initialise the worker pool (#4110)

* Lazy initialise the worker pool

* Minimize pool initialization in core tests

* Fix tests on Python 3.4

* Remove passing in thread count to mock HASS

* Tests: Allow pool by default for threaded, disable for async

* Remove JobPriority for thread pool

* Fix wrong block_till_done

* EmulatedHue: Remove unused test code

* Zigbee: do not touch hass.pool

* Init loop in add_job

* Fix core test

* Fix random sensor test
This commit is contained in:
Paulus Schoutsen 2016-10-31 08:47:29 -07:00 committed by GitHub
parent a1e910f1cf
commit 7f699b4261
26 changed files with 140 additions and 185 deletions

View File

@ -131,9 +131,10 @@ def _async_setup_component(hass: core.HomeAssistant,
return False return False
component = loader.get_component(domain) component = loader.get_component(domain)
async_comp = hasattr(component, 'async_setup')
try: try:
if hasattr(component, 'async_setup'): if async_comp:
result = yield from component.async_setup(hass, config) result = yield from component.async_setup(hass, config)
else: else:
result = yield from hass.loop.run_in_executor( result = yield from hass.loop.run_in_executor(
@ -155,8 +156,11 @@ def _async_setup_component(hass: core.HomeAssistant,
# Assumption: if a component does not depend on groups # Assumption: if a component does not depend on groups
# it communicates with devices # it communicates with devices
if 'group' not in getattr(component, 'DEPENDENCIES', []) and \ if (not async_comp and
hass.pool.worker_count <= 10: 'group' not in getattr(component, 'DEPENDENCIES', [])):
if hass.pool is None:
hass.async_init_pool()
if hass.pool.worker_count <= 10:
hass.pool.add_worker() hass.pool.add_worker()
hass.bus.async_fire( hass.bus.async_fire(

View File

@ -12,7 +12,6 @@ import voluptuous as vol
from homeassistant.components import zigbee from homeassistant.components import zigbee
from homeassistant.components.zigbee import PLATFORM_SCHEMA from homeassistant.components.zigbee import PLATFORM_SCHEMA
from homeassistant.const import TEMP_CELSIUS from homeassistant.const import TEMP_CELSIUS
from homeassistant.core import JobPriority
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -56,8 +55,7 @@ class ZigBeeTemperatureSensor(Entity):
self._config = config self._config = config
self._temp = None self._temp = None
# Get initial state # Get initial state
hass.pool.add_job( hass.add_job(self.update_ha_state, True)
JobPriority.EVENT_STATE, (self.update_ha_state, True))
@property @property
def name(self): def name(self):

View File

@ -13,7 +13,6 @@ import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP, CONF_DEVICE, CONF_NAME, CONF_PIN) EVENT_HOMEASSISTANT_STOP, CONF_DEVICE, CONF_NAME, CONF_PIN)
from homeassistant.core import JobPriority
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
@ -308,8 +307,7 @@ class ZigBeeDigitalIn(Entity):
subscribe(hass, handle_frame) subscribe(hass, handle_frame)
# Get initial state # Get initial state
hass.pool.add_job( hass.add_job(self.update_ha_state, True)
JobPriority.EVENT_STATE, (self.update_ha_state, True))
@property @property
def name(self): def name(self):
@ -435,8 +433,7 @@ class ZigBeeAnalogIn(Entity):
subscribe(hass, handle_frame) subscribe(hass, handle_frame)
# Get initial state # Get initial state
hass.pool.add_job( hass.add_job(self.update_ha_state, True)
JobPriority.EVENT_STATE, (self.update_ha_state, True))
@property @property
def name(self): def name(self):

View File

@ -102,29 +102,6 @@ class CoreState(enum.Enum):
return self.value return self.value
class JobPriority(util.OrderedEnum):
"""Provides job priorities for event bus jobs."""
EVENT_CALLBACK = 0
EVENT_SERVICE = 1
EVENT_STATE = 2
EVENT_TIME = 3
EVENT_DEFAULT = 4
@staticmethod
def from_event_type(event_type):
"""Return a priority based on event type."""
if event_type == EVENT_TIME_CHANGED:
return JobPriority.EVENT_TIME
elif event_type == EVENT_STATE_CHANGED:
return JobPriority.EVENT_STATE
elif event_type == EVENT_CALL_SERVICE:
return JobPriority.EVENT_SERVICE
elif event_type == EVENT_SERVICE_EXECUTED:
return JobPriority.EVENT_CALLBACK
return JobPriority.EVENT_DEFAULT
class HomeAssistant(object): class HomeAssistant(object):
"""Root object of the Home Assistant home automation.""" """Root object of the Home Assistant home automation."""
@ -134,9 +111,10 @@ class HomeAssistant(object):
self.executor = ThreadPoolExecutor(max_workers=5) self.executor = ThreadPoolExecutor(max_workers=5)
self.loop.set_default_executor(self.executor) self.loop.set_default_executor(self.executor)
self.loop.set_exception_handler(self._async_exception_handler) self.loop.set_exception_handler(self._async_exception_handler)
self.pool = create_worker_pool() self.pool = None
self.bus = EventBus(self) self.bus = EventBus(self)
self.services = ServiceRegistry(self.bus, self.add_job, self.loop) self.services = ServiceRegistry(self.bus, self.async_add_job,
self.loop)
self.states = StateMachine(self.bus, self.loop) self.states = StateMachine(self.bus, self.loop)
self.config = Config() # type: Config self.config = Config() # type: Config
# This is a dictionary that any component can store any data on. # This is a dictionary that any component can store any data on.
@ -180,8 +158,7 @@ class HomeAssistant(object):
This method is a coroutine. This method is a coroutine.
""" """
_LOGGER.info( _LOGGER.info("Starting Home Assistant")
"Starting Home Assistant (%d threads)", self.pool.worker_count)
self.state = CoreState.starting self.state = CoreState.starting
@ -208,24 +185,24 @@ class HomeAssistant(object):
# pylint: disable=protected-access # pylint: disable=protected-access
self.loop._thread_ident = threading.get_ident() self.loop._thread_ident = threading.get_ident()
_async_create_timer(self) _async_create_timer(self)
_async_monitor_worker_pool(self)
self.bus.async_fire(EVENT_HOMEASSISTANT_START) self.bus.async_fire(EVENT_HOMEASSISTANT_START)
yield from self.loop.run_in_executor(None, self.pool.block_till_done) if self.pool is not None:
yield from self.loop.run_in_executor(
None, self.pool.block_till_done)
self.state = CoreState.running self.state = CoreState.running
def add_job(self, def add_job(self, target: Callable[..., None], *args: Any) -> None:
target: Callable[..., None],
*args: Any,
priority: JobPriority=JobPriority.EVENT_DEFAULT) -> None:
"""Add job to the worker pool. """Add job to the worker pool.
target: target to call. target: target to call.
args: parameters for method to call. args: parameters for method to call.
""" """
self.pool.add_job(priority, (target,) + args) if self.pool is None:
run_callback_threadsafe(self.pool, self.async_init_pool).result()
self.pool.add_job((target,) + args)
@callback @callback
def async_add_job(self, target: Callable[..., None], *args: Any): def async_add_job(self, target: Callable[..., None], *args: Any) -> None:
"""Add a job from within the eventloop. """Add a job from within the eventloop.
This method must be run in the event loop. This method must be run in the event loop.
@ -238,10 +215,12 @@ class HomeAssistant(object):
elif asyncio.iscoroutinefunction(target): elif asyncio.iscoroutinefunction(target):
self.loop.create_task(target(*args)) self.loop.create_task(target(*args))
else: else:
self.add_job(target, *args) if self.pool is None:
self.async_init_pool()
self.pool.add_job((target,) + args)
@callback @callback
def async_run_job(self, target: Callable[..., None], *args: Any): def async_run_job(self, target: Callable[..., None], *args: Any) -> None:
"""Run a job from within the event loop. """Run a job from within the event loop.
This method must be run in the event loop. This method must be run in the event loop.
@ -254,7 +233,7 @@ class HomeAssistant(object):
else: else:
self.async_add_job(target, *args) self.async_add_job(target, *args)
def _loop_empty(self): def _loop_empty(self) -> bool:
"""Python 3.4.2 empty loop compatibility function.""" """Python 3.4.2 empty loop compatibility function."""
# pylint: disable=protected-access # pylint: disable=protected-access
if sys.version_info < (3, 4, 3): if sys.version_info < (3, 4, 3):
@ -264,7 +243,7 @@ class HomeAssistant(object):
return self.loop._current_handle is None and \ return self.loop._current_handle is None and \
len(self.loop._ready) == 0 len(self.loop._ready) == 0
def block_till_done(self): def block_till_done(self) -> None:
"""Block till all pending work is done.""" """Block till all pending work is done."""
complete = threading.Event() complete = threading.Event()
@ -278,6 +257,7 @@ class HomeAssistant(object):
count = 0 count = 0
while True: while True:
# Wait for the work queue to empty # Wait for the work queue to empty
if self.pool is not None:
self.pool.block_till_done() self.pool.block_till_done()
# Verify the loop is empty # Verify the loop is empty
@ -309,7 +289,9 @@ class HomeAssistant(object):
""" """
self.state = CoreState.stopping self.state = CoreState.stopping
self.bus.async_fire(EVENT_HOMEASSISTANT_STOP) self.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
yield from self.loop.run_in_executor(None, self.pool.block_till_done) if self.pool is not None:
yield from self.loop.run_in_executor(
None, self.pool.block_till_done)
yield from self.loop.run_in_executor(None, self.pool.stop) yield from self.loop.run_in_executor(None, self.pool.stop)
self.executor.shutdown() self.executor.shutdown()
if self._websession is not None: if self._websession is not None:
@ -337,6 +319,12 @@ class HomeAssistant(object):
exc_info=exc_info exc_info=exc_info
) )
@callback
def async_init_pool(self):
"""Initialize the worker pool."""
self.pool = create_worker_pool()
_async_monitor_worker_pool(self)
@callback @callback
def _async_stop_handler(self, *args): def _async_stop_handler(self, *args):
"""Stop Home Assistant.""" """Stop Home Assistant."""
@ -867,10 +855,10 @@ class ServiceCall(object):
class ServiceRegistry(object): class ServiceRegistry(object):
"""Offers services over the eventbus.""" """Offers services over the eventbus."""
def __init__(self, bus, add_job, loop): def __init__(self, bus, async_add_job, loop):
"""Initialize a service registry.""" """Initialize a service registry."""
self._services = {} self._services = {}
self._add_job = add_job self._async_add_job = async_add_job
self._bus = bus self._bus = bus
self._loop = loop self._loop = loop
self._cur_id = 0 self._cur_id = 0
@ -1073,7 +1061,7 @@ class ServiceRegistry(object):
service_handler.func(service_call) service_handler.func(service_call)
fire_service_executed() fire_service_executed()
self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE) self._async_add_job(execute_service)
def _generate_unique_id(self): def _generate_unique_id(self):
"""Generate a unique service call id.""" """Generate a unique service call id."""

View File

@ -319,7 +319,7 @@ class ThreadPool(object):
self._job_handler = job_handler self._job_handler = job_handler
self.worker_count = 0 self.worker_count = 0
self._work_queue = queue.PriorityQueue() self._work_queue = queue.Queue()
self.current_jobs = [] self.current_jobs = []
self._quit_task = object() self._quit_task = object()
@ -349,24 +349,24 @@ class ThreadPool(object):
if not self.running: if not self.running:
raise RuntimeError("ThreadPool not running") raise RuntimeError("ThreadPool not running")
self._work_queue.put(PriorityQueueItem(0, self._quit_task)) self._work_queue.put(self._quit_task)
self.worker_count -= 1 self.worker_count -= 1
def add_job(self, priority, job): def add_job(self, job):
"""Add a job to the queue.""" """Add a job to the queue."""
if not self.running: if not self.running:
raise RuntimeError("ThreadPool not running") raise RuntimeError("ThreadPool not running")
self._work_queue.put(PriorityQueueItem(priority, job)) self._work_queue.put(job)
def add_many_jobs(self, jobs): def add_many_jobs(self, jobs):
"""Add a list of jobs to the queue.""" """Add a list of jobs to the queue."""
if not self.running: if not self.running:
raise RuntimeError("ThreadPool not running") raise RuntimeError("ThreadPool not running")
for priority, job in jobs: for job in jobs:
self._work_queue.put(PriorityQueueItem(priority, job)) self._work_queue.put(job)
def block_till_done(self): def block_till_done(self):
"""Block till current work is done.""" """Block till current work is done."""
@ -392,7 +392,7 @@ class ThreadPool(object):
"""Handle jobs for the thread pool.""" """Handle jobs for the thread pool."""
while True: while True:
# Get new item from work_queue # Get new item from work_queue
job = self._work_queue.get().item job = self._work_queue.get()
if job is self._quit_task: if job is self._quit_task:
self._work_queue.task_done() self._work_queue.task_done()
@ -410,16 +410,3 @@ class ThreadPool(object):
# Tell work_queue the task is done # Tell work_queue the task is done
self._work_queue.task_done() self._work_queue.task_done()
class PriorityQueueItem(object):
"""Holds a priority and a value. Used within PriorityQueue."""
def __init__(self, priority, item):
"""Initialize the queue."""
self.priority = priority
self.item = item
def __lt__(self, other):
"""Return the ordering."""
return self.priority < other.priority

View File

@ -31,18 +31,12 @@ def get_test_config_dir(*add_path):
return os.path.join(os.path.dirname(__file__), "testing_config", *add_path) return os.path.join(os.path.dirname(__file__), "testing_config", *add_path)
def get_test_home_assistant(num_threads=None): def get_test_home_assistant():
"""Return a Home Assistant object pointing at test config dir.""" """Return a Home Assistant object pointing at test config dir."""
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
if num_threads:
orig_num_threads = ha.MIN_WORKER_THREAD
ha.MIN_WORKER_THREAD = num_threads
hass = loop.run_until_complete(async_test_home_assistant(loop)) hass = loop.run_until_complete(async_test_home_assistant(loop))
hass.allow_pool = True
if num_threads:
ha.MIN_WORKER_THREAD = orig_num_threads
# FIXME should not be a daemon. Means hass.stop() not called in teardown # FIXME should not be a daemon. Means hass.stop() not called in teardown
stop_event = threading.Event() stop_event = threading.Event()
@ -60,17 +54,10 @@ def get_test_home_assistant(num_threads=None):
orig_start = hass.start orig_start = hass.start
orig_stop = hass.stop orig_stop = hass.stop
@asyncio.coroutine
def fake_stop():
"""Fake stop."""
yield None
@patch.object(ha, '_async_create_timer')
@patch.object(ha, '_async_monitor_worker_pool')
@patch.object(hass.loop, 'add_signal_handler') @patch.object(hass.loop, 'add_signal_handler')
@patch.object(ha, '_async_create_timer')
@patch.object(hass.loop, 'run_forever') @patch.object(hass.loop, 'run_forever')
@patch.object(hass.loop, 'close') @patch.object(hass.loop, 'close')
@patch.object(hass, 'async_stop', return_value=fake_stop())
def start_hass(*mocks): def start_hass(*mocks):
"""Helper to start hass.""" """Helper to start hass."""
orig_start() orig_start()
@ -108,6 +95,20 @@ def async_test_home_assistant(loop):
hass.state = ha.CoreState.running hass.state = ha.CoreState.running
hass.allow_pool = False
orig_init = hass.async_init_pool
@ha.callback
def mock_async_init_pool():
"""Prevent worker pool from being initialized."""
if hass.allow_pool:
with patch('homeassistant.core._async_monitor_worker_pool'):
orig_init()
else:
assert False, 'Thread pool not allowed. Set hass.allow_pool = True'
hass.async_init_pool = mock_async_init_pool
return hass return hass
@ -225,7 +226,8 @@ class MockModule(object):
# pylint: disable=invalid-name # pylint: disable=invalid-name
def __init__(self, domain=None, dependencies=None, setup=None, def __init__(self, domain=None, dependencies=None, setup=None,
requirements=None, config_schema=None, platform_schema=None): requirements=None, config_schema=None, platform_schema=None,
async_setup=None):
"""Initialize the mock module.""" """Initialize the mock module."""
self.DOMAIN = domain self.DOMAIN = domain
self.DEPENDENCIES = dependencies or [] self.DEPENDENCIES = dependencies or []
@ -238,8 +240,15 @@ class MockModule(object):
if platform_schema is not None: if platform_schema is not None:
self.PLATFORM_SCHEMA = platform_schema self.PLATFORM_SCHEMA = platform_schema
if async_setup is not None:
self.async_setup = async_setup
def setup(self, hass, config): def setup(self, hass, config):
"""Setup the component.""" """Setup the component.
We always define this mock because MagicMock setups will be seen by the
executor as a coroutine, raising an exception.
"""
if self._setup is not None: if self._setup is not None:
return self._setup(hass, config) return self._setup(hass, config)
return True return True

View File

@ -8,6 +8,7 @@ from homeassistant.bootstrap import setup_component
@asyncio.coroutine @asyncio.coroutine
def test_fetching_url(aioclient_mock, hass, test_client): def test_fetching_url(aioclient_mock, hass, test_client):
"""Test that it fetches the given url.""" """Test that it fetches the given url."""
hass.allow_pool = True
aioclient_mock.get('http://example.com', text='hello world') aioclient_mock.get('http://example.com', text='hello world')
def setup_platform(): def setup_platform():
@ -39,6 +40,7 @@ def test_fetching_url(aioclient_mock, hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_limit_refetch(aioclient_mock, hass, test_client): def test_limit_refetch(aioclient_mock, hass, test_client):
"""Test that it fetches the given url.""" """Test that it fetches the given url."""
hass.allow_pool = True
aioclient_mock.get('http://example.com/5a', text='hello world') aioclient_mock.get('http://example.com/5a', text='hello world')
aioclient_mock.get('http://example.com/10a', text='hello world') aioclient_mock.get('http://example.com/10a', text='hello world')
aioclient_mock.get('http://example.com/15a', text='hello planet') aioclient_mock.get('http://example.com/15a', text='hello planet')

View File

@ -14,6 +14,8 @@ from tests.common import assert_setup_component, mock_http_component
@asyncio.coroutine @asyncio.coroutine
def test_loading_file(hass, test_client): def test_loading_file(hass, test_client):
"""Test that it loads image from disk.""" """Test that it loads image from disk."""
hass.allow_pool = True
@mock.patch('os.path.isfile', mock.Mock(return_value=True)) @mock.patch('os.path.isfile', mock.Mock(return_value=True))
@mock.patch('os.access', mock.Mock(return_value=True)) @mock.patch('os.access', mock.Mock(return_value=True))
def setup_platform(): def setup_platform():

View File

@ -86,7 +86,7 @@ class TestDemoClimate(unittest.TestCase):
self.assertEqual(24.0, state.attributes.get('target_temp_high')) self.assertEqual(24.0, state.attributes.get('target_temp_high'))
climate.set_temperature(self.hass, target_temp_high=25, climate.set_temperature(self.hass, target_temp_high=25,
target_temp_low=20, entity_id=ENTITY_ECOBEE) target_temp_low=20, entity_id=ENTITY_ECOBEE)
self.hass.pool.block_till_done() self.hass.block_till_done()
state = self.hass.states.get(ENTITY_ECOBEE) state = self.hass.states.get(ENTITY_ECOBEE)
self.assertEqual(None, state.attributes.get('temperature')) self.assertEqual(None, state.attributes.get('temperature'))
self.assertEqual(20.0, state.attributes.get('target_temp_low')) self.assertEqual(20.0, state.attributes.get('target_temp_low'))
@ -102,7 +102,7 @@ class TestDemoClimate(unittest.TestCase):
climate.set_temperature(self.hass, temperature=None, climate.set_temperature(self.hass, temperature=None,
entity_id=ENTITY_ECOBEE, target_temp_low=None, entity_id=ENTITY_ECOBEE, target_temp_low=None,
target_temp_high=None) target_temp_high=None)
self.hass.pool.block_till_done() self.hass.block_till_done()
state = self.hass.states.get(ENTITY_ECOBEE) state = self.hass.states.get(ENTITY_ECOBEE)
self.assertEqual(None, state.attributes.get('temperature')) self.assertEqual(None, state.attributes.get('temperature'))
self.assertEqual(21.0, state.attributes.get('target_temp_low')) self.assertEqual(21.0, state.attributes.get('target_temp_low'))

View File

@ -15,7 +15,7 @@ class TestCoverRfxtrx(unittest.TestCase):
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(0) self.hass = get_test_home_assistant()
self.hass.config.components = ['rfxtrx'] self.hass.config.components = ['rfxtrx']
def tearDown(self): def tearDown(self):

View File

@ -30,7 +30,7 @@ class TestDemoClimate(unittest.TestCase):
"""Test light state attributes.""" """Test light state attributes."""
light.turn_on( light.turn_on(
self.hass, ENTITY_LIGHT, xy_color=(.4, .6), brightness=25) self.hass, ENTITY_LIGHT, xy_color=(.4, .6), brightness=25)
self.hass.pool.block_till_done() self.hass.block_till_done()
state = self.hass.states.get(ENTITY_LIGHT) state = self.hass.states.get(ENTITY_LIGHT)
self.assertTrue(light.is_on(self.hass, ENTITY_LIGHT)) self.assertTrue(light.is_on(self.hass, ENTITY_LIGHT))
self.assertEqual((.4, .6), state.attributes.get(light.ATTR_XY_COLOR)) self.assertEqual((.4, .6), state.attributes.get(light.ATTR_XY_COLOR))
@ -40,21 +40,21 @@ class TestDemoClimate(unittest.TestCase):
light.turn_on( light.turn_on(
self.hass, ENTITY_LIGHT, rgb_color=(251, 252, 253), self.hass, ENTITY_LIGHT, rgb_color=(251, 252, 253),
white_value=254) white_value=254)
self.hass.pool.block_till_done() self.hass.block_till_done()
state = self.hass.states.get(ENTITY_LIGHT) state = self.hass.states.get(ENTITY_LIGHT)
self.assertEqual(254, state.attributes.get(light.ATTR_WHITE_VALUE)) self.assertEqual(254, state.attributes.get(light.ATTR_WHITE_VALUE))
self.assertEqual( self.assertEqual(
(251, 252, 253), state.attributes.get(light.ATTR_RGB_COLOR)) (251, 252, 253), state.attributes.get(light.ATTR_RGB_COLOR))
light.turn_on(self.hass, ENTITY_LIGHT, color_temp=400) light.turn_on(self.hass, ENTITY_LIGHT, color_temp=400)
self.hass.pool.block_till_done() self.hass.block_till_done()
state = self.hass.states.get(ENTITY_LIGHT) state = self.hass.states.get(ENTITY_LIGHT)
self.assertEqual(400, state.attributes.get(light.ATTR_COLOR_TEMP)) self.assertEqual(400, state.attributes.get(light.ATTR_COLOR_TEMP))
def test_turn_off(self): def test_turn_off(self):
"""Test light turn off method.""" """Test light turn off method."""
light.turn_on(self.hass, ENTITY_LIGHT) light.turn_on(self.hass, ENTITY_LIGHT)
self.hass.pool.block_till_done() self.hass.block_till_done()
self.assertTrue(light.is_on(self.hass, ENTITY_LIGHT)) self.assertTrue(light.is_on(self.hass, ENTITY_LIGHT))
light.turn_off(self.hass, ENTITY_LIGHT) light.turn_off(self.hass, ENTITY_LIGHT)
self.hass.pool.block_till_done() self.hass.block_till_done()
self.assertFalse(light.is_on(self.hass, ENTITY_LIGHT)) self.assertFalse(light.is_on(self.hass, ENTITY_LIGHT))

View File

@ -15,7 +15,7 @@ class TestLightRfxtrx(unittest.TestCase):
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(0) self.hass = get_test_home_assistant()
self.hass.config.components = ['rfxtrx'] self.hass.config.components = ['rfxtrx']
def tearDown(self): def tearDown(self):

View File

@ -21,7 +21,7 @@ class TestMQTT(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(1) self.hass = get_test_home_assistant()
mock_mqtt_component(self.hass) mock_mqtt_component(self.hass)
self.calls = [] self.calls = []
@ -217,7 +217,7 @@ class TestMQTTCallbacks(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(1) self.hass = get_test_home_assistant()
# mock_mqtt_component(self.hass) # mock_mqtt_component(self.hass)
with mock.patch('paho.mqtt.client.Client'): with mock.patch('paho.mqtt.client.Client'):

View File

@ -111,7 +111,7 @@ class TestNotifyDemo(unittest.TestCase):
} }
script.call_from_config(self.hass, conf) script.call_from_config(self.hass, conf)
self.hass.pool.block_till_done() self.hass.block_till_done()
self.assertTrue(len(self.events) == 1) self.assertTrue(len(self.events) == 1)
assert { assert {
'message': 'Test 123 4', 'message': 'Test 123 4',

View File

@ -178,7 +178,7 @@ class EmailContentSensor(unittest.TestCase):
sensor.entity_id = "sensor.emailtest" sensor.entity_id = "sensor.emailtest"
sensor.update() sensor.update()
self.hass.pool.block_till_done() self.hass.block_till_done()
states_received.wait(5) states_received.wait(5)
self.assertEqual("Test Message", states[0].state) self.assertEqual("Test Message", states[0].state)

View File

@ -33,4 +33,4 @@ class TestRandomSensor(unittest.TestCase):
state = self.hass.states.get('sensor.test') state = self.hass.states.get('sensor.test')
self.assertLessEqual(int(state.state), config['sensor']['maximum']) self.assertLessEqual(int(state.state), config['sensor']['maximum'])
self.assertGreater(int(state.state), config['sensor']['minimum']) self.assertGreaterEqual(int(state.state), config['sensor']['minimum'])

View File

@ -16,7 +16,7 @@ class TestSensorRfxtrx(unittest.TestCase):
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(0) self.hass = get_test_home_assistant()
self.hass.config.components = ['rfxtrx'] self.hass.config.components = ['rfxtrx']
def tearDown(self): def tearDown(self):

View File

@ -15,7 +15,7 @@ class TestSwitchRfxtrx(unittest.TestCase):
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(0) self.hass = get_test_home_assistant()
self.hass.config.components = ['rfxtrx'] self.hass.config.components = ['rfxtrx']
def tearDown(self): def tearDown(self):

View File

@ -19,7 +19,7 @@ class TestConversation(unittest.TestCase):
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.ent_id = 'light.kitchen_lights' self.ent_id = 'light.kitchen_lights'
self.hass = get_test_home_assistant(3) self.hass = get_test_home_assistant()
self.hass.states.set(self.ent_id, 'on') self.hass.states.set(self.ent_id, 'on')
self.assertTrue(run_coroutine_threadsafe( self.assertTrue(run_coroutine_threadsafe(
core_components.async_setup(self.hass, {}), self.hass.loop core_components.async_setup(self.hass, {}), self.hass.loop

View File

@ -1,8 +1,6 @@
"""The tests for the emulated Hue component.""" """The tests for the emulated Hue component."""
import time import time
import json import json
import threading
import asyncio
import unittest import unittest
import requests import requests
@ -372,58 +370,3 @@ class TestEmulatedHueExposedByDefault(unittest.TestCase):
url, data=json.dumps(data), timeout=5, headers=req_headers) url, data=json.dumps(data), timeout=5, headers=req_headers)
return result return result
class MQTTBroker(object):
"""Encapsulates an embedded MQTT broker."""
def __init__(self, host, port):
"""Initialize a new instance."""
from hbmqtt.broker import Broker
self._loop = asyncio.new_event_loop()
hbmqtt_config = {
'listeners': {
'default': {
'max-connections': 50000,
'type': 'tcp',
'bind': '{}:{}'.format(host, port)
}
},
'auth': {
'plugins': ['auth.anonymous'],
'allow-anonymous': True
}
}
self._broker = Broker(config=hbmqtt_config, loop=self._loop)
self._thread = threading.Thread(target=self._run_loop)
self._started_ev = threading.Event()
def start(self):
"""Start the broker."""
self._thread.start()
self._started_ev.wait()
def stop(self):
"""Stop the broker."""
self._loop.call_soon_threadsafe(asyncio.async, self._broker.shutdown())
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
def _run_loop(self):
"""Run the loop."""
asyncio.set_event_loop(self._loop)
self._loop.run_until_complete(self._broker_coroutine())
self._started_ev.set()
self._loop.run_forever()
self._loop.close()
@asyncio.coroutine
def _broker_coroutine(self):
"""The Broker coroutine."""
yield from self._broker.start()

View File

@ -17,7 +17,7 @@ class TestInfluxDB(unittest.TestCase):
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(2) self.hass = get_test_home_assistant()
self.handler_method = None self.handler_method = None
self.hass.bus.listen = mock.Mock() self.hass.bus.listen = mock.Mock()

View File

@ -15,7 +15,7 @@ class TestLogentries(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(2) self.hass = get_test_home_assistant()
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started.""" """Stop everything that was started."""

View File

@ -16,7 +16,7 @@ class TestUpdater(unittest.TestCase):
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(2) self.hass = get_test_home_assistant()
self.log_config = {'logger': self.log_config = {'logger':
{'default': 'warning', 'logs': {'test': 'info'}}} {'default': 'warning', 'logs': {'test': 'info'}}}

View File

@ -14,7 +14,7 @@ class TestSplunk(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(2) self.hass = get_test_home_assistant()
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started.""" """Stop everything that was started."""

View File

@ -17,7 +17,7 @@ class TestStatsd(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(2) self.hass = get_test_home_assistant()
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started.""" """Stop everything that was started."""

View File

@ -56,7 +56,7 @@ def test_async_add_job_add_threaded_job_to_pool(mock_iscoro):
ha.HomeAssistant.async_add_job(hass, job) ha.HomeAssistant.async_add_job(hass, job)
assert len(hass.loop.call_soon.mock_calls) == 0 assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 0 assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 1 assert len(hass.pool.add_job.mock_calls) == 1
def test_async_run_job_calls_callback(): def test_async_run_job_calls_callback():
@ -91,7 +91,7 @@ class TestHomeAssistant(unittest.TestCase):
# pylint: disable=invalid-name # pylint: disable=invalid-name
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(0) self.hass = get_test_home_assistant()
# pylint: disable=invalid-name # pylint: disable=invalid-name
def tearDown(self): def tearDown(self):
@ -169,7 +169,6 @@ class TestEventBus(unittest.TestCase):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
self.bus = self.hass.bus self.bus = self.hass.bus
self.bus.listen('test_event', lambda x: len)
# pylint: disable=invalid-name # pylint: disable=invalid-name
def tearDown(self): def tearDown(self):
@ -178,6 +177,7 @@ class TestEventBus(unittest.TestCase):
def test_add_remove_listener(self): def test_add_remove_listener(self):
"""Test remove_listener method.""" """Test remove_listener method."""
self.hass.allow_pool = False
old_count = len(self.bus.listeners) old_count = len(self.bus.listeners)
def listener(_): pass def listener(_): pass
@ -195,8 +195,10 @@ class TestEventBus(unittest.TestCase):
def test_unsubscribe_listener(self): def test_unsubscribe_listener(self):
"""Test unsubscribe listener from returned function.""" """Test unsubscribe listener from returned function."""
self.hass.allow_pool = False
calls = [] calls = []
@ha.callback
def listener(event): def listener(event):
"""Mock listener.""" """Mock listener."""
calls.append(event) calls.append(event)
@ -217,6 +219,7 @@ class TestEventBus(unittest.TestCase):
def test_listen_once_event_with_callback(self): def test_listen_once_event_with_callback(self):
"""Test listen_once_event method.""" """Test listen_once_event method."""
self.hass.allow_pool = False
runs = [] runs = []
@ha.callback @ha.callback
@ -234,6 +237,7 @@ class TestEventBus(unittest.TestCase):
def test_listen_once_event_with_coroutine(self): def test_listen_once_event_with_coroutine(self):
"""Test listen_once_event method.""" """Test listen_once_event method."""
self.hass.allow_pool = False
runs = [] runs = []
@asyncio.coroutine @asyncio.coroutine
@ -279,6 +283,7 @@ class TestEventBus(unittest.TestCase):
def test_callback_event_listener(self): def test_callback_event_listener(self):
"""Test a event listener listeners.""" """Test a event listener listeners."""
self.hass.allow_pool = False
callback_calls = [] callback_calls = []
@ha.callback @ha.callback
@ -292,6 +297,7 @@ class TestEventBus(unittest.TestCase):
def test_coroutine_event_listener(self): def test_coroutine_event_listener(self):
"""Test a event listener listeners.""" """Test a event listener listeners."""
self.hass.allow_pool = False
coroutine_calls = [] coroutine_calls = []
@asyncio.coroutine @asyncio.coroutine
@ -366,10 +372,11 @@ class TestStateMachine(unittest.TestCase):
# pylint: disable=invalid-name # pylint: disable=invalid-name
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant(0) self.hass = get_test_home_assistant()
self.states = self.hass.states self.states = self.hass.states
self.states.set("light.Bowl", "on") self.states.set("light.Bowl", "on")
self.states.set("switch.AC", "off") self.states.set("switch.AC", "off")
self.hass.allow_pool = False
# pylint: disable=invalid-name # pylint: disable=invalid-name
def tearDown(self): def tearDown(self):
@ -413,8 +420,12 @@ class TestStateMachine(unittest.TestCase):
def test_remove(self): def test_remove(self):
"""Test remove method.""" """Test remove method."""
events = [] events = []
self.hass.bus.listen(EVENT_STATE_CHANGED,
lambda event: events.append(event)) @ha.callback
def callback(event):
events.append(event)
self.hass.bus.listen(EVENT_STATE_CHANGED, callback)
self.assertIn('light.bowl', self.states.entity_ids()) self.assertIn('light.bowl', self.states.entity_ids())
self.assertTrue(self.states.remove('light.bowl')) self.assertTrue(self.states.remove('light.bowl'))
@ -436,8 +447,11 @@ class TestStateMachine(unittest.TestCase):
"""Test insensitivty.""" """Test insensitivty."""
runs = [] runs = []
self.hass.bus.listen(EVENT_STATE_CHANGED, @ha.callback
lambda event: runs.append(event)) def callback(event):
runs.append(event)
self.hass.bus.listen(EVENT_STATE_CHANGED, callback)
self.states.set('light.BOWL', 'off') self.states.set('light.BOWL', 'off')
self.hass.block_till_done() self.hass.block_till_done()
@ -462,7 +476,12 @@ class TestStateMachine(unittest.TestCase):
def test_force_update(self): def test_force_update(self):
"""Test force update option.""" """Test force update option."""
events = [] events = []
self.hass.bus.listen(EVENT_STATE_CHANGED, lambda ev: events.append(ev))
@ha.callback
def callback(event):
events.append(event)
self.hass.bus.listen(EVENT_STATE_CHANGED, callback)
self.states.set('light.bowl', 'on') self.states.set('light.bowl', 'on')
self.hass.block_till_done() self.hass.block_till_done()
@ -504,6 +523,7 @@ class TestServiceRegistry(unittest.TestCase):
def test_has_service(self): def test_has_service(self):
"""Test has_service method.""" """Test has_service method."""
self.hass.allow_pool = False
self.assertTrue( self.assertTrue(
self.services.has_service("tesT_domaiN", "tesT_servicE")) self.services.has_service("tesT_domaiN", "tesT_servicE"))
self.assertFalse( self.assertFalse(
@ -513,6 +533,7 @@ class TestServiceRegistry(unittest.TestCase):
def test_services(self): def test_services(self):
"""Test services.""" """Test services."""
self.hass.allow_pool = False
expected = { expected = {
'test_domain': {'test_service': {'description': '', 'fields': {}}} 'test_domain': {'test_service': {'description': '', 'fields': {}}}
} }
@ -535,6 +556,7 @@ class TestServiceRegistry(unittest.TestCase):
def test_call_non_existing_with_blocking(self): def test_call_non_existing_with_blocking(self):
"""Test non-existing with blocking.""" """Test non-existing with blocking."""
self.hass.allow_pool = False
prior = ha.SERVICE_CALL_LIMIT prior = ha.SERVICE_CALL_LIMIT
try: try:
ha.SERVICE_CALL_LIMIT = 0.01 ha.SERVICE_CALL_LIMIT = 0.01
@ -545,6 +567,7 @@ class TestServiceRegistry(unittest.TestCase):
def test_async_service(self): def test_async_service(self):
"""Test registering and calling an async service.""" """Test registering and calling an async service."""
self.hass.allow_pool = False
calls = [] calls = []
@asyncio.coroutine @asyncio.coroutine
@ -561,6 +584,7 @@ class TestServiceRegistry(unittest.TestCase):
def test_callback_service(self): def test_callback_service(self):
"""Test registering and calling an async service.""" """Test registering and calling an async service."""
self.hass.allow_pool = False
calls = [] calls = []
@ha.callback @ha.callback
@ -629,8 +653,9 @@ class TestWorkerPool(unittest.TestCase):
def register_call(_): def register_call(_):
calls.append(1) calls.append(1)
pool.add_job(ha.JobPriority.EVENT_DEFAULT, (malicious_job, None)) pool.add_job((malicious_job, None))
pool.add_job(ha.JobPriority.EVENT_DEFAULT, (register_call, None)) pool.block_till_done()
pool.add_job((register_call, None))
pool.block_till_done() pool.block_till_done()
self.assertEqual(1, len(calls)) self.assertEqual(1, len(calls))