mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Clean up some async stuff (#3915)
* Clean up some async stuff * Adjust comments * Pass hass instance to eventbus
This commit is contained in:
parent
daea93d9f9
commit
4c8d1d9d2f
@ -79,8 +79,7 @@ class NuimoThread(threading.Thread):
|
|||||||
self._name = name
|
self._name = name
|
||||||
self._hass_is_running = True
|
self._hass_is_running = True
|
||||||
self._nuimo = None
|
self._nuimo = None
|
||||||
self._listener = hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP,
|
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.stop)
|
||||||
self.stop)
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Setup connection or be idle."""
|
"""Setup connection or be idle."""
|
||||||
@ -99,8 +98,6 @@ class NuimoThread(threading.Thread):
|
|||||||
"""Terminate Thread by unsetting flag."""
|
"""Terminate Thread by unsetting flag."""
|
||||||
_LOGGER.debug('Stopping thread for Nuimo %s', self._mac)
|
_LOGGER.debug('Stopping thread for Nuimo %s', self._mac)
|
||||||
self._hass_is_running = False
|
self._hass_is_running = False
|
||||||
self._hass.bus.remove_listener(EVENT_HOMEASSISTANT_STOP,
|
|
||||||
self._listener)
|
|
||||||
|
|
||||||
def _attach(self):
|
def _attach(self):
|
||||||
"""Create a nuimo object from mac address or discovery."""
|
"""Create a nuimo object from mac address or discovery."""
|
||||||
|
@ -8,7 +8,6 @@ of entities and react to changes.
|
|||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
import enum
|
import enum
|
||||||
import functools as ft
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -137,8 +136,8 @@ class HomeAssistant(object):
|
|||||||
self.executor = ThreadPoolExecutor(max_workers=5)
|
self.executor = ThreadPoolExecutor(max_workers=5)
|
||||||
self.loop.set_default_executor(self.executor)
|
self.loop.set_default_executor(self.executor)
|
||||||
self.loop.set_exception_handler(self._async_exception_handler)
|
self.loop.set_exception_handler(self._async_exception_handler)
|
||||||
self.pool = pool = create_worker_pool()
|
self.pool = create_worker_pool()
|
||||||
self.bus = EventBus(pool, self.loop)
|
self.bus = EventBus(self)
|
||||||
self.services = ServiceRegistry(self.bus, self.add_job, self.loop)
|
self.services = ServiceRegistry(self.bus, self.add_job, self.loop)
|
||||||
self.states = StateMachine(self.bus, self.loop)
|
self.states = StateMachine(self.bus, self.loop)
|
||||||
self.config = Config() # type: Config
|
self.config = Config() # type: Config
|
||||||
@ -218,8 +217,8 @@ class HomeAssistant(object):
|
|||||||
"""
|
"""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
self.loop._thread_ident = threading.get_ident()
|
self.loop._thread_ident = threading.get_ident()
|
||||||
async_create_timer(self)
|
_async_create_timer(self)
|
||||||
async_monitor_worker_pool(self)
|
_async_monitor_worker_pool(self)
|
||||||
self.bus.async_fire(EVENT_HOMEASSISTANT_START)
|
self.bus.async_fire(EVENT_HOMEASSISTANT_START)
|
||||||
yield from self.loop.run_in_executor(None, self.pool.block_till_done)
|
yield from self.loop.run_in_executor(None, self.pool.block_till_done)
|
||||||
self.state = CoreState.running
|
self.state = CoreState.running
|
||||||
@ -235,9 +234,12 @@ class HomeAssistant(object):
|
|||||||
"""
|
"""
|
||||||
self.pool.add_job(priority, (target,) + args)
|
self.pool.add_job(priority, (target,) + args)
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_add_job(self, target: Callable[..., None], *args: Any):
|
def async_add_job(self, target: Callable[..., None], *args: Any):
|
||||||
"""Add a job from within the eventloop.
|
"""Add a job from within the eventloop.
|
||||||
|
|
||||||
|
This method must be run in the event loop.
|
||||||
|
|
||||||
target: target to call.
|
target: target to call.
|
||||||
args: parameters for method to call.
|
args: parameters for method to call.
|
||||||
"""
|
"""
|
||||||
@ -248,9 +250,12 @@ class HomeAssistant(object):
|
|||||||
else:
|
else:
|
||||||
self.add_job(target, *args)
|
self.add_job(target, *args)
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_run_job(self, target: Callable[..., None], *args: Any):
|
def async_run_job(self, target: Callable[..., None], *args: Any):
|
||||||
"""Run a job from within the event loop.
|
"""Run a job from within the event loop.
|
||||||
|
|
||||||
|
This method must be run in the event loop.
|
||||||
|
|
||||||
target: target to call.
|
target: target to call.
|
||||||
args: parameters for method to call.
|
args: parameters for method to call.
|
||||||
"""
|
"""
|
||||||
@ -369,7 +374,10 @@ class Event(object):
|
|||||||
self.time_fired = time_fired or dt_util.utcnow()
|
self.time_fired = time_fired or dt_util.utcnow()
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
"""Create a dict representation of this Event."""
|
"""Create a dict representation of this Event.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
"""
|
||||||
return {
|
return {
|
||||||
'event_type': self.event_type,
|
'event_type': self.event_type,
|
||||||
'data': dict(self.data),
|
'data': dict(self.data),
|
||||||
@ -400,13 +408,12 @@ class Event(object):
|
|||||||
class EventBus(object):
|
class EventBus(object):
|
||||||
"""Allows firing of and listening for events."""
|
"""Allows firing of and listening for events."""
|
||||||
|
|
||||||
def __init__(self, pool: util.ThreadPool,
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
loop: asyncio.AbstractEventLoop) -> None:
|
|
||||||
"""Initialize a new event bus."""
|
"""Initialize a new event bus."""
|
||||||
self._listeners = {}
|
self._listeners = {}
|
||||||
self._pool = pool
|
self._hass = hass
|
||||||
self._loop = loop
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_listeners(self):
|
def async_listeners(self):
|
||||||
"""Dict with events and the number of listeners.
|
"""Dict with events and the number of listeners.
|
||||||
|
|
||||||
@ -419,23 +426,25 @@ class EventBus(object):
|
|||||||
def listeners(self):
|
def listeners(self):
|
||||||
"""Dict with events and the number of listeners."""
|
"""Dict with events and the number of listeners."""
|
||||||
return run_callback_threadsafe(
|
return run_callback_threadsafe(
|
||||||
self._loop, self.async_listeners
|
self._hass.loop, self.async_listeners
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
def fire(self, event_type: str, event_data=None, origin=EventOrigin.local):
|
def fire(self, event_type: str, event_data=None, origin=EventOrigin.local):
|
||||||
"""Fire an event."""
|
"""Fire an event."""
|
||||||
if not self._pool.running:
|
self._hass.loop.call_soon_threadsafe(self.async_fire, event_type,
|
||||||
raise HomeAssistantError('Home Assistant has shut down.')
|
event_data, origin)
|
||||||
|
|
||||||
self._loop.call_soon_threadsafe(self.async_fire, event_type,
|
|
||||||
event_data, origin)
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_fire(self, event_type: str, event_data=None,
|
def async_fire(self, event_type: str, event_data=None,
|
||||||
origin=EventOrigin.local, wait=False):
|
origin=EventOrigin.local, wait=False):
|
||||||
"""Fire an event.
|
"""Fire an event.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
|
if event_type != EVENT_HOMEASSISTANT_STOP and \
|
||||||
|
self._hass.state == CoreState.stopping:
|
||||||
|
raise HomeAssistantError('Home Assistant is shutting down.')
|
||||||
|
|
||||||
# Copy the list of the current listeners because some listeners
|
# Copy the list of the current listeners because some listeners
|
||||||
# remove themselves as a listener while being executed which
|
# remove themselves as a listener while being executed which
|
||||||
# causes the iterator to be confused.
|
# causes the iterator to be confused.
|
||||||
@ -450,20 +459,8 @@ class EventBus(object):
|
|||||||
if not listeners:
|
if not listeners:
|
||||||
return
|
return
|
||||||
|
|
||||||
job_priority = JobPriority.from_event_type(event_type)
|
|
||||||
|
|
||||||
sync_jobs = []
|
|
||||||
for func in listeners:
|
for func in listeners:
|
||||||
if asyncio.iscoroutinefunction(func):
|
self._hass.async_add_job(func, event)
|
||||||
self._loop.create_task(func(event))
|
|
||||||
elif is_callback(func):
|
|
||||||
self._loop.call_soon(func, event)
|
|
||||||
else:
|
|
||||||
sync_jobs.append((job_priority, (func, event)))
|
|
||||||
|
|
||||||
# Send all the sync jobs at once
|
|
||||||
if sync_jobs:
|
|
||||||
self._pool.add_many_jobs(sync_jobs)
|
|
||||||
|
|
||||||
def listen(self, event_type, listener):
|
def listen(self, event_type, listener):
|
||||||
"""Listen for all events or events of a specific type.
|
"""Listen for all events or events of a specific type.
|
||||||
@ -471,16 +468,17 @@ class EventBus(object):
|
|||||||
To listen to all events specify the constant ``MATCH_ALL``
|
To listen to all events specify the constant ``MATCH_ALL``
|
||||||
as event_type.
|
as event_type.
|
||||||
"""
|
"""
|
||||||
future = run_callback_threadsafe(
|
async_remove_listener = run_callback_threadsafe(
|
||||||
self._loop, self.async_listen, event_type, listener)
|
self._hass.loop, self.async_listen, event_type, listener).result()
|
||||||
future.result()
|
|
||||||
|
|
||||||
def remove_listener():
|
def remove_listener():
|
||||||
"""Remove the listener."""
|
"""Remove the listener."""
|
||||||
self._remove_listener(event_type, listener)
|
run_callback_threadsafe(
|
||||||
|
self._hass.loop, async_remove_listener).result()
|
||||||
|
|
||||||
return remove_listener
|
return remove_listener
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_listen(self, event_type, listener):
|
def async_listen(self, event_type, listener):
|
||||||
"""Listen for all events or events of a specific type.
|
"""Listen for all events or events of a specific type.
|
||||||
|
|
||||||
@ -496,7 +494,7 @@ class EventBus(object):
|
|||||||
|
|
||||||
def remove_listener():
|
def remove_listener():
|
||||||
"""Remove the listener."""
|
"""Remove the listener."""
|
||||||
self.async_remove_listener(event_type, listener)
|
self._async_remove_listener(event_type, listener)
|
||||||
|
|
||||||
return remove_listener
|
return remove_listener
|
||||||
|
|
||||||
@ -508,26 +506,18 @@ class EventBus(object):
|
|||||||
|
|
||||||
Returns function to unsubscribe the listener.
|
Returns function to unsubscribe the listener.
|
||||||
"""
|
"""
|
||||||
@ft.wraps(listener)
|
async_remove_listener = run_callback_threadsafe(
|
||||||
def onetime_listener(event):
|
self._hass.loop, self.async_listen_once, event_type, listener,
|
||||||
"""Remove listener from eventbus and then fire listener."""
|
).result()
|
||||||
if hasattr(onetime_listener, 'run'):
|
|
||||||
return
|
|
||||||
# Set variable so that we will never run twice.
|
|
||||||
# Because the event bus might have to wait till a thread comes
|
|
||||||
# available to execute this listener it might occur that the
|
|
||||||
# listener gets lined up twice to be executed.
|
|
||||||
# This will make sure the second time it does nothing.
|
|
||||||
setattr(onetime_listener, 'run', True)
|
|
||||||
|
|
||||||
remove_listener()
|
def remove_listener():
|
||||||
|
"""Remove the listener."""
|
||||||
listener(event)
|
run_callback_threadsafe(
|
||||||
|
self._hass.loop, async_remove_listener).result()
|
||||||
remove_listener = self.listen(event_type, onetime_listener)
|
|
||||||
|
|
||||||
return remove_listener
|
return remove_listener
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_listen_once(self, event_type, listener):
|
def async_listen_once(self, event_type, listener):
|
||||||
"""Listen once for event of a specific type.
|
"""Listen once for event of a specific type.
|
||||||
|
|
||||||
@ -538,8 +528,7 @@ class EventBus(object):
|
|||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
@ft.wraps(listener)
|
@callback
|
||||||
@asyncio.coroutine
|
|
||||||
def onetime_listener(event):
|
def onetime_listener(event):
|
||||||
"""Remove listener from eventbus and then fire listener."""
|
"""Remove listener from eventbus and then fire listener."""
|
||||||
if hasattr(onetime_listener, 'run'):
|
if hasattr(onetime_listener, 'run'):
|
||||||
@ -550,34 +539,14 @@ class EventBus(object):
|
|||||||
# multiple times as well.
|
# multiple times as well.
|
||||||
# This will make sure the second time it does nothing.
|
# This will make sure the second time it does nothing.
|
||||||
setattr(onetime_listener, 'run', True)
|
setattr(onetime_listener, 'run', True)
|
||||||
|
self._async_remove_listener(event_type, onetime_listener)
|
||||||
|
|
||||||
self.async_remove_listener(event_type, onetime_listener)
|
self._hass.async_run_job(listener, event)
|
||||||
|
|
||||||
if asyncio.iscoroutinefunction(listener):
|
return self.async_listen(event_type, onetime_listener)
|
||||||
yield from listener(event)
|
|
||||||
else:
|
|
||||||
job_priority = JobPriority.from_event_type(event.event_type)
|
|
||||||
self._pool.add_job(job_priority, (listener, event))
|
|
||||||
|
|
||||||
self.async_listen(event_type, onetime_listener)
|
@callback
|
||||||
|
def _async_remove_listener(self, event_type, listener):
|
||||||
return onetime_listener
|
|
||||||
|
|
||||||
def remove_listener(self, event_type, listener):
|
|
||||||
"""Remove a listener of a specific event_type. (DEPRECATED 0.28)."""
|
|
||||||
_LOGGER.warning('bus.remove_listener has been deprecated. Please use '
|
|
||||||
'the function returned from calling listen.')
|
|
||||||
self._remove_listener(event_type, listener)
|
|
||||||
|
|
||||||
def _remove_listener(self, event_type, listener):
|
|
||||||
"""Remove a listener of a specific event_type."""
|
|
||||||
future = run_callback_threadsafe(
|
|
||||||
self._loop,
|
|
||||||
self.async_remove_listener, event_type, listener
|
|
||||||
)
|
|
||||||
future.result()
|
|
||||||
|
|
||||||
def async_remove_listener(self, event_type, listener):
|
|
||||||
"""Remove a listener of a specific event_type.
|
"""Remove a listener of a specific event_type.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
@ -644,6 +613,8 @@ class State(object):
|
|||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
"""Return a dict representation of the State.
|
"""Return a dict representation of the State.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
|
||||||
To be used for JSON serialization.
|
To be used for JSON serialization.
|
||||||
Ensures: state == State.from_dict(state.as_dict())
|
Ensures: state == State.from_dict(state.as_dict())
|
||||||
"""
|
"""
|
||||||
@ -657,6 +628,8 @@ class State(object):
|
|||||||
def from_dict(cls, json_dict):
|
def from_dict(cls, json_dict):
|
||||||
"""Initialize a state from a dict.
|
"""Initialize a state from a dict.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
|
||||||
Ensures: state == State.from_json_dict(state.to_json_dict())
|
Ensures: state == State.from_json_dict(state.to_json_dict())
|
||||||
"""
|
"""
|
||||||
if not (json_dict and 'entity_id' in json_dict and
|
if not (json_dict and 'entity_id' in json_dict and
|
||||||
@ -709,8 +682,12 @@ class StateMachine(object):
|
|||||||
)
|
)
|
||||||
return future.result()
|
return future.result()
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_entity_ids(self, domain_filter=None):
|
def async_entity_ids(self, domain_filter=None):
|
||||||
"""List of entity ids that are being tracked."""
|
"""List of entity ids that are being tracked.
|
||||||
|
|
||||||
|
This method must be run in the event loop.
|
||||||
|
"""
|
||||||
if domain_filter is None:
|
if domain_filter is None:
|
||||||
return list(self._states.keys())
|
return list(self._states.keys())
|
||||||
|
|
||||||
@ -723,6 +700,7 @@ class StateMachine(object):
|
|||||||
"""Create a list of all states."""
|
"""Create a list of all states."""
|
||||||
return run_callback_threadsafe(self._loop, self.async_all).result()
|
return run_callback_threadsafe(self._loop, self.async_all).result()
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_all(self):
|
def async_all(self):
|
||||||
"""Create a list of all states.
|
"""Create a list of all states.
|
||||||
|
|
||||||
@ -763,6 +741,7 @@ class StateMachine(object):
|
|||||||
return run_callback_threadsafe(
|
return run_callback_threadsafe(
|
||||||
self._loop, self.async_remove, entity_id).result()
|
self._loop, self.async_remove, entity_id).result()
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_remove(self, entity_id):
|
def async_remove(self, entity_id):
|
||||||
"""Remove the state of an entity.
|
"""Remove the state of an entity.
|
||||||
|
|
||||||
@ -800,6 +779,7 @@ class StateMachine(object):
|
|||||||
self.async_set, entity_id, new_state, attributes, force_update,
|
self.async_set, entity_id, new_state, attributes, force_update,
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_set(self, entity_id, new_state, attributes=None,
|
def async_set(self, entity_id, new_state, attributes=None,
|
||||||
force_update=False):
|
force_update=False):
|
||||||
"""Set the state of an entity, add entity if it does not exist.
|
"""Set the state of an entity, add entity if it does not exist.
|
||||||
@ -908,14 +888,21 @@ class ServiceRegistry(object):
|
|||||||
self._loop, self.async_services,
|
self._loop, self.async_services,
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_services(self):
|
def async_services(self):
|
||||||
"""Dict with per domain a list of available services."""
|
"""Dict with per domain a list of available services.
|
||||||
|
|
||||||
|
This method must be run in the event loop.
|
||||||
|
"""
|
||||||
return {domain: {key: value.as_dict() for key, value
|
return {domain: {key: value.as_dict() for key, value
|
||||||
in self._services[domain].items()}
|
in self._services[domain].items()}
|
||||||
for domain in self._services}
|
for domain in self._services}
|
||||||
|
|
||||||
def has_service(self, domain, service):
|
def has_service(self, domain, service):
|
||||||
"""Test if specified service exists."""
|
"""Test if specified service exists.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
"""
|
||||||
return service.lower() in self._services.get(domain.lower(), [])
|
return service.lower() in self._services.get(domain.lower(), [])
|
||||||
|
|
||||||
# pylint: disable=too-many-arguments
|
# pylint: disable=too-many-arguments
|
||||||
@ -935,6 +922,7 @@ class ServiceRegistry(object):
|
|||||||
schema
|
schema
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_register(self, domain, service, service_func, description=None,
|
def async_register(self, domain, service, service_func, description=None,
|
||||||
schema=None):
|
schema=None):
|
||||||
"""
|
"""
|
||||||
@ -985,7 +973,7 @@ class ServiceRegistry(object):
|
|||||||
self._loop
|
self._loop
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
@callback
|
@asyncio.coroutine
|
||||||
def async_call(self, domain, service, service_data=None, blocking=False):
|
def async_call(self, domain, service, service_data=None, blocking=False):
|
||||||
"""
|
"""
|
||||||
Call a service.
|
Call a service.
|
||||||
@ -1121,18 +1109,27 @@ class Config(object):
|
|||||||
self.config_dir = None
|
self.config_dir = None
|
||||||
|
|
||||||
def distance(self: object, lat: float, lon: float) -> float:
|
def distance(self: object, lat: float, lon: float) -> float:
|
||||||
"""Calculate distance from Home Assistant."""
|
"""Calculate distance from Home Assistant.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
"""
|
||||||
return self.units.length(
|
return self.units.length(
|
||||||
location.distance(self.latitude, self.longitude, lat, lon), 'm')
|
location.distance(self.latitude, self.longitude, lat, lon), 'm')
|
||||||
|
|
||||||
def path(self, *path):
|
def path(self, *path):
|
||||||
"""Generate path to the file within the config dir."""
|
"""Generate path to the file within the config dir.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
"""
|
||||||
if self.config_dir is None:
|
if self.config_dir is None:
|
||||||
raise HomeAssistantError("config_dir is not set")
|
raise HomeAssistantError("config_dir is not set")
|
||||||
return os.path.join(self.config_dir, *path)
|
return os.path.join(self.config_dir, *path)
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
"""Create a dict representation of this dict."""
|
"""Create a dict representation of this dict.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
"""
|
||||||
time_zone = self.time_zone or dt_util.UTC
|
time_zone = self.time_zone or dt_util.UTC
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -1147,7 +1144,7 @@ class Config(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def async_create_timer(hass, interval=TIMER_INTERVAL):
|
def _async_create_timer(hass, interval=TIMER_INTERVAL):
|
||||||
"""Create a timer that will start on HOMEASSISTANT_START."""
|
"""Create a timer that will start on HOMEASSISTANT_START."""
|
||||||
stop_event = asyncio.Event(loop=hass.loop)
|
stop_event = asyncio.Event(loop=hass.loop)
|
||||||
|
|
||||||
@ -1230,7 +1227,7 @@ def create_worker_pool(worker_count=None):
|
|||||||
return util.ThreadPool(job_handler, worker_count)
|
return util.ThreadPool(job_handler, worker_count)
|
||||||
|
|
||||||
|
|
||||||
def async_monitor_worker_pool(hass):
|
def _async_monitor_worker_pool(hass):
|
||||||
"""Create a monitor for the thread pool to check if pool is misbehaving."""
|
"""Create a monitor for the thread pool to check if pool is misbehaving."""
|
||||||
busy_threshold = hass.pool.worker_count * 3
|
busy_threshold = hass.pool.worker_count * 3
|
||||||
|
|
||||||
|
@ -124,9 +124,9 @@ class HomeAssistant(ha.HomeAssistant):
|
|||||||
self.remote_api = remote_api
|
self.remote_api = remote_api
|
||||||
|
|
||||||
self.loop = loop or asyncio.get_event_loop()
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
self.pool = pool = ha.create_worker_pool()
|
self.pool = ha.create_worker_pool()
|
||||||
|
|
||||||
self.bus = EventBus(remote_api, pool, self.loop)
|
self.bus = EventBus(remote_api, self)
|
||||||
self.services = ha.ServiceRegistry(self.bus, self.add_job, self.loop)
|
self.services = ha.ServiceRegistry(self.bus, self.add_job, self.loop)
|
||||||
self.states = StateMachine(self.bus, self.loop, self.remote_api)
|
self.states = StateMachine(self.bus, self.loop, self.remote_api)
|
||||||
self.config = ha.Config()
|
self.config = ha.Config()
|
||||||
@ -143,7 +143,7 @@ class HomeAssistant(ha.HomeAssistant):
|
|||||||
'Unable to setup local API to receive events')
|
'Unable to setup local API to receive events')
|
||||||
|
|
||||||
self.state = ha.CoreState.starting
|
self.state = ha.CoreState.starting
|
||||||
ha.async_create_timer(self)
|
ha._async_create_timer(self) # pylint: disable=protected-access
|
||||||
|
|
||||||
self.bus.fire(ha.EVENT_HOMEASSISTANT_START,
|
self.bus.fire(ha.EVENT_HOMEASSISTANT_START,
|
||||||
origin=ha.EventOrigin.remote)
|
origin=ha.EventOrigin.remote)
|
||||||
@ -180,9 +180,9 @@ class EventBus(ha.EventBus):
|
|||||||
"""EventBus implementation that forwards fire_event to remote API."""
|
"""EventBus implementation that forwards fire_event to remote API."""
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
def __init__(self, api, pool, loop):
|
def __init__(self, api, hass):
|
||||||
"""Initalize the eventbus."""
|
"""Initalize the eventbus."""
|
||||||
super().__init__(pool, loop)
|
super().__init__(hass)
|
||||||
self._api = api
|
self._api = api
|
||||||
|
|
||||||
def fire(self, event_type, event_data=None, origin=ha.EventOrigin.local):
|
def fire(self, event_type, event_data=None, origin=ha.EventOrigin.local):
|
||||||
|
@ -76,8 +76,8 @@ def get_test_home_assistant(num_threads=None):
|
|||||||
"""Fake stop."""
|
"""Fake stop."""
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
@patch.object(ha, 'async_create_timer')
|
@patch.object(ha, '_async_create_timer')
|
||||||
@patch.object(ha, 'async_monitor_worker_pool')
|
@patch.object(ha, '_async_monitor_worker_pool')
|
||||||
@patch.object(hass.loop, 'add_signal_handler')
|
@patch.object(hass.loop, 'add_signal_handler')
|
||||||
@patch.object(hass.loop, 'run_forever')
|
@patch.object(hass.loop, 'run_forever')
|
||||||
@patch.object(hass.loop, 'close')
|
@patch.object(hass.loop, 'close')
|
||||||
|
@ -145,14 +145,14 @@ class TestAPI(unittest.TestCase):
|
|||||||
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
|
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
|
||||||
data=json.dumps({"state": "not_to_be_set"}),
|
data=json.dumps({"state": "not_to_be_set"}),
|
||||||
headers=HA_HEADERS)
|
headers=HA_HEADERS)
|
||||||
hass.bus._pool.block_till_done()
|
hass.block_till_done()
|
||||||
self.assertEqual(0, len(events))
|
self.assertEqual(0, len(events))
|
||||||
|
|
||||||
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
|
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
|
||||||
data=json.dumps({"state": "not_to_be_set",
|
data=json.dumps({"state": "not_to_be_set",
|
||||||
"force_update": True}),
|
"force_update": True}),
|
||||||
headers=HA_HEADERS)
|
headers=HA_HEADERS)
|
||||||
hass.bus._pool.block_till_done()
|
hass.block_till_done()
|
||||||
self.assertEqual(1, len(events))
|
self.assertEqual(1, len(events))
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
@ -179,19 +179,16 @@ class TestEventBus(unittest.TestCase):
|
|||||||
|
|
||||||
def listener(_): pass
|
def listener(_): pass
|
||||||
|
|
||||||
self.bus.listen('test', listener)
|
unsub = self.bus.listen('test', listener)
|
||||||
|
|
||||||
self.assertEqual(old_count + 1, len(self.bus.listeners))
|
self.assertEqual(old_count + 1, len(self.bus.listeners))
|
||||||
|
|
||||||
# Try deleting a non registered listener, nothing should happen
|
|
||||||
self.bus._remove_listener('test', lambda x: len)
|
|
||||||
|
|
||||||
# Remove listener
|
# Remove listener
|
||||||
self.bus._remove_listener('test', listener)
|
unsub()
|
||||||
self.assertEqual(old_count, len(self.bus.listeners))
|
self.assertEqual(old_count, len(self.bus.listeners))
|
||||||
|
|
||||||
# Try deleting listener while category doesn't exist either
|
# Should do nothing now
|
||||||
self.bus._remove_listener('test', listener)
|
unsub()
|
||||||
|
|
||||||
def test_unsubscribe_listener(self):
|
def test_unsubscribe_listener(self):
|
||||||
"""Test unsubscribe listener from returned function."""
|
"""Test unsubscribe listener from returned function."""
|
||||||
@ -215,11 +212,48 @@ class TestEventBus(unittest.TestCase):
|
|||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
|
|
||||||
def test_listen_once_event(self):
|
def test_listen_once_event_with_callback(self):
|
||||||
"""Test listen_once_event method."""
|
"""Test listen_once_event method."""
|
||||||
runs = []
|
runs = []
|
||||||
|
|
||||||
self.bus.listen_once('test_event', lambda x: runs.append(1))
|
@ha.callback
|
||||||
|
def event_handler(event):
|
||||||
|
runs.append(event)
|
||||||
|
|
||||||
|
self.bus.listen_once('test_event', event_handler)
|
||||||
|
|
||||||
|
self.bus.fire('test_event')
|
||||||
|
# Second time it should not increase runs
|
||||||
|
self.bus.fire('test_event')
|
||||||
|
|
||||||
|
self.hass.block_till_done()
|
||||||
|
self.assertEqual(1, len(runs))
|
||||||
|
|
||||||
|
def test_listen_once_event_with_coroutine(self):
|
||||||
|
"""Test listen_once_event method."""
|
||||||
|
runs = []
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def event_handler(event):
|
||||||
|
runs.append(event)
|
||||||
|
|
||||||
|
self.bus.listen_once('test_event', event_handler)
|
||||||
|
|
||||||
|
self.bus.fire('test_event')
|
||||||
|
# Second time it should not increase runs
|
||||||
|
self.bus.fire('test_event')
|
||||||
|
|
||||||
|
self.hass.block_till_done()
|
||||||
|
self.assertEqual(1, len(runs))
|
||||||
|
|
||||||
|
def test_listen_once_event_with_thread(self):
|
||||||
|
"""Test listen_once_event method."""
|
||||||
|
runs = []
|
||||||
|
|
||||||
|
def event_handler(event):
|
||||||
|
runs.append(event)
|
||||||
|
|
||||||
|
self.bus.listen_once('test_event', event_handler)
|
||||||
|
|
||||||
self.bus.fire('test_event')
|
self.bus.fire('test_event')
|
||||||
# Second time it should not increase runs
|
# Second time it should not increase runs
|
||||||
@ -604,7 +638,7 @@ class TestWorkerPoolMonitor(object):
|
|||||||
schedule_handle = MagicMock()
|
schedule_handle = MagicMock()
|
||||||
hass.loop.call_later.return_value = schedule_handle
|
hass.loop.call_later.return_value = schedule_handle
|
||||||
|
|
||||||
ha.async_monitor_worker_pool(hass)
|
ha._async_monitor_worker_pool(hass)
|
||||||
assert hass.loop.call_later.called
|
assert hass.loop.call_later.called
|
||||||
assert hass.bus.async_listen_once.called
|
assert hass.bus.async_listen_once.called
|
||||||
assert not schedule_handle.called
|
assert not schedule_handle.called
|
||||||
@ -650,7 +684,7 @@ class TestAsyncCreateTimer(object):
|
|||||||
now.second = 1
|
now.second = 1
|
||||||
mock_utcnow.reset_mock()
|
mock_utcnow.reset_mock()
|
||||||
|
|
||||||
ha.async_create_timer(hass)
|
ha._async_create_timer(hass)
|
||||||
assert len(hass.bus.async_listen_once.mock_calls) == 2
|
assert len(hass.bus.async_listen_once.mock_calls) == 2
|
||||||
start_timer = hass.bus.async_listen_once.mock_calls[1][1][1]
|
start_timer = hass.bus.async_listen_once.mock_calls[1][1][1]
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ def setUpModule(): # pylint: disable=invalid-name
|
|||||||
{http.DOMAIN: {http.CONF_API_PASSWORD: API_PASSWORD,
|
{http.DOMAIN: {http.CONF_API_PASSWORD: API_PASSWORD,
|
||||||
http.CONF_SERVER_PORT: SLAVE_PORT}})
|
http.CONF_SERVER_PORT: SLAVE_PORT}})
|
||||||
|
|
||||||
with patch.object(ha, 'async_create_timer', return_value=None):
|
with patch.object(ha, '_async_create_timer', return_value=None):
|
||||||
slave.start()
|
slave.start()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user