Clean up some async stuff (#3915)

* Clean up some async stuff

* Adjust comments

* Pass hass instance to eventbus
This commit is contained in:
Paulus Schoutsen 2016-10-17 19:38:41 -07:00 committed by GitHub
parent daea93d9f9
commit 4c8d1d9d2f
7 changed files with 139 additions and 111 deletions

View File

@ -79,8 +79,7 @@ class NuimoThread(threading.Thread):
self._name = name self._name = name
self._hass_is_running = True self._hass_is_running = True
self._nuimo = None self._nuimo = None
self._listener = hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.stop)
self.stop)
def run(self): def run(self):
"""Setup connection or be idle.""" """Setup connection or be idle."""
@ -99,8 +98,6 @@ class NuimoThread(threading.Thread):
"""Terminate Thread by unsetting flag.""" """Terminate Thread by unsetting flag."""
_LOGGER.debug('Stopping thread for Nuimo %s', self._mac) _LOGGER.debug('Stopping thread for Nuimo %s', self._mac)
self._hass_is_running = False self._hass_is_running = False
self._hass.bus.remove_listener(EVENT_HOMEASSISTANT_STOP,
self._listener)
def _attach(self): def _attach(self):
"""Create a nuimo object from mac address or discovery.""" """Create a nuimo object from mac address or discovery."""

View File

@ -8,7 +8,6 @@ of entities and react to changes.
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import enum import enum
import functools as ft
import logging import logging
import os import os
import re import re
@ -137,8 +136,8 @@ class HomeAssistant(object):
self.executor = ThreadPoolExecutor(max_workers=5) self.executor = ThreadPoolExecutor(max_workers=5)
self.loop.set_default_executor(self.executor) self.loop.set_default_executor(self.executor)
self.loop.set_exception_handler(self._async_exception_handler) self.loop.set_exception_handler(self._async_exception_handler)
self.pool = pool = create_worker_pool() self.pool = create_worker_pool()
self.bus = EventBus(pool, self.loop) self.bus = EventBus(self)
self.services = ServiceRegistry(self.bus, self.add_job, self.loop) self.services = ServiceRegistry(self.bus, self.add_job, self.loop)
self.states = StateMachine(self.bus, self.loop) self.states = StateMachine(self.bus, self.loop)
self.config = Config() # type: Config self.config = Config() # type: Config
@ -218,8 +217,8 @@ class HomeAssistant(object):
""" """
# pylint: disable=protected-access # pylint: disable=protected-access
self.loop._thread_ident = threading.get_ident() self.loop._thread_ident = threading.get_ident()
async_create_timer(self) _async_create_timer(self)
async_monitor_worker_pool(self) _async_monitor_worker_pool(self)
self.bus.async_fire(EVENT_HOMEASSISTANT_START) self.bus.async_fire(EVENT_HOMEASSISTANT_START)
yield from self.loop.run_in_executor(None, self.pool.block_till_done) yield from self.loop.run_in_executor(None, self.pool.block_till_done)
self.state = CoreState.running self.state = CoreState.running
@ -235,9 +234,12 @@ class HomeAssistant(object):
""" """
self.pool.add_job(priority, (target,) + args) self.pool.add_job(priority, (target,) + args)
@callback
def async_add_job(self, target: Callable[..., None], *args: Any): def async_add_job(self, target: Callable[..., None], *args: Any):
"""Add a job from within the eventloop. """Add a job from within the eventloop.
This method must be run in the event loop.
target: target to call. target: target to call.
args: parameters for method to call. args: parameters for method to call.
""" """
@ -248,9 +250,12 @@ class HomeAssistant(object):
else: else:
self.add_job(target, *args) self.add_job(target, *args)
@callback
def async_run_job(self, target: Callable[..., None], *args: Any): def async_run_job(self, target: Callable[..., None], *args: Any):
"""Run a job from within the event loop. """Run a job from within the event loop.
This method must be run in the event loop.
target: target to call. target: target to call.
args: parameters for method to call. args: parameters for method to call.
""" """
@ -369,7 +374,10 @@ class Event(object):
self.time_fired = time_fired or dt_util.utcnow() self.time_fired = time_fired or dt_util.utcnow()
def as_dict(self): def as_dict(self):
"""Create a dict representation of this Event.""" """Create a dict representation of this Event.
Async friendly.
"""
return { return {
'event_type': self.event_type, 'event_type': self.event_type,
'data': dict(self.data), 'data': dict(self.data),
@ -400,13 +408,12 @@ class Event(object):
class EventBus(object): class EventBus(object):
"""Allows firing of and listening for events.""" """Allows firing of and listening for events."""
def __init__(self, pool: util.ThreadPool, def __init__(self, hass: HomeAssistant) -> None:
loop: asyncio.AbstractEventLoop) -> None:
"""Initialize a new event bus.""" """Initialize a new event bus."""
self._listeners = {} self._listeners = {}
self._pool = pool self._hass = hass
self._loop = loop
@callback
def async_listeners(self): def async_listeners(self):
"""Dict with events and the number of listeners. """Dict with events and the number of listeners.
@ -419,23 +426,25 @@ class EventBus(object):
def listeners(self): def listeners(self):
"""Dict with events and the number of listeners.""" """Dict with events and the number of listeners."""
return run_callback_threadsafe( return run_callback_threadsafe(
self._loop, self.async_listeners self._hass.loop, self.async_listeners
).result() ).result()
def fire(self, event_type: str, event_data=None, origin=EventOrigin.local): def fire(self, event_type: str, event_data=None, origin=EventOrigin.local):
"""Fire an event.""" """Fire an event."""
if not self._pool.running: self._hass.loop.call_soon_threadsafe(self.async_fire, event_type,
raise HomeAssistantError('Home Assistant has shut down.') event_data, origin)
self._loop.call_soon_threadsafe(self.async_fire, event_type,
event_data, origin)
@callback
def async_fire(self, event_type: str, event_data=None, def async_fire(self, event_type: str, event_data=None,
origin=EventOrigin.local, wait=False): origin=EventOrigin.local, wait=False):
"""Fire an event. """Fire an event.
This method must be run in the event loop. This method must be run in the event loop.
""" """
if event_type != EVENT_HOMEASSISTANT_STOP and \
self._hass.state == CoreState.stopping:
raise HomeAssistantError('Home Assistant is shutting down.')
# Copy the list of the current listeners because some listeners # Copy the list of the current listeners because some listeners
# remove themselves as a listener while being executed which # remove themselves as a listener while being executed which
# causes the iterator to be confused. # causes the iterator to be confused.
@ -450,20 +459,8 @@ class EventBus(object):
if not listeners: if not listeners:
return return
job_priority = JobPriority.from_event_type(event_type)
sync_jobs = []
for func in listeners: for func in listeners:
if asyncio.iscoroutinefunction(func): self._hass.async_add_job(func, event)
self._loop.create_task(func(event))
elif is_callback(func):
self._loop.call_soon(func, event)
else:
sync_jobs.append((job_priority, (func, event)))
# Send all the sync jobs at once
if sync_jobs:
self._pool.add_many_jobs(sync_jobs)
def listen(self, event_type, listener): def listen(self, event_type, listener):
"""Listen for all events or events of a specific type. """Listen for all events or events of a specific type.
@ -471,16 +468,17 @@ class EventBus(object):
To listen to all events specify the constant ``MATCH_ALL`` To listen to all events specify the constant ``MATCH_ALL``
as event_type. as event_type.
""" """
future = run_callback_threadsafe( async_remove_listener = run_callback_threadsafe(
self._loop, self.async_listen, event_type, listener) self._hass.loop, self.async_listen, event_type, listener).result()
future.result()
def remove_listener(): def remove_listener():
"""Remove the listener.""" """Remove the listener."""
self._remove_listener(event_type, listener) run_callback_threadsafe(
self._hass.loop, async_remove_listener).result()
return remove_listener return remove_listener
@callback
def async_listen(self, event_type, listener): def async_listen(self, event_type, listener):
"""Listen for all events or events of a specific type. """Listen for all events or events of a specific type.
@ -496,7 +494,7 @@ class EventBus(object):
def remove_listener(): def remove_listener():
"""Remove the listener.""" """Remove the listener."""
self.async_remove_listener(event_type, listener) self._async_remove_listener(event_type, listener)
return remove_listener return remove_listener
@ -508,26 +506,18 @@ class EventBus(object):
Returns function to unsubscribe the listener. Returns function to unsubscribe the listener.
""" """
@ft.wraps(listener) async_remove_listener = run_callback_threadsafe(
def onetime_listener(event): self._hass.loop, self.async_listen_once, event_type, listener,
"""Remove listener from eventbus and then fire listener.""" ).result()
if hasattr(onetime_listener, 'run'):
return
# Set variable so that we will never run twice.
# Because the event bus might have to wait till a thread comes
# available to execute this listener it might occur that the
# listener gets lined up twice to be executed.
# This will make sure the second time it does nothing.
setattr(onetime_listener, 'run', True)
remove_listener() def remove_listener():
"""Remove the listener."""
listener(event) run_callback_threadsafe(
self._hass.loop, async_remove_listener).result()
remove_listener = self.listen(event_type, onetime_listener)
return remove_listener return remove_listener
@callback
def async_listen_once(self, event_type, listener): def async_listen_once(self, event_type, listener):
"""Listen once for event of a specific type. """Listen once for event of a specific type.
@ -538,8 +528,7 @@ class EventBus(object):
This method must be run in the event loop. This method must be run in the event loop.
""" """
@ft.wraps(listener) @callback
@asyncio.coroutine
def onetime_listener(event): def onetime_listener(event):
"""Remove listener from eventbus and then fire listener.""" """Remove listener from eventbus and then fire listener."""
if hasattr(onetime_listener, 'run'): if hasattr(onetime_listener, 'run'):
@ -550,34 +539,14 @@ class EventBus(object):
# multiple times as well. # multiple times as well.
# This will make sure the second time it does nothing. # This will make sure the second time it does nothing.
setattr(onetime_listener, 'run', True) setattr(onetime_listener, 'run', True)
self._async_remove_listener(event_type, onetime_listener)
self.async_remove_listener(event_type, onetime_listener) self._hass.async_run_job(listener, event)
if asyncio.iscoroutinefunction(listener): return self.async_listen(event_type, onetime_listener)
yield from listener(event)
else:
job_priority = JobPriority.from_event_type(event.event_type)
self._pool.add_job(job_priority, (listener, event))
self.async_listen(event_type, onetime_listener) @callback
def _async_remove_listener(self, event_type, listener):
return onetime_listener
def remove_listener(self, event_type, listener):
"""Remove a listener of a specific event_type. (DEPRECATED 0.28)."""
_LOGGER.warning('bus.remove_listener has been deprecated. Please use '
'the function returned from calling listen.')
self._remove_listener(event_type, listener)
def _remove_listener(self, event_type, listener):
"""Remove a listener of a specific event_type."""
future = run_callback_threadsafe(
self._loop,
self.async_remove_listener, event_type, listener
)
future.result()
def async_remove_listener(self, event_type, listener):
"""Remove a listener of a specific event_type. """Remove a listener of a specific event_type.
This method must be run in the event loop. This method must be run in the event loop.
@ -644,6 +613,8 @@ class State(object):
def as_dict(self): def as_dict(self):
"""Return a dict representation of the State. """Return a dict representation of the State.
Async friendly.
To be used for JSON serialization. To be used for JSON serialization.
Ensures: state == State.from_dict(state.as_dict()) Ensures: state == State.from_dict(state.as_dict())
""" """
@ -657,6 +628,8 @@ class State(object):
def from_dict(cls, json_dict): def from_dict(cls, json_dict):
"""Initialize a state from a dict. """Initialize a state from a dict.
Async friendly.
Ensures: state == State.from_json_dict(state.to_json_dict()) Ensures: state == State.from_json_dict(state.to_json_dict())
""" """
if not (json_dict and 'entity_id' in json_dict and if not (json_dict and 'entity_id' in json_dict and
@ -709,8 +682,12 @@ class StateMachine(object):
) )
return future.result() return future.result()
@callback
def async_entity_ids(self, domain_filter=None): def async_entity_ids(self, domain_filter=None):
"""List of entity ids that are being tracked.""" """List of entity ids that are being tracked.
This method must be run in the event loop.
"""
if domain_filter is None: if domain_filter is None:
return list(self._states.keys()) return list(self._states.keys())
@ -723,6 +700,7 @@ class StateMachine(object):
"""Create a list of all states.""" """Create a list of all states."""
return run_callback_threadsafe(self._loop, self.async_all).result() return run_callback_threadsafe(self._loop, self.async_all).result()
@callback
def async_all(self): def async_all(self):
"""Create a list of all states. """Create a list of all states.
@ -763,6 +741,7 @@ class StateMachine(object):
return run_callback_threadsafe( return run_callback_threadsafe(
self._loop, self.async_remove, entity_id).result() self._loop, self.async_remove, entity_id).result()
@callback
def async_remove(self, entity_id): def async_remove(self, entity_id):
"""Remove the state of an entity. """Remove the state of an entity.
@ -800,6 +779,7 @@ class StateMachine(object):
self.async_set, entity_id, new_state, attributes, force_update, self.async_set, entity_id, new_state, attributes, force_update,
).result() ).result()
@callback
def async_set(self, entity_id, new_state, attributes=None, def async_set(self, entity_id, new_state, attributes=None,
force_update=False): force_update=False):
"""Set the state of an entity, add entity if it does not exist. """Set the state of an entity, add entity if it does not exist.
@ -908,14 +888,21 @@ class ServiceRegistry(object):
self._loop, self.async_services, self._loop, self.async_services,
).result() ).result()
@callback
def async_services(self): def async_services(self):
"""Dict with per domain a list of available services.""" """Dict with per domain a list of available services.
This method must be run in the event loop.
"""
return {domain: {key: value.as_dict() for key, value return {domain: {key: value.as_dict() for key, value
in self._services[domain].items()} in self._services[domain].items()}
for domain in self._services} for domain in self._services}
def has_service(self, domain, service): def has_service(self, domain, service):
"""Test if specified service exists.""" """Test if specified service exists.
Async friendly.
"""
return service.lower() in self._services.get(domain.lower(), []) return service.lower() in self._services.get(domain.lower(), [])
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
@ -935,6 +922,7 @@ class ServiceRegistry(object):
schema schema
).result() ).result()
@callback
def async_register(self, domain, service, service_func, description=None, def async_register(self, domain, service, service_func, description=None,
schema=None): schema=None):
""" """
@ -985,7 +973,7 @@ class ServiceRegistry(object):
self._loop self._loop
).result() ).result()
@callback @asyncio.coroutine
def async_call(self, domain, service, service_data=None, blocking=False): def async_call(self, domain, service, service_data=None, blocking=False):
""" """
Call a service. Call a service.
@ -1121,18 +1109,27 @@ class Config(object):
self.config_dir = None self.config_dir = None
def distance(self: object, lat: float, lon: float) -> float: def distance(self: object, lat: float, lon: float) -> float:
"""Calculate distance from Home Assistant.""" """Calculate distance from Home Assistant.
Async friendly.
"""
return self.units.length( return self.units.length(
location.distance(self.latitude, self.longitude, lat, lon), 'm') location.distance(self.latitude, self.longitude, lat, lon), 'm')
def path(self, *path): def path(self, *path):
"""Generate path to the file within the config dir.""" """Generate path to the file within the config dir.
Async friendly.
"""
if self.config_dir is None: if self.config_dir is None:
raise HomeAssistantError("config_dir is not set") raise HomeAssistantError("config_dir is not set")
return os.path.join(self.config_dir, *path) return os.path.join(self.config_dir, *path)
def as_dict(self): def as_dict(self):
"""Create a dict representation of this dict.""" """Create a dict representation of this dict.
Async friendly.
"""
time_zone = self.time_zone or dt_util.UTC time_zone = self.time_zone or dt_util.UTC
return { return {
@ -1147,7 +1144,7 @@ class Config(object):
} }
def async_create_timer(hass, interval=TIMER_INTERVAL): def _async_create_timer(hass, interval=TIMER_INTERVAL):
"""Create a timer that will start on HOMEASSISTANT_START.""" """Create a timer that will start on HOMEASSISTANT_START."""
stop_event = asyncio.Event(loop=hass.loop) stop_event = asyncio.Event(loop=hass.loop)
@ -1230,7 +1227,7 @@ def create_worker_pool(worker_count=None):
return util.ThreadPool(job_handler, worker_count) return util.ThreadPool(job_handler, worker_count)
def async_monitor_worker_pool(hass): def _async_monitor_worker_pool(hass):
"""Create a monitor for the thread pool to check if pool is misbehaving.""" """Create a monitor for the thread pool to check if pool is misbehaving."""
busy_threshold = hass.pool.worker_count * 3 busy_threshold = hass.pool.worker_count * 3

View File

@ -124,9 +124,9 @@ class HomeAssistant(ha.HomeAssistant):
self.remote_api = remote_api self.remote_api = remote_api
self.loop = loop or asyncio.get_event_loop() self.loop = loop or asyncio.get_event_loop()
self.pool = pool = ha.create_worker_pool() self.pool = ha.create_worker_pool()
self.bus = EventBus(remote_api, pool, self.loop) self.bus = EventBus(remote_api, self)
self.services = ha.ServiceRegistry(self.bus, self.add_job, self.loop) self.services = ha.ServiceRegistry(self.bus, self.add_job, self.loop)
self.states = StateMachine(self.bus, self.loop, self.remote_api) self.states = StateMachine(self.bus, self.loop, self.remote_api)
self.config = ha.Config() self.config = ha.Config()
@ -143,7 +143,7 @@ class HomeAssistant(ha.HomeAssistant):
'Unable to setup local API to receive events') 'Unable to setup local API to receive events')
self.state = ha.CoreState.starting self.state = ha.CoreState.starting
ha.async_create_timer(self) ha._async_create_timer(self) # pylint: disable=protected-access
self.bus.fire(ha.EVENT_HOMEASSISTANT_START, self.bus.fire(ha.EVENT_HOMEASSISTANT_START,
origin=ha.EventOrigin.remote) origin=ha.EventOrigin.remote)
@ -180,9 +180,9 @@ class EventBus(ha.EventBus):
"""EventBus implementation that forwards fire_event to remote API.""" """EventBus implementation that forwards fire_event to remote API."""
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
def __init__(self, api, pool, loop): def __init__(self, api, hass):
"""Initalize the eventbus.""" """Initalize the eventbus."""
super().__init__(pool, loop) super().__init__(hass)
self._api = api self._api = api
def fire(self, event_type, event_data=None, origin=ha.EventOrigin.local): def fire(self, event_type, event_data=None, origin=ha.EventOrigin.local):

View File

@ -76,8 +76,8 @@ def get_test_home_assistant(num_threads=None):
"""Fake stop.""" """Fake stop."""
yield None yield None
@patch.object(ha, 'async_create_timer') @patch.object(ha, '_async_create_timer')
@patch.object(ha, 'async_monitor_worker_pool') @patch.object(ha, '_async_monitor_worker_pool')
@patch.object(hass.loop, 'add_signal_handler') @patch.object(hass.loop, 'add_signal_handler')
@patch.object(hass.loop, 'run_forever') @patch.object(hass.loop, 'run_forever')
@patch.object(hass.loop, 'close') @patch.object(hass.loop, 'close')

View File

@ -145,14 +145,14 @@ class TestAPI(unittest.TestCase):
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")), requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
data=json.dumps({"state": "not_to_be_set"}), data=json.dumps({"state": "not_to_be_set"}),
headers=HA_HEADERS) headers=HA_HEADERS)
hass.bus._pool.block_till_done() hass.block_till_done()
self.assertEqual(0, len(events)) self.assertEqual(0, len(events))
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")), requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
data=json.dumps({"state": "not_to_be_set", data=json.dumps({"state": "not_to_be_set",
"force_update": True}), "force_update": True}),
headers=HA_HEADERS) headers=HA_HEADERS)
hass.bus._pool.block_till_done() hass.block_till_done()
self.assertEqual(1, len(events)) self.assertEqual(1, len(events))
# pylint: disable=invalid-name # pylint: disable=invalid-name

View File

@ -179,19 +179,16 @@ class TestEventBus(unittest.TestCase):
def listener(_): pass def listener(_): pass
self.bus.listen('test', listener) unsub = self.bus.listen('test', listener)
self.assertEqual(old_count + 1, len(self.bus.listeners)) self.assertEqual(old_count + 1, len(self.bus.listeners))
# Try deleting a non registered listener, nothing should happen
self.bus._remove_listener('test', lambda x: len)
# Remove listener # Remove listener
self.bus._remove_listener('test', listener) unsub()
self.assertEqual(old_count, len(self.bus.listeners)) self.assertEqual(old_count, len(self.bus.listeners))
# Try deleting listener while category doesn't exist either # Should do nothing now
self.bus._remove_listener('test', listener) unsub()
def test_unsubscribe_listener(self): def test_unsubscribe_listener(self):
"""Test unsubscribe listener from returned function.""" """Test unsubscribe listener from returned function."""
@ -215,11 +212,48 @@ class TestEventBus(unittest.TestCase):
assert len(calls) == 1 assert len(calls) == 1
def test_listen_once_event(self): def test_listen_once_event_with_callback(self):
"""Test listen_once_event method.""" """Test listen_once_event method."""
runs = [] runs = []
self.bus.listen_once('test_event', lambda x: runs.append(1)) @ha.callback
def event_handler(event):
runs.append(event)
self.bus.listen_once('test_event', event_handler)
self.bus.fire('test_event')
# Second time it should not increase runs
self.bus.fire('test_event')
self.hass.block_till_done()
self.assertEqual(1, len(runs))
def test_listen_once_event_with_coroutine(self):
"""Test listen_once_event method."""
runs = []
@asyncio.coroutine
def event_handler(event):
runs.append(event)
self.bus.listen_once('test_event', event_handler)
self.bus.fire('test_event')
# Second time it should not increase runs
self.bus.fire('test_event')
self.hass.block_till_done()
self.assertEqual(1, len(runs))
def test_listen_once_event_with_thread(self):
"""Test listen_once_event method."""
runs = []
def event_handler(event):
runs.append(event)
self.bus.listen_once('test_event', event_handler)
self.bus.fire('test_event') self.bus.fire('test_event')
# Second time it should not increase runs # Second time it should not increase runs
@ -604,7 +638,7 @@ class TestWorkerPoolMonitor(object):
schedule_handle = MagicMock() schedule_handle = MagicMock()
hass.loop.call_later.return_value = schedule_handle hass.loop.call_later.return_value = schedule_handle
ha.async_monitor_worker_pool(hass) ha._async_monitor_worker_pool(hass)
assert hass.loop.call_later.called assert hass.loop.call_later.called
assert hass.bus.async_listen_once.called assert hass.bus.async_listen_once.called
assert not schedule_handle.called assert not schedule_handle.called
@ -650,7 +684,7 @@ class TestAsyncCreateTimer(object):
now.second = 1 now.second = 1
mock_utcnow.reset_mock() mock_utcnow.reset_mock()
ha.async_create_timer(hass) ha._async_create_timer(hass)
assert len(hass.bus.async_listen_once.mock_calls) == 2 assert len(hass.bus.async_listen_once.mock_calls) == 2
start_timer = hass.bus.async_listen_once.mock_calls[1][1][1] start_timer = hass.bus.async_listen_once.mock_calls[1][1][1]

View File

@ -69,7 +69,7 @@ def setUpModule(): # pylint: disable=invalid-name
{http.DOMAIN: {http.CONF_API_PASSWORD: API_PASSWORD, {http.DOMAIN: {http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SLAVE_PORT}}) http.CONF_SERVER_PORT: SLAVE_PORT}})
with patch.object(ha, 'async_create_timer', return_value=None): with patch.object(ha, '_async_create_timer', return_value=None):
slave.start() slave.start()