diff --git a/homeassistant/components/nuimo_controller.py b/homeassistant/components/nuimo_controller.py index b383b4f45fc..e3d8f0238cf 100644 --- a/homeassistant/components/nuimo_controller.py +++ b/homeassistant/components/nuimo_controller.py @@ -79,8 +79,7 @@ class NuimoThread(threading.Thread): self._name = name self._hass_is_running = True self._nuimo = None - self._listener = hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, - self.stop) + hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.stop) def run(self): """Setup connection or be idle.""" @@ -99,8 +98,6 @@ class NuimoThread(threading.Thread): """Terminate Thread by unsetting flag.""" _LOGGER.debug('Stopping thread for Nuimo %s', self._mac) self._hass_is_running = False - self._hass.bus.remove_listener(EVENT_HOMEASSISTANT_STOP, - self._listener) def _attach(self): """Create a nuimo object from mac address or discovery.""" diff --git a/homeassistant/core.py b/homeassistant/core.py index ebd24558a40..bd59db59f05 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -8,7 +8,6 @@ of entities and react to changes. import asyncio from concurrent.futures import ThreadPoolExecutor import enum -import functools as ft import logging import os import re @@ -137,8 +136,8 @@ class HomeAssistant(object): self.executor = ThreadPoolExecutor(max_workers=5) self.loop.set_default_executor(self.executor) self.loop.set_exception_handler(self._async_exception_handler) - self.pool = pool = create_worker_pool() - self.bus = EventBus(pool, self.loop) + self.pool = create_worker_pool() + self.bus = EventBus(self) self.services = ServiceRegistry(self.bus, self.add_job, self.loop) self.states = StateMachine(self.bus, self.loop) self.config = Config() # type: Config @@ -218,8 +217,8 @@ class HomeAssistant(object): """ # pylint: disable=protected-access self.loop._thread_ident = threading.get_ident() - async_create_timer(self) - async_monitor_worker_pool(self) + _async_create_timer(self) + _async_monitor_worker_pool(self) self.bus.async_fire(EVENT_HOMEASSISTANT_START) yield from self.loop.run_in_executor(None, self.pool.block_till_done) self.state = CoreState.running @@ -235,9 +234,12 @@ class HomeAssistant(object): """ self.pool.add_job(priority, (target,) + args) + @callback def async_add_job(self, target: Callable[..., None], *args: Any): """Add a job from within the eventloop. + This method must be run in the event loop. + target: target to call. args: parameters for method to call. """ @@ -248,9 +250,12 @@ class HomeAssistant(object): else: self.add_job(target, *args) + @callback def async_run_job(self, target: Callable[..., None], *args: Any): """Run a job from within the event loop. + This method must be run in the event loop. + target: target to call. args: parameters for method to call. """ @@ -369,7 +374,10 @@ class Event(object): self.time_fired = time_fired or dt_util.utcnow() def as_dict(self): - """Create a dict representation of this Event.""" + """Create a dict representation of this Event. + + Async friendly. + """ return { 'event_type': self.event_type, 'data': dict(self.data), @@ -400,13 +408,12 @@ class Event(object): class EventBus(object): """Allows firing of and listening for events.""" - def __init__(self, pool: util.ThreadPool, - loop: asyncio.AbstractEventLoop) -> None: + def __init__(self, hass: HomeAssistant) -> None: """Initialize a new event bus.""" self._listeners = {} - self._pool = pool - self._loop = loop + self._hass = hass + @callback def async_listeners(self): """Dict with events and the number of listeners. @@ -419,23 +426,25 @@ class EventBus(object): def listeners(self): """Dict with events and the number of listeners.""" return run_callback_threadsafe( - self._loop, self.async_listeners + self._hass.loop, self.async_listeners ).result() def fire(self, event_type: str, event_data=None, origin=EventOrigin.local): """Fire an event.""" - if not self._pool.running: - raise HomeAssistantError('Home Assistant has shut down.') - - self._loop.call_soon_threadsafe(self.async_fire, event_type, - event_data, origin) + self._hass.loop.call_soon_threadsafe(self.async_fire, event_type, + event_data, origin) + @callback def async_fire(self, event_type: str, event_data=None, origin=EventOrigin.local, wait=False): """Fire an event. This method must be run in the event loop. """ + if event_type != EVENT_HOMEASSISTANT_STOP and \ + self._hass.state == CoreState.stopping: + raise HomeAssistantError('Home Assistant is shutting down.') + # Copy the list of the current listeners because some listeners # remove themselves as a listener while being executed which # causes the iterator to be confused. @@ -450,20 +459,8 @@ class EventBus(object): if not listeners: return - job_priority = JobPriority.from_event_type(event_type) - - sync_jobs = [] for func in listeners: - if asyncio.iscoroutinefunction(func): - self._loop.create_task(func(event)) - elif is_callback(func): - self._loop.call_soon(func, event) - else: - sync_jobs.append((job_priority, (func, event))) - - # Send all the sync jobs at once - if sync_jobs: - self._pool.add_many_jobs(sync_jobs) + self._hass.async_add_job(func, event) def listen(self, event_type, listener): """Listen for all events or events of a specific type. @@ -471,16 +468,17 @@ class EventBus(object): To listen to all events specify the constant ``MATCH_ALL`` as event_type. """ - future = run_callback_threadsafe( - self._loop, self.async_listen, event_type, listener) - future.result() + async_remove_listener = run_callback_threadsafe( + self._hass.loop, self.async_listen, event_type, listener).result() def remove_listener(): """Remove the listener.""" - self._remove_listener(event_type, listener) + run_callback_threadsafe( + self._hass.loop, async_remove_listener).result() return remove_listener + @callback def async_listen(self, event_type, listener): """Listen for all events or events of a specific type. @@ -496,7 +494,7 @@ class EventBus(object): def remove_listener(): """Remove the listener.""" - self.async_remove_listener(event_type, listener) + self._async_remove_listener(event_type, listener) return remove_listener @@ -508,26 +506,18 @@ class EventBus(object): Returns function to unsubscribe the listener. """ - @ft.wraps(listener) - def onetime_listener(event): - """Remove listener from eventbus and then fire listener.""" - if hasattr(onetime_listener, 'run'): - return - # Set variable so that we will never run twice. - # Because the event bus might have to wait till a thread comes - # available to execute this listener it might occur that the - # listener gets lined up twice to be executed. - # This will make sure the second time it does nothing. - setattr(onetime_listener, 'run', True) + async_remove_listener = run_callback_threadsafe( + self._hass.loop, self.async_listen_once, event_type, listener, + ).result() - remove_listener() - - listener(event) - - remove_listener = self.listen(event_type, onetime_listener) + def remove_listener(): + """Remove the listener.""" + run_callback_threadsafe( + self._hass.loop, async_remove_listener).result() return remove_listener + @callback def async_listen_once(self, event_type, listener): """Listen once for event of a specific type. @@ -538,8 +528,7 @@ class EventBus(object): This method must be run in the event loop. """ - @ft.wraps(listener) - @asyncio.coroutine + @callback def onetime_listener(event): """Remove listener from eventbus and then fire listener.""" if hasattr(onetime_listener, 'run'): @@ -550,34 +539,14 @@ class EventBus(object): # multiple times as well. # This will make sure the second time it does nothing. setattr(onetime_listener, 'run', True) + self._async_remove_listener(event_type, onetime_listener) - self.async_remove_listener(event_type, onetime_listener) + self._hass.async_run_job(listener, event) - if asyncio.iscoroutinefunction(listener): - yield from listener(event) - else: - job_priority = JobPriority.from_event_type(event.event_type) - self._pool.add_job(job_priority, (listener, event)) + return self.async_listen(event_type, onetime_listener) - self.async_listen(event_type, onetime_listener) - - return onetime_listener - - def remove_listener(self, event_type, listener): - """Remove a listener of a specific event_type. (DEPRECATED 0.28).""" - _LOGGER.warning('bus.remove_listener has been deprecated. Please use ' - 'the function returned from calling listen.') - self._remove_listener(event_type, listener) - - def _remove_listener(self, event_type, listener): - """Remove a listener of a specific event_type.""" - future = run_callback_threadsafe( - self._loop, - self.async_remove_listener, event_type, listener - ) - future.result() - - def async_remove_listener(self, event_type, listener): + @callback + def _async_remove_listener(self, event_type, listener): """Remove a listener of a specific event_type. This method must be run in the event loop. @@ -644,6 +613,8 @@ class State(object): def as_dict(self): """Return a dict representation of the State. + Async friendly. + To be used for JSON serialization. Ensures: state == State.from_dict(state.as_dict()) """ @@ -657,6 +628,8 @@ class State(object): def from_dict(cls, json_dict): """Initialize a state from a dict. + Async friendly. + Ensures: state == State.from_json_dict(state.to_json_dict()) """ if not (json_dict and 'entity_id' in json_dict and @@ -709,8 +682,12 @@ class StateMachine(object): ) return future.result() + @callback def async_entity_ids(self, domain_filter=None): - """List of entity ids that are being tracked.""" + """List of entity ids that are being tracked. + + This method must be run in the event loop. + """ if domain_filter is None: return list(self._states.keys()) @@ -723,6 +700,7 @@ class StateMachine(object): """Create a list of all states.""" return run_callback_threadsafe(self._loop, self.async_all).result() + @callback def async_all(self): """Create a list of all states. @@ -763,6 +741,7 @@ class StateMachine(object): return run_callback_threadsafe( self._loop, self.async_remove, entity_id).result() + @callback def async_remove(self, entity_id): """Remove the state of an entity. @@ -800,6 +779,7 @@ class StateMachine(object): self.async_set, entity_id, new_state, attributes, force_update, ).result() + @callback def async_set(self, entity_id, new_state, attributes=None, force_update=False): """Set the state of an entity, add entity if it does not exist. @@ -908,14 +888,21 @@ class ServiceRegistry(object): self._loop, self.async_services, ).result() + @callback def async_services(self): - """Dict with per domain a list of available services.""" + """Dict with per domain a list of available services. + + This method must be run in the event loop. + """ return {domain: {key: value.as_dict() for key, value in self._services[domain].items()} for domain in self._services} def has_service(self, domain, service): - """Test if specified service exists.""" + """Test if specified service exists. + + Async friendly. + """ return service.lower() in self._services.get(domain.lower(), []) # pylint: disable=too-many-arguments @@ -935,6 +922,7 @@ class ServiceRegistry(object): schema ).result() + @callback def async_register(self, domain, service, service_func, description=None, schema=None): """ @@ -985,7 +973,7 @@ class ServiceRegistry(object): self._loop ).result() - @callback + @asyncio.coroutine def async_call(self, domain, service, service_data=None, blocking=False): """ Call a service. @@ -1121,18 +1109,27 @@ class Config(object): self.config_dir = None def distance(self: object, lat: float, lon: float) -> float: - """Calculate distance from Home Assistant.""" + """Calculate distance from Home Assistant. + + Async friendly. + """ return self.units.length( location.distance(self.latitude, self.longitude, lat, lon), 'm') def path(self, *path): - """Generate path to the file within the config dir.""" + """Generate path to the file within the config dir. + + Async friendly. + """ if self.config_dir is None: raise HomeAssistantError("config_dir is not set") return os.path.join(self.config_dir, *path) def as_dict(self): - """Create a dict representation of this dict.""" + """Create a dict representation of this dict. + + Async friendly. + """ time_zone = self.time_zone or dt_util.UTC return { @@ -1147,7 +1144,7 @@ class Config(object): } -def async_create_timer(hass, interval=TIMER_INTERVAL): +def _async_create_timer(hass, interval=TIMER_INTERVAL): """Create a timer that will start on HOMEASSISTANT_START.""" stop_event = asyncio.Event(loop=hass.loop) @@ -1230,7 +1227,7 @@ def create_worker_pool(worker_count=None): return util.ThreadPool(job_handler, worker_count) -def async_monitor_worker_pool(hass): +def _async_monitor_worker_pool(hass): """Create a monitor for the thread pool to check if pool is misbehaving.""" busy_threshold = hass.pool.worker_count * 3 diff --git a/homeassistant/remote.py b/homeassistant/remote.py index 8725990f146..15a84e08ffe 100644 --- a/homeassistant/remote.py +++ b/homeassistant/remote.py @@ -124,9 +124,9 @@ class HomeAssistant(ha.HomeAssistant): self.remote_api = remote_api self.loop = loop or asyncio.get_event_loop() - self.pool = pool = ha.create_worker_pool() + self.pool = ha.create_worker_pool() - self.bus = EventBus(remote_api, pool, self.loop) + self.bus = EventBus(remote_api, self) self.services = ha.ServiceRegistry(self.bus, self.add_job, self.loop) self.states = StateMachine(self.bus, self.loop, self.remote_api) self.config = ha.Config() @@ -143,7 +143,7 @@ class HomeAssistant(ha.HomeAssistant): 'Unable to setup local API to receive events') self.state = ha.CoreState.starting - ha.async_create_timer(self) + ha._async_create_timer(self) # pylint: disable=protected-access self.bus.fire(ha.EVENT_HOMEASSISTANT_START, origin=ha.EventOrigin.remote) @@ -180,9 +180,9 @@ class EventBus(ha.EventBus): """EventBus implementation that forwards fire_event to remote API.""" # pylint: disable=too-few-public-methods - def __init__(self, api, pool, loop): + def __init__(self, api, hass): """Initalize the eventbus.""" - super().__init__(pool, loop) + super().__init__(hass) self._api = api def fire(self, event_type, event_data=None, origin=ha.EventOrigin.local): diff --git a/tests/common.py b/tests/common.py index 9dc98d2f4b4..b185a47e66c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -76,8 +76,8 @@ def get_test_home_assistant(num_threads=None): """Fake stop.""" yield None - @patch.object(ha, 'async_create_timer') - @patch.object(ha, 'async_monitor_worker_pool') + @patch.object(ha, '_async_create_timer') + @patch.object(ha, '_async_monitor_worker_pool') @patch.object(hass.loop, 'add_signal_handler') @patch.object(hass.loop, 'run_forever') @patch.object(hass.loop, 'close') diff --git a/tests/components/test_api.py b/tests/components/test_api.py index dee4320824b..ca494305073 100644 --- a/tests/components/test_api.py +++ b/tests/components/test_api.py @@ -145,14 +145,14 @@ class TestAPI(unittest.TestCase): requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")), data=json.dumps({"state": "not_to_be_set"}), headers=HA_HEADERS) - hass.bus._pool.block_till_done() + hass.block_till_done() self.assertEqual(0, len(events)) requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")), data=json.dumps({"state": "not_to_be_set", "force_update": True}), headers=HA_HEADERS) - hass.bus._pool.block_till_done() + hass.block_till_done() self.assertEqual(1, len(events)) # pylint: disable=invalid-name diff --git a/tests/test_core.py b/tests/test_core.py index 39301b5614a..b3ab2ba4dbd 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -179,19 +179,16 @@ class TestEventBus(unittest.TestCase): def listener(_): pass - self.bus.listen('test', listener) + unsub = self.bus.listen('test', listener) self.assertEqual(old_count + 1, len(self.bus.listeners)) - # Try deleting a non registered listener, nothing should happen - self.bus._remove_listener('test', lambda x: len) - # Remove listener - self.bus._remove_listener('test', listener) + unsub() self.assertEqual(old_count, len(self.bus.listeners)) - # Try deleting listener while category doesn't exist either - self.bus._remove_listener('test', listener) + # Should do nothing now + unsub() def test_unsubscribe_listener(self): """Test unsubscribe listener from returned function.""" @@ -215,11 +212,48 @@ class TestEventBus(unittest.TestCase): assert len(calls) == 1 - def test_listen_once_event(self): + def test_listen_once_event_with_callback(self): """Test listen_once_event method.""" runs = [] - self.bus.listen_once('test_event', lambda x: runs.append(1)) + @ha.callback + def event_handler(event): + runs.append(event) + + self.bus.listen_once('test_event', event_handler) + + self.bus.fire('test_event') + # Second time it should not increase runs + self.bus.fire('test_event') + + self.hass.block_till_done() + self.assertEqual(1, len(runs)) + + def test_listen_once_event_with_coroutine(self): + """Test listen_once_event method.""" + runs = [] + + @asyncio.coroutine + def event_handler(event): + runs.append(event) + + self.bus.listen_once('test_event', event_handler) + + self.bus.fire('test_event') + # Second time it should not increase runs + self.bus.fire('test_event') + + self.hass.block_till_done() + self.assertEqual(1, len(runs)) + + def test_listen_once_event_with_thread(self): + """Test listen_once_event method.""" + runs = [] + + def event_handler(event): + runs.append(event) + + self.bus.listen_once('test_event', event_handler) self.bus.fire('test_event') # Second time it should not increase runs @@ -604,7 +638,7 @@ class TestWorkerPoolMonitor(object): schedule_handle = MagicMock() hass.loop.call_later.return_value = schedule_handle - ha.async_monitor_worker_pool(hass) + ha._async_monitor_worker_pool(hass) assert hass.loop.call_later.called assert hass.bus.async_listen_once.called assert not schedule_handle.called @@ -650,7 +684,7 @@ class TestAsyncCreateTimer(object): now.second = 1 mock_utcnow.reset_mock() - ha.async_create_timer(hass) + ha._async_create_timer(hass) assert len(hass.bus.async_listen_once.mock_calls) == 2 start_timer = hass.bus.async_listen_once.mock_calls[1][1][1] diff --git a/tests/test_remote.py b/tests/test_remote.py index 653971f8bc1..316f13c5fc2 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -69,7 +69,7 @@ def setUpModule(): # pylint: disable=invalid-name {http.DOMAIN: {http.CONF_API_PASSWORD: API_PASSWORD, http.CONF_SERVER_PORT: SLAVE_PORT}}) - with patch.object(ha, 'async_create_timer', return_value=None): + with patch.object(ha, '_async_create_timer', return_value=None): slave.start()