Add async_safe annotation (#3688)

* Add async_safe annotation

* More async_run_job

* coroutine -> async_save

* Lint

* Rename async_safe -> callback

* Add tests to core for different job types

* Add one more test with different type of callbacks

* Fix typing signature for callback methods

* Fix callback service executed method

* Fix method signatures for callback
This commit is contained in:
Paulus Schoutsen 2016-10-04 20:44:32 -07:00 committed by GitHub
parent be7401f4a2
commit 5085cdb0f7
19 changed files with 231 additions and 87 deletions

View File

@ -4,11 +4,11 @@ Offer event listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#event-trigger
"""
import asyncio
import logging
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import CONF_PLATFORM
from homeassistant.helpers import config_validation as cv
@ -29,12 +29,12 @@ def async_trigger(hass, config, action):
event_type = config.get(CONF_EVENT_TYPE)
event_data = config.get(CONF_EVENT_DATA)
@asyncio.coroutine
@callback
def handle_event(event):
"""Listen for events and calls the action when data matches."""
if not event_data or all(val == event.data.get(key) for key, val
in event_data.items()):
hass.async_add_job(action, {
hass.async_run_job(action, {
'trigger': {
'platform': 'event',
'event': event,

View File

@ -4,9 +4,9 @@ Offer MQTT listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#mqtt-trigger
"""
import asyncio
import voluptuous as vol
from homeassistant.core import callback
import homeassistant.components.mqtt as mqtt
from homeassistant.const import (CONF_PLATFORM, CONF_PAYLOAD)
import homeassistant.helpers.config_validation as cv
@ -27,11 +27,11 @@ def async_trigger(hass, config, action):
topic = config.get(CONF_TOPIC)
payload = config.get(CONF_PAYLOAD)
@asyncio.coroutine
@callback
def mqtt_automation_listener(msg_topic, msg_payload, qos):
"""Listen for MQTT messages."""
if payload is None or payload == msg_payload:
hass.async_add_job(action, {
hass.async_run_job(action, {
'trigger': {
'platform': 'mqtt',
'topic': msg_topic,

View File

@ -4,11 +4,11 @@ Offer numeric state listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#numeric-state-trigger
"""
import asyncio
import logging
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import (
CONF_VALUE_TEMPLATE, CONF_PLATFORM, CONF_ENTITY_ID,
CONF_BELOW, CONF_ABOVE)
@ -35,7 +35,7 @@ def async_trigger(hass, config, action):
if value_template is not None:
value_template.hass = hass
@asyncio.coroutine
@callback
def state_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action."""
if to_s is None:
@ -64,6 +64,6 @@ def async_trigger(hass, config, action):
variables['trigger']['from_state'] = from_s
variables['trigger']['to_state'] = to_s
hass.async_add_job(action, variables)
hass.async_run_job(action, variables)
return async_track_state_change(hass, entity_id, state_automation_listener)

View File

@ -4,9 +4,9 @@ Offer state listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#state-trigger
"""
import asyncio
import voluptuous as vol
from homeassistant.core import callback
import homeassistant.util.dt as dt_util
from homeassistant.const import MATCH_ALL, CONF_PLATFORM
from homeassistant.helpers.event import (
@ -43,14 +43,14 @@ def async_trigger(hass, config, action):
async_remove_state_for_cancel = None
async_remove_state_for_listener = None
@asyncio.coroutine
@callback
def state_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action."""
nonlocal async_remove_state_for_cancel, async_remove_state_for_listener
def call_action():
"""Call action with right context."""
hass.async_add_job(action, {
hass.async_run_job(action, {
'trigger': {
'platform': 'state',
'entity_id': entity,
@ -64,13 +64,13 @@ def async_trigger(hass, config, action):
call_action()
return
@asyncio.coroutine
@callback
def state_for_listener(now):
"""Fire on state changes after a delay and calls action."""
async_remove_state_for_cancel()
call_action()
@asyncio.coroutine
@callback
def state_for_cancel_listener(entity, inner_from_s, inner_to_s):
"""Fire on changes and cancel for listener if changed."""
if inner_to_s.state == to_s.state:

View File

@ -4,12 +4,12 @@ Offer sun based automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#sun-trigger
"""
import asyncio
from datetime import timedelta
import logging
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import (
CONF_EVENT, CONF_OFFSET, CONF_PLATFORM, SUN_EVENT_SUNRISE)
from homeassistant.helpers.event import async_track_sunrise, async_track_sunset
@ -31,10 +31,10 @@ def async_trigger(hass, config, action):
event = config.get(CONF_EVENT)
offset = config.get(CONF_OFFSET)
@asyncio.coroutine
@callback
def call_action():
"""Call action with right context."""
hass.async_add_job(action, {
hass.async_run_job(action, {
'trigger': {
'platform': 'sun',
'event': event,

View File

@ -4,11 +4,11 @@ Offer template automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#template-trigger
"""
import asyncio
import logging
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import CONF_VALUE_TEMPLATE, CONF_PLATFORM
from homeassistant.helpers import condition
from homeassistant.helpers.event import async_track_state_change
@ -31,7 +31,7 @@ def async_trigger(hass, config, action):
# Local variable to keep track of if the action has already been triggered
already_triggered = False
@asyncio.coroutine
@callback
def state_changed_listener(entity_id, from_s, to_s):
"""Listen for state changes and calls action."""
nonlocal already_triggered
@ -40,7 +40,7 @@ def async_trigger(hass, config, action):
# Check to see if template returns true
if template_result and not already_triggered:
already_triggered = True
hass.async_add_job(action, {
hass.async_run_job(action, {
'trigger': {
'platform': 'template',
'entity_id': entity_id,

View File

@ -4,11 +4,11 @@ Offer time listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#time-trigger
"""
import asyncio
import logging
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import CONF_AFTER, CONF_PLATFORM
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.event import async_track_time_change
@ -39,10 +39,10 @@ def async_trigger(hass, config, action):
minutes = config.get(CONF_MINUTES)
seconds = config.get(CONF_SECONDS)
@asyncio.coroutine
@callback
def time_automation_listener(now):
"""Listen for time changes and calls action."""
hass.async_add_job(action, {
hass.async_run_job(action, {
'trigger': {
'platform': 'time',
'now': now,

View File

@ -4,9 +4,9 @@ Offer zone automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#zone-trigger
"""
import asyncio
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import (
CONF_EVENT, CONF_ENTITY_ID, CONF_ZONE, MATCH_ALL, CONF_PLATFORM)
from homeassistant.helpers.event import async_track_state_change
@ -32,7 +32,7 @@ def async_trigger(hass, config, action):
zone_entity_id = config.get(CONF_ZONE)
event = config.get(CONF_EVENT)
@asyncio.coroutine
@callback
def zone_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action."""
if from_s and not location.has_location(from_s) or \
@ -49,7 +49,7 @@ def async_trigger(hass, config, action):
# pylint: disable=too-many-boolean-expressions
if event == EVENT_ENTER and not from_match and to_match or \
event == EVENT_LEAVE and from_match and not to_match:
hass.async_add_job(action, {
hass.async_run_job(action, {
'trigger': {
'platform': 'zone',
'entity_id': entity,

View File

@ -9,6 +9,7 @@ import logging
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components.binary_sensor import (
BinarySensorDevice, ENTITY_ID_FORMAT, PLATFORM_SCHEMA,
SENSOR_CLASSES_SCHEMA)
@ -82,7 +83,7 @@ class BinarySensorTemplate(BinarySensorDevice):
self.update()
@asyncio.coroutine
@callback
def template_bsensor_state_listener(entity, old_state, new_state):
"""Called when the target device changes state."""
hass.loop.create_task(self.async_update_ha_state(True))

View File

@ -171,7 +171,7 @@ def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS):
if not _match_topic(topic, event.data[ATTR_TOPIC]):
return
hass.async_add_job(callback, event.data[ATTR_TOPIC],
hass.async_run_job(callback, event.data[ATTR_TOPIC],
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS])
async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED,

View File

@ -9,6 +9,7 @@ import logging
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components.sensor import ENTITY_ID_FORMAT, PLATFORM_SCHEMA
from homeassistant.const import (
ATTR_FRIENDLY_NAME, ATTR_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE,
@ -79,7 +80,7 @@ class SensorTemplate(Entity):
self.update()
@asyncio.coroutine
@callback
def template_sensor_state_listener(entity, old_state, new_state):
"""Called when the target device changes state."""
hass.loop.create_task(self.async_update_ha_state(True))

View File

@ -9,6 +9,7 @@ import logging
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components.switch import (
ENTITY_ID_FORMAT, SwitchDevice, PLATFORM_SCHEMA)
from homeassistant.const import (
@ -88,7 +89,7 @@ class SwitchTemplate(SwitchDevice):
self.update()
@asyncio.coroutine
@callback
def template_switch_state_listener(entity, old_state, new_state):
"""Called when the target device changes state."""
hass.loop.create_task(self.async_update_ha_state(True))

View File

@ -78,6 +78,18 @@ def valid_entity_id(entity_id: str) -> bool:
return ENTITY_ID_PATTERN.match(entity_id) is not None
def callback(func: Callable[..., None]) -> Callable[..., None]:
"""Annotation to mark method as safe to call from within the event loop."""
# pylint: disable=protected-access
func._hass_callback = True
return func
def is_callback(func: Callable[..., Any]) -> bool:
"""Check if function is safe to be called in the event loop."""
return '_hass_callback' in func.__dict__
class CoreState(enum.Enum):
"""Represent the current state of Home Assistant."""
@ -224,11 +236,24 @@ class HomeAssistant(object):
target: target to call.
args: parameters for method to call.
"""
if asyncio.iscoroutinefunction(target):
if is_callback(target):
self.loop.call_soon(target, *args)
elif asyncio.iscoroutinefunction(target):
self.loop.create_task(target(*args))
else:
self.add_job(target, *args)
def async_run_job(self, target: Callable[..., None], *args: Any):
"""Run a job from within the event loop.
target: target to call.
args: parameters for method to call.
"""
if is_callback(target):
target(*args)
else:
self.async_add_job(target, *args)
def _loop_empty(self):
"""Python 3.4.2 empty loop compatibility function."""
# pylint: disable=protected-access
@ -380,7 +405,6 @@ class EventBus(object):
self._loop.call_soon_threadsafe(self.async_fire, event_type,
event_data, origin)
return
def async_fire(self, event_type: str, event_data=None,
origin=EventOrigin.local, wait=False):
@ -408,6 +432,8 @@ class EventBus(object):
for func in listeners:
if asyncio.iscoroutinefunction(func):
self._loop.create_task(func(event))
elif is_callback(func):
self._loop.call_soon(func, event)
else:
sync_jobs.append((job_priority, (func, event)))
@ -795,7 +821,7 @@ class Service(object):
"""Represents a callable service."""
__slots__ = ['func', 'description', 'fields', 'schema',
'iscoroutinefunction']
'is_callback', 'is_coroutinefunction']
def __init__(self, func, description, fields, schema):
"""Initialize a service."""
@ -803,7 +829,8 @@ class Service(object):
self.description = description or ''
self.fields = fields or {}
self.schema = schema
self.iscoroutinefunction = asyncio.iscoroutinefunction(func)
self.is_callback = is_callback(func)
self.is_coroutinefunction = asyncio.iscoroutinefunction(func)
def as_dict(self):
"""Return dictionary representation of this service."""
@ -934,7 +961,7 @@ class ServiceRegistry(object):
self._loop
).result()
@asyncio.coroutine
@callback
def async_call(self, domain, service, service_data=None, blocking=False):
"""
Call a service.
@ -966,7 +993,7 @@ class ServiceRegistry(object):
if blocking:
fut = asyncio.Future(loop=self._loop)
@asyncio.coroutine
@callback
def service_executed(event):
"""Callback method that is called when service is executed."""
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
@ -1007,7 +1034,8 @@ class ServiceRegistry(object):
data = {ATTR_SERVICE_CALL_ID: call_id}
if service_handler.iscoroutinefunction:
if (service_handler.is_coroutinefunction or
service_handler.is_callback):
self._bus.async_fire(EVENT_SERVICE_EXECUTED, data)
else:
self._bus.fire(EVENT_SERVICE_EXECUTED, data)
@ -1023,17 +1051,19 @@ class ServiceRegistry(object):
service_call = ServiceCall(domain, service, service_data, call_id)
if not service_handler.iscoroutinefunction:
if service_handler.is_callback:
service_handler.func(service_call)
fire_service_executed()
elif service_handler.is_coroutinefunction:
yield from service_handler.func(service_call)
fire_service_executed()
else:
def execute_service():
"""Execute a service and fires a SERVICE_EXECUTED event."""
service_handler.func(service_call)
fire_service_executed()
self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE)
return
yield from service_handler.func(service_call)
fire_service_executed()
def _generate_unique_id(self):
"""Generate a unique service call id."""
@ -1098,7 +1128,7 @@ def async_create_timer(hass, interval=TIMER_INTERVAL):
stop_event = asyncio.Event(loop=hass.loop)
# Setting the Event inside the loop by marking it as a coroutine
@asyncio.coroutine
@callback
def stop_timer(event):
"""Stop the timer."""
stop_event.set()
@ -1212,7 +1242,7 @@ def async_monitor_worker_pool(hass):
schedule()
@asyncio.coroutine
@callback
def stop_monitor(event):
"""Stop the monitor."""
handle.cancel()

View File

@ -1,9 +1,8 @@
"""Helpers for listening to events."""
import asyncio
import functools as ft
from datetime import timedelta
from ..core import HomeAssistant
from ..core import HomeAssistant, callback
from ..const import (
ATTR_NOW, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
from ..util import dt as dt_util
@ -57,8 +56,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
else:
entity_ids = tuple(entity_id.lower() for entity_id in entity_ids)
@ft.wraps(action)
@asyncio.coroutine
@callback
def state_change_listener(event):
"""The listener that listens for specific state changes."""
if entity_ids != MATCH_ALL and \
@ -76,7 +74,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
new_state = None
if _matcher(old_state, from_state) and _matcher(new_state, to_state):
hass.async_add_job(action, event.data.get('entity_id'),
hass.async_run_job(action, event.data.get('entity_id'),
event.data.get('old_state'),
event.data.get('new_state'))
@ -90,11 +88,10 @@ def async_track_point_in_time(hass, action, point_in_time):
"""Add a listener that fires once after a spefic point in time."""
utc_point_in_time = dt_util.as_utc(point_in_time)
@ft.wraps(action)
@asyncio.coroutine
@callback
def utc_converter(utc_now):
"""Convert passed in UTC now to local now."""
hass.async_add_job(action, dt_util.as_local(utc_now))
hass.async_run_job(action, dt_util.as_local(utc_now))
return async_track_point_in_utc_time(hass, utc_converter,
utc_point_in_time)
@ -108,8 +105,7 @@ def async_track_point_in_utc_time(hass, action, point_in_time):
# Ensure point_in_time is UTC
point_in_time = dt_util.as_utc(point_in_time)
@ft.wraps(action)
@asyncio.coroutine
@callback
def point_in_time_listener(event):
"""Listen for matching time_changed events."""
now = event.data[ATTR_NOW]
@ -125,7 +121,7 @@ def async_track_point_in_utc_time(hass, action, point_in_time):
point_in_time_listener.run = True
async_unsub()
hass.async_add_job(action, now)
hass.async_run_job(action, now)
async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED,
point_in_time_listener)
@ -151,14 +147,13 @@ def async_track_sunrise(hass, action, offset=None):
return next_time
@ft.wraps(action)
@asyncio.coroutine
@callback
def sunrise_automation_listener(now):
"""Called when it's time for action."""
nonlocal remove
remove = async_track_point_in_utc_time(
hass, sunrise_automation_listener, next_rise())
hass.async_add_job(action)
hass.async_run_job(action)
remove = async_track_point_in_utc_time(
hass, sunrise_automation_listener, next_rise())
@ -187,14 +182,13 @@ def async_track_sunset(hass, action, offset=None):
return next_time
@ft.wraps(action)
@asyncio.coroutine
@callback
def sunset_automation_listener(now):
"""Called when it's time for action."""
nonlocal remove
remove = async_track_point_in_utc_time(
hass, sunset_automation_listener, next_set())
hass.async_add_job(action)
hass.async_run_job(action)
remove = async_track_point_in_utc_time(
hass, sunset_automation_listener, next_set())
@ -217,10 +211,10 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
# We do not have to wrap the function with time pattern matching logic
# if no pattern given
if all(val is None for val in (year, month, day, hour, minute, second)):
@ft.wraps(action)
@callback
def time_change_listener(event):
"""Fire every time event that comes in."""
action(event.data[ATTR_NOW])
hass.async_run_job(action, event.data[ATTR_NOW])
return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener)
@ -228,8 +222,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
year, month, day = pmp(year), pmp(month), pmp(day)
hour, minute, second = pmp(hour), pmp(minute), pmp(second)
@ft.wraps(action)
@asyncio.coroutine
@callback
def pattern_time_change_listener(event):
"""Listen for matching time_changed events."""
now = event.data[ATTR_NOW]
@ -246,7 +239,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
mat(now.minute, minute) and \
mat(now.second, second):
hass.async_add_job(action, now)
hass.async_run_job(action, now)
return hass.bus.async_listen(EVENT_TIME_CHANGED,
pattern_time_change_listener)

View File

@ -111,7 +111,7 @@ def mock_service(hass, domain, service):
"""
calls = []
hass.services.register(domain, service, calls.append)
hass.services.register(domain, service, lambda call: calls.append(call))
return calls

View File

@ -139,7 +139,8 @@ class TestAPI(unittest.TestCase):
hass.states.set("test.test", "not_to_be_set")
events = []
hass.bus.listen(const.EVENT_STATE_CHANGED, events.append)
hass.bus.listen(const.EVENT_STATE_CHANGED,
lambda ev: events.append(ev))
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
data=json.dumps({"state": "not_to_be_set"}),

View File

@ -1,6 +1,7 @@
"""Test event helpers."""
# pylint: disable=protected-access,too-many-public-methods
# pylint: disable=too-few-public-methods
import asyncio
import unittest
from datetime import datetime, timedelta
@ -113,17 +114,23 @@ class TestEventHelpers(unittest.TestCase):
wildcard_runs = []
wildercard_runs = []
track_state_change(
self.hass, 'light.Bowl', lambda a, b, c: specific_runs.append(1),
'on', 'off')
def specific_run_callback(entity_id, old_state, new_state):
specific_runs.append(1)
track_state_change(
self.hass, 'light.Bowl',
lambda _, old_s, new_s: wildcard_runs.append((old_s, new_s)))
self.hass, 'light.Bowl', specific_run_callback, 'on', 'off')
track_state_change(
self.hass, MATCH_ALL,
lambda _, old_s, new_s: wildercard_runs.append((old_s, new_s)))
@ha.callback
def wildcard_run_callback(entity_id, old_state, new_state):
wildcard_runs.append((old_state, new_state))
track_state_change(self.hass, 'light.Bowl', wildcard_run_callback)
@asyncio.coroutine
def wildercard_run_callback(entity_id, old_state, new_state):
wildercard_runs.append((old_state, new_state))
track_state_change(self.hass, MATCH_ALL, wildercard_run_callback)
# Adding state to state machine
self.hass.states.set("light.Bowl", "on")

View File

@ -23,13 +23,70 @@ from tests.common import get_test_home_assistant
PST = pytz.timezone('America/Los_Angeles')
class TestMethods(unittest.TestCase):
"""Test the Home Assistant helper methods."""
def test_split_entity_id():
"""Test split_entity_id."""
assert ha.split_entity_id('domain.object_id') == ['domain', 'object_id']
def test_split_entity_id(self):
"""Test split_entity_id."""
self.assertEqual(['domain', 'object_id'],
ha.split_entity_id('domain.object_id'))
def test_async_add_job_schedule_callback():
"""Test that we schedule coroutines and add jobs to the job pool."""
hass = MagicMock()
job = MagicMock()
ha.HomeAssistant.async_add_job(hass, ha.callback(job))
assert len(hass.loop.call_soon.mock_calls) == 1
assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 0
@patch('asyncio.iscoroutinefunction', return_value=True)
def test_async_add_job_schedule_coroutinefunction(mock_iscoro):
"""Test that we schedule coroutines and add jobs to the job pool."""
hass = MagicMock()
job = MagicMock()
ha.HomeAssistant.async_add_job(hass, job)
assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 1
assert len(hass.add_job.mock_calls) == 0
@patch('asyncio.iscoroutinefunction', return_value=False)
def test_async_add_job_add_threaded_job_to_pool(mock_iscoro):
"""Test that we schedule coroutines and add jobs to the job pool."""
hass = MagicMock()
job = MagicMock()
ha.HomeAssistant.async_add_job(hass, job)
assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 1
def test_async_run_job_calls_callback():
"""Test that the callback annotation is respected."""
hass = MagicMock()
calls = []
def job():
calls.append(1)
ha.HomeAssistant.async_run_job(hass, ha.callback(job))
assert len(calls) == 1
assert len(hass.async_add_job.mock_calls) == 0
def test_async_run_job_delegates_non_async():
"""Test that the callback annotation is respected."""
hass = MagicMock()
calls = []
def job():
calls.append(1)
ha.HomeAssistant.async_run_job(hass, job)
assert len(calls) == 0
assert len(hass.async_add_job.mock_calls) == 1
class TestHomeAssistant(unittest.TestCase):
@ -173,6 +230,44 @@ class TestEventBus(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(1, len(runs))
def test_thread_event_listener(self):
"""Test a event listener listeners."""
thread_calls = []
def thread_listener(event):
thread_calls.append(event)
self.bus.listen('test_thread', thread_listener)
self.bus.fire('test_thread')
self.hass.block_till_done()
assert len(thread_calls) == 1
def test_callback_event_listener(self):
"""Test a event listener listeners."""
callback_calls = []
@ha.callback
def callback_listener(event):
callback_calls.append(event)
self.bus.listen('test_callback', callback_listener)
self.bus.fire('test_callback')
self.hass.block_till_done()
assert len(callback_calls) == 1
def test_coroutine_event_listener(self):
"""Test a event listener listeners."""
coroutine_calls = []
@asyncio.coroutine
def coroutine_listener(event):
coroutine_calls.append(event)
self.bus.listen('test_coroutine', coroutine_listener)
self.bus.fire('test_coroutine')
self.hass.block_till_done()
assert len(coroutine_calls) == 1
class TestState(unittest.TestCase):
"""Test State methods."""
@ -330,7 +425,7 @@ class TestStateMachine(unittest.TestCase):
def test_force_update(self):
"""Test force update option."""
events = []
self.hass.bus.listen(EVENT_STATE_CHANGED, events.append)
self.hass.bus.listen(EVENT_STATE_CHANGED, lambda ev: events.append(ev))
self.states.set('light.bowl', 'on')
self.hass.block_till_done()
@ -425,6 +520,22 @@ class TestServiceRegistry(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(1, len(calls))
def test_callback_service(self):
"""Test registering and calling an async service."""
calls = []
@ha.callback
def service_handler(call):
"""Service handler coroutine."""
calls.append(call)
self.services.register('test_domain', 'register_calls',
service_handler)
self.assertTrue(
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
self.hass.block_till_done()
self.assertEqual(1, len(calls))
class TestConfig(unittest.TestCase):
"""Test configuration methods."""
@ -524,8 +635,7 @@ class TestWorkerPoolMonitor(object):
check_threshold()
assert mock_warning.called
event_loop.run_until_complete(
hass.bus.async_listen_once.mock_calls[0][1][1](None))
hass.bus.async_listen_once.mock_calls[0][1][1](None)
assert schedule_handle.cancel.called
@ -561,5 +671,5 @@ class TestAsyncCreateTimer(object):
assert {ha.ATTR_NOW: now} == event_data
stop_timer = hass.bus.async_listen_once.mock_calls[0][1][1]
event_loop.run_until_complete(stop_timer(None))
stop_timer(None)
assert event.set.called

View File

@ -163,7 +163,7 @@ class TestRemoteMethods(unittest.TestCase):
def test_set_state_with_push(self):
"""Test Python API set_state with push option."""
events = []
hass.bus.listen(EVENT_STATE_CHANGED, events.append)
hass.bus.listen(EVENT_STATE_CHANGED, lambda ev: events.append(ev))
remote.set_state(master_api, 'test.test', 'set_test_2')
remote.set_state(master_api, 'test.test', 'set_test_2')