diff --git a/homeassistant/components/automation/event.py b/homeassistant/components/automation/event.py index ae0f219a01a..a51f9fa8187 100644 --- a/homeassistant/components/automation/event.py +++ b/homeassistant/components/automation/event.py @@ -4,11 +4,11 @@ Offer event listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#event-trigger """ -import asyncio import logging import voluptuous as vol +from homeassistant.core import callback from homeassistant.const import CONF_PLATFORM from homeassistant.helpers import config_validation as cv @@ -29,12 +29,12 @@ def async_trigger(hass, config, action): event_type = config.get(CONF_EVENT_TYPE) event_data = config.get(CONF_EVENT_DATA) - @asyncio.coroutine + @callback def handle_event(event): """Listen for events and calls the action when data matches.""" if not event_data or all(val == event.data.get(key) for key, val in event_data.items()): - hass.async_add_job(action, { + hass.async_run_job(action, { 'trigger': { 'platform': 'event', 'event': event, diff --git a/homeassistant/components/automation/mqtt.py b/homeassistant/components/automation/mqtt.py index 7897a9bc221..39deae3d66e 100644 --- a/homeassistant/components/automation/mqtt.py +++ b/homeassistant/components/automation/mqtt.py @@ -4,9 +4,9 @@ Offer MQTT listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#mqtt-trigger """ -import asyncio import voluptuous as vol +from homeassistant.core import callback import homeassistant.components.mqtt as mqtt from homeassistant.const import (CONF_PLATFORM, CONF_PAYLOAD) import homeassistant.helpers.config_validation as cv @@ -27,11 +27,11 @@ def async_trigger(hass, config, action): topic = config.get(CONF_TOPIC) payload = config.get(CONF_PAYLOAD) - @asyncio.coroutine + @callback def mqtt_automation_listener(msg_topic, msg_payload, qos): """Listen for MQTT messages.""" if payload is None or payload == msg_payload: - hass.async_add_job(action, { + hass.async_run_job(action, { 'trigger': { 'platform': 'mqtt', 'topic': msg_topic, diff --git a/homeassistant/components/automation/numeric_state.py b/homeassistant/components/automation/numeric_state.py index 4d6cdc21190..9c3ac7d8396 100644 --- a/homeassistant/components/automation/numeric_state.py +++ b/homeassistant/components/automation/numeric_state.py @@ -4,11 +4,11 @@ Offer numeric state listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#numeric-state-trigger """ -import asyncio import logging import voluptuous as vol +from homeassistant.core import callback from homeassistant.const import ( CONF_VALUE_TEMPLATE, CONF_PLATFORM, CONF_ENTITY_ID, CONF_BELOW, CONF_ABOVE) @@ -35,7 +35,7 @@ def async_trigger(hass, config, action): if value_template is not None: value_template.hass = hass - @asyncio.coroutine + @callback def state_automation_listener(entity, from_s, to_s): """Listen for state changes and calls action.""" if to_s is None: @@ -64,6 +64,6 @@ def async_trigger(hass, config, action): variables['trigger']['from_state'] = from_s variables['trigger']['to_state'] = to_s - hass.async_add_job(action, variables) + hass.async_run_job(action, variables) return async_track_state_change(hass, entity_id, state_automation_listener) diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py index 0649834ff33..fb146991602 100644 --- a/homeassistant/components/automation/state.py +++ b/homeassistant/components/automation/state.py @@ -4,9 +4,9 @@ Offer state listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#state-trigger """ -import asyncio import voluptuous as vol +from homeassistant.core import callback import homeassistant.util.dt as dt_util from homeassistant.const import MATCH_ALL, CONF_PLATFORM from homeassistant.helpers.event import ( @@ -43,14 +43,14 @@ def async_trigger(hass, config, action): async_remove_state_for_cancel = None async_remove_state_for_listener = None - @asyncio.coroutine + @callback def state_automation_listener(entity, from_s, to_s): """Listen for state changes and calls action.""" nonlocal async_remove_state_for_cancel, async_remove_state_for_listener def call_action(): """Call action with right context.""" - hass.async_add_job(action, { + hass.async_run_job(action, { 'trigger': { 'platform': 'state', 'entity_id': entity, @@ -64,13 +64,13 @@ def async_trigger(hass, config, action): call_action() return - @asyncio.coroutine + @callback def state_for_listener(now): """Fire on state changes after a delay and calls action.""" async_remove_state_for_cancel() call_action() - @asyncio.coroutine + @callback def state_for_cancel_listener(entity, inner_from_s, inner_to_s): """Fire on changes and cancel for listener if changed.""" if inner_to_s.state == to_s.state: diff --git a/homeassistant/components/automation/sun.py b/homeassistant/components/automation/sun.py index 9892707a139..2baa0726813 100644 --- a/homeassistant/components/automation/sun.py +++ b/homeassistant/components/automation/sun.py @@ -4,12 +4,12 @@ Offer sun based automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#sun-trigger """ -import asyncio from datetime import timedelta import logging import voluptuous as vol +from homeassistant.core import callback from homeassistant.const import ( CONF_EVENT, CONF_OFFSET, CONF_PLATFORM, SUN_EVENT_SUNRISE) from homeassistant.helpers.event import async_track_sunrise, async_track_sunset @@ -31,10 +31,10 @@ def async_trigger(hass, config, action): event = config.get(CONF_EVENT) offset = config.get(CONF_OFFSET) - @asyncio.coroutine + @callback def call_action(): """Call action with right context.""" - hass.async_add_job(action, { + hass.async_run_job(action, { 'trigger': { 'platform': 'sun', 'event': event, diff --git a/homeassistant/components/automation/template.py b/homeassistant/components/automation/template.py index 94f57dbbc02..90d75d0d982 100644 --- a/homeassistant/components/automation/template.py +++ b/homeassistant/components/automation/template.py @@ -4,11 +4,11 @@ Offer template automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#template-trigger """ -import asyncio import logging import voluptuous as vol +from homeassistant.core import callback from homeassistant.const import CONF_VALUE_TEMPLATE, CONF_PLATFORM from homeassistant.helpers import condition from homeassistant.helpers.event import async_track_state_change @@ -31,7 +31,7 @@ def async_trigger(hass, config, action): # Local variable to keep track of if the action has already been triggered already_triggered = False - @asyncio.coroutine + @callback def state_changed_listener(entity_id, from_s, to_s): """Listen for state changes and calls action.""" nonlocal already_triggered @@ -40,7 +40,7 @@ def async_trigger(hass, config, action): # Check to see if template returns true if template_result and not already_triggered: already_triggered = True - hass.async_add_job(action, { + hass.async_run_job(action, { 'trigger': { 'platform': 'template', 'entity_id': entity_id, diff --git a/homeassistant/components/automation/time.py b/homeassistant/components/automation/time.py index 190a6519278..d0315f26de0 100644 --- a/homeassistant/components/automation/time.py +++ b/homeassistant/components/automation/time.py @@ -4,11 +4,11 @@ Offer time listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#time-trigger """ -import asyncio import logging import voluptuous as vol +from homeassistant.core import callback from homeassistant.const import CONF_AFTER, CONF_PLATFORM from homeassistant.helpers import config_validation as cv from homeassistant.helpers.event import async_track_time_change @@ -39,10 +39,10 @@ def async_trigger(hass, config, action): minutes = config.get(CONF_MINUTES) seconds = config.get(CONF_SECONDS) - @asyncio.coroutine + @callback def time_automation_listener(now): """Listen for time changes and calls action.""" - hass.async_add_job(action, { + hass.async_run_job(action, { 'trigger': { 'platform': 'time', 'now': now, diff --git a/homeassistant/components/automation/zone.py b/homeassistant/components/automation/zone.py index 59812738692..935dc3cf24c 100644 --- a/homeassistant/components/automation/zone.py +++ b/homeassistant/components/automation/zone.py @@ -4,9 +4,9 @@ Offer zone automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#zone-trigger """ -import asyncio import voluptuous as vol +from homeassistant.core import callback from homeassistant.const import ( CONF_EVENT, CONF_ENTITY_ID, CONF_ZONE, MATCH_ALL, CONF_PLATFORM) from homeassistant.helpers.event import async_track_state_change @@ -32,7 +32,7 @@ def async_trigger(hass, config, action): zone_entity_id = config.get(CONF_ZONE) event = config.get(CONF_EVENT) - @asyncio.coroutine + @callback def zone_automation_listener(entity, from_s, to_s): """Listen for state changes and calls action.""" if from_s and not location.has_location(from_s) or \ @@ -49,7 +49,7 @@ def async_trigger(hass, config, action): # pylint: disable=too-many-boolean-expressions if event == EVENT_ENTER and not from_match and to_match or \ event == EVENT_LEAVE and from_match and not to_match: - hass.async_add_job(action, { + hass.async_run_job(action, { 'trigger': { 'platform': 'zone', 'entity_id': entity, diff --git a/homeassistant/components/binary_sensor/template.py b/homeassistant/components/binary_sensor/template.py index 339a5cb9ba1..d179edfc1d8 100644 --- a/homeassistant/components/binary_sensor/template.py +++ b/homeassistant/components/binary_sensor/template.py @@ -9,6 +9,7 @@ import logging import voluptuous as vol +from homeassistant.core import callback from homeassistant.components.binary_sensor import ( BinarySensorDevice, ENTITY_ID_FORMAT, PLATFORM_SCHEMA, SENSOR_CLASSES_SCHEMA) @@ -82,7 +83,7 @@ class BinarySensorTemplate(BinarySensorDevice): self.update() - @asyncio.coroutine + @callback def template_bsensor_state_listener(entity, old_state, new_state): """Called when the target device changes state.""" hass.loop.create_task(self.async_update_ha_state(True)) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 7995d9bf39a..3edd0ffc500 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -171,7 +171,7 @@ def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS): if not _match_topic(topic, event.data[ATTR_TOPIC]): return - hass.async_add_job(callback, event.data[ATTR_TOPIC], + hass.async_run_job(callback, event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD], event.data[ATTR_QOS]) async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED, diff --git a/homeassistant/components/sensor/template.py b/homeassistant/components/sensor/template.py index ed905f44ebd..1abd1d2fd94 100644 --- a/homeassistant/components/sensor/template.py +++ b/homeassistant/components/sensor/template.py @@ -9,6 +9,7 @@ import logging import voluptuous as vol +from homeassistant.core import callback from homeassistant.components.sensor import ENTITY_ID_FORMAT, PLATFORM_SCHEMA from homeassistant.const import ( ATTR_FRIENDLY_NAME, ATTR_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE, @@ -79,7 +80,7 @@ class SensorTemplate(Entity): self.update() - @asyncio.coroutine + @callback def template_sensor_state_listener(entity, old_state, new_state): """Called when the target device changes state.""" hass.loop.create_task(self.async_update_ha_state(True)) diff --git a/homeassistant/components/switch/template.py b/homeassistant/components/switch/template.py index bcd74454ce5..b6ce400d0ac 100644 --- a/homeassistant/components/switch/template.py +++ b/homeassistant/components/switch/template.py @@ -9,6 +9,7 @@ import logging import voluptuous as vol +from homeassistant.core import callback from homeassistant.components.switch import ( ENTITY_ID_FORMAT, SwitchDevice, PLATFORM_SCHEMA) from homeassistant.const import ( @@ -88,7 +89,7 @@ class SwitchTemplate(SwitchDevice): self.update() - @asyncio.coroutine + @callback def template_switch_state_listener(entity, old_state, new_state): """Called when the target device changes state.""" hass.loop.create_task(self.async_update_ha_state(True)) diff --git a/homeassistant/core.py b/homeassistant/core.py index 83056d6d7f2..0e0d1953992 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -78,6 +78,18 @@ def valid_entity_id(entity_id: str) -> bool: return ENTITY_ID_PATTERN.match(entity_id) is not None +def callback(func: Callable[..., None]) -> Callable[..., None]: + """Annotation to mark method as safe to call from within the event loop.""" + # pylint: disable=protected-access + func._hass_callback = True + return func + + +def is_callback(func: Callable[..., Any]) -> bool: + """Check if function is safe to be called in the event loop.""" + return '_hass_callback' in func.__dict__ + + class CoreState(enum.Enum): """Represent the current state of Home Assistant.""" @@ -224,11 +236,24 @@ class HomeAssistant(object): target: target to call. args: parameters for method to call. """ - if asyncio.iscoroutinefunction(target): + if is_callback(target): + self.loop.call_soon(target, *args) + elif asyncio.iscoroutinefunction(target): self.loop.create_task(target(*args)) else: self.add_job(target, *args) + def async_run_job(self, target: Callable[..., None], *args: Any): + """Run a job from within the event loop. + + target: target to call. + args: parameters for method to call. + """ + if is_callback(target): + target(*args) + else: + self.async_add_job(target, *args) + def _loop_empty(self): """Python 3.4.2 empty loop compatibility function.""" # pylint: disable=protected-access @@ -380,7 +405,6 @@ class EventBus(object): self._loop.call_soon_threadsafe(self.async_fire, event_type, event_data, origin) - return def async_fire(self, event_type: str, event_data=None, origin=EventOrigin.local, wait=False): @@ -408,6 +432,8 @@ class EventBus(object): for func in listeners: if asyncio.iscoroutinefunction(func): self._loop.create_task(func(event)) + elif is_callback(func): + self._loop.call_soon(func, event) else: sync_jobs.append((job_priority, (func, event))) @@ -795,7 +821,7 @@ class Service(object): """Represents a callable service.""" __slots__ = ['func', 'description', 'fields', 'schema', - 'iscoroutinefunction'] + 'is_callback', 'is_coroutinefunction'] def __init__(self, func, description, fields, schema): """Initialize a service.""" @@ -803,7 +829,8 @@ class Service(object): self.description = description or '' self.fields = fields or {} self.schema = schema - self.iscoroutinefunction = asyncio.iscoroutinefunction(func) + self.is_callback = is_callback(func) + self.is_coroutinefunction = asyncio.iscoroutinefunction(func) def as_dict(self): """Return dictionary representation of this service.""" @@ -934,7 +961,7 @@ class ServiceRegistry(object): self._loop ).result() - @asyncio.coroutine + @callback def async_call(self, domain, service, service_data=None, blocking=False): """ Call a service. @@ -966,7 +993,7 @@ class ServiceRegistry(object): if blocking: fut = asyncio.Future(loop=self._loop) - @asyncio.coroutine + @callback def service_executed(event): """Callback method that is called when service is executed.""" if event.data[ATTR_SERVICE_CALL_ID] == call_id: @@ -1007,7 +1034,8 @@ class ServiceRegistry(object): data = {ATTR_SERVICE_CALL_ID: call_id} - if service_handler.iscoroutinefunction: + if (service_handler.is_coroutinefunction or + service_handler.is_callback): self._bus.async_fire(EVENT_SERVICE_EXECUTED, data) else: self._bus.fire(EVENT_SERVICE_EXECUTED, data) @@ -1023,17 +1051,19 @@ class ServiceRegistry(object): service_call = ServiceCall(domain, service, service_data, call_id) - if not service_handler.iscoroutinefunction: + if service_handler.is_callback: + service_handler.func(service_call) + fire_service_executed() + elif service_handler.is_coroutinefunction: + yield from service_handler.func(service_call) + fire_service_executed() + else: def execute_service(): """Execute a service and fires a SERVICE_EXECUTED event.""" service_handler.func(service_call) fire_service_executed() self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE) - return - - yield from service_handler.func(service_call) - fire_service_executed() def _generate_unique_id(self): """Generate a unique service call id.""" @@ -1098,7 +1128,7 @@ def async_create_timer(hass, interval=TIMER_INTERVAL): stop_event = asyncio.Event(loop=hass.loop) # Setting the Event inside the loop by marking it as a coroutine - @asyncio.coroutine + @callback def stop_timer(event): """Stop the timer.""" stop_event.set() @@ -1212,7 +1242,7 @@ def async_monitor_worker_pool(hass): schedule() - @asyncio.coroutine + @callback def stop_monitor(event): """Stop the monitor.""" handle.cancel() diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 69f620adb82..390af3c7ad1 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -1,9 +1,8 @@ """Helpers for listening to events.""" -import asyncio import functools as ft from datetime import timedelta -from ..core import HomeAssistant +from ..core import HomeAssistant, callback from ..const import ( ATTR_NOW, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL) from ..util import dt as dt_util @@ -57,8 +56,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None, else: entity_ids = tuple(entity_id.lower() for entity_id in entity_ids) - @ft.wraps(action) - @asyncio.coroutine + @callback def state_change_listener(event): """The listener that listens for specific state changes.""" if entity_ids != MATCH_ALL and \ @@ -76,7 +74,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None, new_state = None if _matcher(old_state, from_state) and _matcher(new_state, to_state): - hass.async_add_job(action, event.data.get('entity_id'), + hass.async_run_job(action, event.data.get('entity_id'), event.data.get('old_state'), event.data.get('new_state')) @@ -90,11 +88,10 @@ def async_track_point_in_time(hass, action, point_in_time): """Add a listener that fires once after a spefic point in time.""" utc_point_in_time = dt_util.as_utc(point_in_time) - @ft.wraps(action) - @asyncio.coroutine + @callback def utc_converter(utc_now): """Convert passed in UTC now to local now.""" - hass.async_add_job(action, dt_util.as_local(utc_now)) + hass.async_run_job(action, dt_util.as_local(utc_now)) return async_track_point_in_utc_time(hass, utc_converter, utc_point_in_time) @@ -108,8 +105,7 @@ def async_track_point_in_utc_time(hass, action, point_in_time): # Ensure point_in_time is UTC point_in_time = dt_util.as_utc(point_in_time) - @ft.wraps(action) - @asyncio.coroutine + @callback def point_in_time_listener(event): """Listen for matching time_changed events.""" now = event.data[ATTR_NOW] @@ -125,7 +121,7 @@ def async_track_point_in_utc_time(hass, action, point_in_time): point_in_time_listener.run = True async_unsub() - hass.async_add_job(action, now) + hass.async_run_job(action, now) async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED, point_in_time_listener) @@ -151,14 +147,13 @@ def async_track_sunrise(hass, action, offset=None): return next_time - @ft.wraps(action) - @asyncio.coroutine + @callback def sunrise_automation_listener(now): """Called when it's time for action.""" nonlocal remove remove = async_track_point_in_utc_time( hass, sunrise_automation_listener, next_rise()) - hass.async_add_job(action) + hass.async_run_job(action) remove = async_track_point_in_utc_time( hass, sunrise_automation_listener, next_rise()) @@ -187,14 +182,13 @@ def async_track_sunset(hass, action, offset=None): return next_time - @ft.wraps(action) - @asyncio.coroutine + @callback def sunset_automation_listener(now): """Called when it's time for action.""" nonlocal remove remove = async_track_point_in_utc_time( hass, sunset_automation_listener, next_set()) - hass.async_add_job(action) + hass.async_run_job(action) remove = async_track_point_in_utc_time( hass, sunset_automation_listener, next_set()) @@ -217,10 +211,10 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None, # We do not have to wrap the function with time pattern matching logic # if no pattern given if all(val is None for val in (year, month, day, hour, minute, second)): - @ft.wraps(action) + @callback def time_change_listener(event): """Fire every time event that comes in.""" - action(event.data[ATTR_NOW]) + hass.async_run_job(action, event.data[ATTR_NOW]) return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener) @@ -228,8 +222,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None, year, month, day = pmp(year), pmp(month), pmp(day) hour, minute, second = pmp(hour), pmp(minute), pmp(second) - @ft.wraps(action) - @asyncio.coroutine + @callback def pattern_time_change_listener(event): """Listen for matching time_changed events.""" now = event.data[ATTR_NOW] @@ -246,7 +239,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None, mat(now.minute, minute) and \ mat(now.second, second): - hass.async_add_job(action, now) + hass.async_run_job(action, now) return hass.bus.async_listen(EVENT_TIME_CHANGED, pattern_time_change_listener) diff --git a/tests/common.py b/tests/common.py index b44cbee4b6f..fb5dab7004b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -111,7 +111,7 @@ def mock_service(hass, domain, service): """ calls = [] - hass.services.register(domain, service, calls.append) + hass.services.register(domain, service, lambda call: calls.append(call)) return calls diff --git a/tests/components/test_api.py b/tests/components/test_api.py index 4e7d98cd6cc..dee4320824b 100644 --- a/tests/components/test_api.py +++ b/tests/components/test_api.py @@ -139,7 +139,8 @@ class TestAPI(unittest.TestCase): hass.states.set("test.test", "not_to_be_set") events = [] - hass.bus.listen(const.EVENT_STATE_CHANGED, events.append) + hass.bus.listen(const.EVENT_STATE_CHANGED, + lambda ev: events.append(ev)) requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")), data=json.dumps({"state": "not_to_be_set"}), diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 4993ce92b3a..89c97434f8d 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -1,6 +1,7 @@ """Test event helpers.""" # pylint: disable=protected-access,too-many-public-methods # pylint: disable=too-few-public-methods +import asyncio import unittest from datetime import datetime, timedelta @@ -113,17 +114,23 @@ class TestEventHelpers(unittest.TestCase): wildcard_runs = [] wildercard_runs = [] - track_state_change( - self.hass, 'light.Bowl', lambda a, b, c: specific_runs.append(1), - 'on', 'off') + def specific_run_callback(entity_id, old_state, new_state): + specific_runs.append(1) track_state_change( - self.hass, 'light.Bowl', - lambda _, old_s, new_s: wildcard_runs.append((old_s, new_s))) + self.hass, 'light.Bowl', specific_run_callback, 'on', 'off') - track_state_change( - self.hass, MATCH_ALL, - lambda _, old_s, new_s: wildercard_runs.append((old_s, new_s))) + @ha.callback + def wildcard_run_callback(entity_id, old_state, new_state): + wildcard_runs.append((old_state, new_state)) + + track_state_change(self.hass, 'light.Bowl', wildcard_run_callback) + + @asyncio.coroutine + def wildercard_run_callback(entity_id, old_state, new_state): + wildercard_runs.append((old_state, new_state)) + + track_state_change(self.hass, MATCH_ALL, wildercard_run_callback) # Adding state to state machine self.hass.states.set("light.Bowl", "on") diff --git a/tests/test_core.py b/tests/test_core.py index 80a8c6d4c5f..6f480baa71b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -23,13 +23,70 @@ from tests.common import get_test_home_assistant PST = pytz.timezone('America/Los_Angeles') -class TestMethods(unittest.TestCase): - """Test the Home Assistant helper methods.""" +def test_split_entity_id(): + """Test split_entity_id.""" + assert ha.split_entity_id('domain.object_id') == ['domain', 'object_id'] - def test_split_entity_id(self): - """Test split_entity_id.""" - self.assertEqual(['domain', 'object_id'], - ha.split_entity_id('domain.object_id')) + +def test_async_add_job_schedule_callback(): + """Test that we schedule coroutines and add jobs to the job pool.""" + hass = MagicMock() + job = MagicMock() + + ha.HomeAssistant.async_add_job(hass, ha.callback(job)) + assert len(hass.loop.call_soon.mock_calls) == 1 + assert len(hass.loop.create_task.mock_calls) == 0 + assert len(hass.add_job.mock_calls) == 0 + + +@patch('asyncio.iscoroutinefunction', return_value=True) +def test_async_add_job_schedule_coroutinefunction(mock_iscoro): + """Test that we schedule coroutines and add jobs to the job pool.""" + hass = MagicMock() + job = MagicMock() + + ha.HomeAssistant.async_add_job(hass, job) + assert len(hass.loop.call_soon.mock_calls) == 0 + assert len(hass.loop.create_task.mock_calls) == 1 + assert len(hass.add_job.mock_calls) == 0 + + +@patch('asyncio.iscoroutinefunction', return_value=False) +def test_async_add_job_add_threaded_job_to_pool(mock_iscoro): + """Test that we schedule coroutines and add jobs to the job pool.""" + hass = MagicMock() + job = MagicMock() + + ha.HomeAssistant.async_add_job(hass, job) + assert len(hass.loop.call_soon.mock_calls) == 0 + assert len(hass.loop.create_task.mock_calls) == 0 + assert len(hass.add_job.mock_calls) == 1 + + +def test_async_run_job_calls_callback(): + """Test that the callback annotation is respected.""" + hass = MagicMock() + calls = [] + + def job(): + calls.append(1) + + ha.HomeAssistant.async_run_job(hass, ha.callback(job)) + assert len(calls) == 1 + assert len(hass.async_add_job.mock_calls) == 0 + + +def test_async_run_job_delegates_non_async(): + """Test that the callback annotation is respected.""" + hass = MagicMock() + calls = [] + + def job(): + calls.append(1) + + ha.HomeAssistant.async_run_job(hass, job) + assert len(calls) == 0 + assert len(hass.async_add_job.mock_calls) == 1 class TestHomeAssistant(unittest.TestCase): @@ -173,6 +230,44 @@ class TestEventBus(unittest.TestCase): self.hass.block_till_done() self.assertEqual(1, len(runs)) + def test_thread_event_listener(self): + """Test a event listener listeners.""" + thread_calls = [] + + def thread_listener(event): + thread_calls.append(event) + + self.bus.listen('test_thread', thread_listener) + self.bus.fire('test_thread') + self.hass.block_till_done() + assert len(thread_calls) == 1 + + def test_callback_event_listener(self): + """Test a event listener listeners.""" + callback_calls = [] + + @ha.callback + def callback_listener(event): + callback_calls.append(event) + + self.bus.listen('test_callback', callback_listener) + self.bus.fire('test_callback') + self.hass.block_till_done() + assert len(callback_calls) == 1 + + def test_coroutine_event_listener(self): + """Test a event listener listeners.""" + coroutine_calls = [] + + @asyncio.coroutine + def coroutine_listener(event): + coroutine_calls.append(event) + + self.bus.listen('test_coroutine', coroutine_listener) + self.bus.fire('test_coroutine') + self.hass.block_till_done() + assert len(coroutine_calls) == 1 + class TestState(unittest.TestCase): """Test State methods.""" @@ -330,7 +425,7 @@ class TestStateMachine(unittest.TestCase): def test_force_update(self): """Test force update option.""" events = [] - self.hass.bus.listen(EVENT_STATE_CHANGED, events.append) + self.hass.bus.listen(EVENT_STATE_CHANGED, lambda ev: events.append(ev)) self.states.set('light.bowl', 'on') self.hass.block_till_done() @@ -425,6 +520,22 @@ class TestServiceRegistry(unittest.TestCase): self.hass.block_till_done() self.assertEqual(1, len(calls)) + def test_callback_service(self): + """Test registering and calling an async service.""" + calls = [] + + @ha.callback + def service_handler(call): + """Service handler coroutine.""" + calls.append(call) + + self.services.register('test_domain', 'register_calls', + service_handler) + self.assertTrue( + self.services.call('test_domain', 'REGISTER_CALLS', blocking=True)) + self.hass.block_till_done() + self.assertEqual(1, len(calls)) + class TestConfig(unittest.TestCase): """Test configuration methods.""" @@ -524,8 +635,7 @@ class TestWorkerPoolMonitor(object): check_threshold() assert mock_warning.called - event_loop.run_until_complete( - hass.bus.async_listen_once.mock_calls[0][1][1](None)) + hass.bus.async_listen_once.mock_calls[0][1][1](None) assert schedule_handle.cancel.called @@ -561,5 +671,5 @@ class TestAsyncCreateTimer(object): assert {ha.ATTR_NOW: now} == event_data stop_timer = hass.bus.async_listen_once.mock_calls[0][1][1] - event_loop.run_until_complete(stop_timer(None)) + stop_timer(None) assert event.set.called diff --git a/tests/test_remote.py b/tests/test_remote.py index a5212face2c..653971f8bc1 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -163,7 +163,7 @@ class TestRemoteMethods(unittest.TestCase): def test_set_state_with_push(self): """Test Python API set_state with push option.""" events = [] - hass.bus.listen(EVENT_STATE_CHANGED, events.append) + hass.bus.listen(EVENT_STATE_CHANGED, lambda ev: events.append(ev)) remote.set_state(master_api, 'test.test', 'set_test_2') remote.set_state(master_api, 'test.test', 'set_test_2')