Add async_safe annotation (#3688)

* Add async_safe annotation

* More async_run_job

* coroutine -> async_save

* Lint

* Rename async_safe -> callback

* Add tests to core for different job types

* Add one more test with different type of callbacks

* Fix typing signature for callback methods

* Fix callback service executed method

* Fix method signatures for callback
This commit is contained in:
Paulus Schoutsen 2016-10-04 20:44:32 -07:00 committed by GitHub
parent be7401f4a2
commit 5085cdb0f7
19 changed files with 231 additions and 87 deletions

View File

@ -4,11 +4,11 @@ Offer event listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#event-trigger at https://home-assistant.io/components/automation/#event-trigger
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import CONF_PLATFORM from homeassistant.const import CONF_PLATFORM
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
@ -29,12 +29,12 @@ def async_trigger(hass, config, action):
event_type = config.get(CONF_EVENT_TYPE) event_type = config.get(CONF_EVENT_TYPE)
event_data = config.get(CONF_EVENT_DATA) event_data = config.get(CONF_EVENT_DATA)
@asyncio.coroutine @callback
def handle_event(event): def handle_event(event):
"""Listen for events and calls the action when data matches.""" """Listen for events and calls the action when data matches."""
if not event_data or all(val == event.data.get(key) for key, val if not event_data or all(val == event.data.get(key) for key, val
in event_data.items()): in event_data.items()):
hass.async_add_job(action, { hass.async_run_job(action, {
'trigger': { 'trigger': {
'platform': 'event', 'platform': 'event',
'event': event, 'event': event,

View File

@ -4,9 +4,9 @@ Offer MQTT listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#mqtt-trigger at https://home-assistant.io/components/automation/#mqtt-trigger
""" """
import asyncio
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
import homeassistant.components.mqtt as mqtt import homeassistant.components.mqtt as mqtt
from homeassistant.const import (CONF_PLATFORM, CONF_PAYLOAD) from homeassistant.const import (CONF_PLATFORM, CONF_PAYLOAD)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -27,11 +27,11 @@ def async_trigger(hass, config, action):
topic = config.get(CONF_TOPIC) topic = config.get(CONF_TOPIC)
payload = config.get(CONF_PAYLOAD) payload = config.get(CONF_PAYLOAD)
@asyncio.coroutine @callback
def mqtt_automation_listener(msg_topic, msg_payload, qos): def mqtt_automation_listener(msg_topic, msg_payload, qos):
"""Listen for MQTT messages.""" """Listen for MQTT messages."""
if payload is None or payload == msg_payload: if payload is None or payload == msg_payload:
hass.async_add_job(action, { hass.async_run_job(action, {
'trigger': { 'trigger': {
'platform': 'mqtt', 'platform': 'mqtt',
'topic': msg_topic, 'topic': msg_topic,

View File

@ -4,11 +4,11 @@ Offer numeric state listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#numeric-state-trigger at https://home-assistant.io/components/automation/#numeric-state-trigger
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import ( from homeassistant.const import (
CONF_VALUE_TEMPLATE, CONF_PLATFORM, CONF_ENTITY_ID, CONF_VALUE_TEMPLATE, CONF_PLATFORM, CONF_ENTITY_ID,
CONF_BELOW, CONF_ABOVE) CONF_BELOW, CONF_ABOVE)
@ -35,7 +35,7 @@ def async_trigger(hass, config, action):
if value_template is not None: if value_template is not None:
value_template.hass = hass value_template.hass = hass
@asyncio.coroutine @callback
def state_automation_listener(entity, from_s, to_s): def state_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action.""" """Listen for state changes and calls action."""
if to_s is None: if to_s is None:
@ -64,6 +64,6 @@ def async_trigger(hass, config, action):
variables['trigger']['from_state'] = from_s variables['trigger']['from_state'] = from_s
variables['trigger']['to_state'] = to_s variables['trigger']['to_state'] = to_s
hass.async_add_job(action, variables) hass.async_run_job(action, variables)
return async_track_state_change(hass, entity_id, state_automation_listener) return async_track_state_change(hass, entity_id, state_automation_listener)

View File

@ -4,9 +4,9 @@ Offer state listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#state-trigger at https://home-assistant.io/components/automation/#state-trigger
""" """
import asyncio
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.const import MATCH_ALL, CONF_PLATFORM from homeassistant.const import MATCH_ALL, CONF_PLATFORM
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
@ -43,14 +43,14 @@ def async_trigger(hass, config, action):
async_remove_state_for_cancel = None async_remove_state_for_cancel = None
async_remove_state_for_listener = None async_remove_state_for_listener = None
@asyncio.coroutine @callback
def state_automation_listener(entity, from_s, to_s): def state_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action.""" """Listen for state changes and calls action."""
nonlocal async_remove_state_for_cancel, async_remove_state_for_listener nonlocal async_remove_state_for_cancel, async_remove_state_for_listener
def call_action(): def call_action():
"""Call action with right context.""" """Call action with right context."""
hass.async_add_job(action, { hass.async_run_job(action, {
'trigger': { 'trigger': {
'platform': 'state', 'platform': 'state',
'entity_id': entity, 'entity_id': entity,
@ -64,13 +64,13 @@ def async_trigger(hass, config, action):
call_action() call_action()
return return
@asyncio.coroutine @callback
def state_for_listener(now): def state_for_listener(now):
"""Fire on state changes after a delay and calls action.""" """Fire on state changes after a delay and calls action."""
async_remove_state_for_cancel() async_remove_state_for_cancel()
call_action() call_action()
@asyncio.coroutine @callback
def state_for_cancel_listener(entity, inner_from_s, inner_to_s): def state_for_cancel_listener(entity, inner_from_s, inner_to_s):
"""Fire on changes and cancel for listener if changed.""" """Fire on changes and cancel for listener if changed."""
if inner_to_s.state == to_s.state: if inner_to_s.state == to_s.state:

View File

@ -4,12 +4,12 @@ Offer sun based automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#sun-trigger at https://home-assistant.io/components/automation/#sun-trigger
""" """
import asyncio
from datetime import timedelta from datetime import timedelta
import logging import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import ( from homeassistant.const import (
CONF_EVENT, CONF_OFFSET, CONF_PLATFORM, SUN_EVENT_SUNRISE) CONF_EVENT, CONF_OFFSET, CONF_PLATFORM, SUN_EVENT_SUNRISE)
from homeassistant.helpers.event import async_track_sunrise, async_track_sunset from homeassistant.helpers.event import async_track_sunrise, async_track_sunset
@ -31,10 +31,10 @@ def async_trigger(hass, config, action):
event = config.get(CONF_EVENT) event = config.get(CONF_EVENT)
offset = config.get(CONF_OFFSET) offset = config.get(CONF_OFFSET)
@asyncio.coroutine @callback
def call_action(): def call_action():
"""Call action with right context.""" """Call action with right context."""
hass.async_add_job(action, { hass.async_run_job(action, {
'trigger': { 'trigger': {
'platform': 'sun', 'platform': 'sun',
'event': event, 'event': event,

View File

@ -4,11 +4,11 @@ Offer template automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#template-trigger at https://home-assistant.io/components/automation/#template-trigger
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import CONF_VALUE_TEMPLATE, CONF_PLATFORM from homeassistant.const import CONF_VALUE_TEMPLATE, CONF_PLATFORM
from homeassistant.helpers import condition from homeassistant.helpers import condition
from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.event import async_track_state_change
@ -31,7 +31,7 @@ def async_trigger(hass, config, action):
# Local variable to keep track of if the action has already been triggered # Local variable to keep track of if the action has already been triggered
already_triggered = False already_triggered = False
@asyncio.coroutine @callback
def state_changed_listener(entity_id, from_s, to_s): def state_changed_listener(entity_id, from_s, to_s):
"""Listen for state changes and calls action.""" """Listen for state changes and calls action."""
nonlocal already_triggered nonlocal already_triggered
@ -40,7 +40,7 @@ def async_trigger(hass, config, action):
# Check to see if template returns true # Check to see if template returns true
if template_result and not already_triggered: if template_result and not already_triggered:
already_triggered = True already_triggered = True
hass.async_add_job(action, { hass.async_run_job(action, {
'trigger': { 'trigger': {
'platform': 'template', 'platform': 'template',
'entity_id': entity_id, 'entity_id': entity_id,

View File

@ -4,11 +4,11 @@ Offer time listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#time-trigger at https://home-assistant.io/components/automation/#time-trigger
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import CONF_AFTER, CONF_PLATFORM from homeassistant.const import CONF_AFTER, CONF_PLATFORM
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.event import async_track_time_change from homeassistant.helpers.event import async_track_time_change
@ -39,10 +39,10 @@ def async_trigger(hass, config, action):
minutes = config.get(CONF_MINUTES) minutes = config.get(CONF_MINUTES)
seconds = config.get(CONF_SECONDS) seconds = config.get(CONF_SECONDS)
@asyncio.coroutine @callback
def time_automation_listener(now): def time_automation_listener(now):
"""Listen for time changes and calls action.""" """Listen for time changes and calls action."""
hass.async_add_job(action, { hass.async_run_job(action, {
'trigger': { 'trigger': {
'platform': 'time', 'platform': 'time',
'now': now, 'now': now,

View File

@ -4,9 +4,9 @@ Offer zone automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#zone-trigger at https://home-assistant.io/components/automation/#zone-trigger
""" """
import asyncio
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import ( from homeassistant.const import (
CONF_EVENT, CONF_ENTITY_ID, CONF_ZONE, MATCH_ALL, CONF_PLATFORM) CONF_EVENT, CONF_ENTITY_ID, CONF_ZONE, MATCH_ALL, CONF_PLATFORM)
from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.event import async_track_state_change
@ -32,7 +32,7 @@ def async_trigger(hass, config, action):
zone_entity_id = config.get(CONF_ZONE) zone_entity_id = config.get(CONF_ZONE)
event = config.get(CONF_EVENT) event = config.get(CONF_EVENT)
@asyncio.coroutine @callback
def zone_automation_listener(entity, from_s, to_s): def zone_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action.""" """Listen for state changes and calls action."""
if from_s and not location.has_location(from_s) or \ if from_s and not location.has_location(from_s) or \
@ -49,7 +49,7 @@ def async_trigger(hass, config, action):
# pylint: disable=too-many-boolean-expressions # pylint: disable=too-many-boolean-expressions
if event == EVENT_ENTER and not from_match and to_match or \ if event == EVENT_ENTER and not from_match and to_match or \
event == EVENT_LEAVE and from_match and not to_match: event == EVENT_LEAVE and from_match and not to_match:
hass.async_add_job(action, { hass.async_run_job(action, {
'trigger': { 'trigger': {
'platform': 'zone', 'platform': 'zone',
'entity_id': entity, 'entity_id': entity,

View File

@ -9,6 +9,7 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components.binary_sensor import ( from homeassistant.components.binary_sensor import (
BinarySensorDevice, ENTITY_ID_FORMAT, PLATFORM_SCHEMA, BinarySensorDevice, ENTITY_ID_FORMAT, PLATFORM_SCHEMA,
SENSOR_CLASSES_SCHEMA) SENSOR_CLASSES_SCHEMA)
@ -82,7 +83,7 @@ class BinarySensorTemplate(BinarySensorDevice):
self.update() self.update()
@asyncio.coroutine @callback
def template_bsensor_state_listener(entity, old_state, new_state): def template_bsensor_state_listener(entity, old_state, new_state):
"""Called when the target device changes state.""" """Called when the target device changes state."""
hass.loop.create_task(self.async_update_ha_state(True)) hass.loop.create_task(self.async_update_ha_state(True))

View File

@ -171,7 +171,7 @@ def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS):
if not _match_topic(topic, event.data[ATTR_TOPIC]): if not _match_topic(topic, event.data[ATTR_TOPIC]):
return return
hass.async_add_job(callback, event.data[ATTR_TOPIC], hass.async_run_job(callback, event.data[ATTR_TOPIC],
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS]) event.data[ATTR_PAYLOAD], event.data[ATTR_QOS])
async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED, async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED,

View File

@ -9,6 +9,7 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components.sensor import ENTITY_ID_FORMAT, PLATFORM_SCHEMA from homeassistant.components.sensor import ENTITY_ID_FORMAT, PLATFORM_SCHEMA
from homeassistant.const import ( from homeassistant.const import (
ATTR_FRIENDLY_NAME, ATTR_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE, ATTR_FRIENDLY_NAME, ATTR_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE,
@ -79,7 +80,7 @@ class SensorTemplate(Entity):
self.update() self.update()
@asyncio.coroutine @callback
def template_sensor_state_listener(entity, old_state, new_state): def template_sensor_state_listener(entity, old_state, new_state):
"""Called when the target device changes state.""" """Called when the target device changes state."""
hass.loop.create_task(self.async_update_ha_state(True)) hass.loop.create_task(self.async_update_ha_state(True))

View File

@ -9,6 +9,7 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components.switch import ( from homeassistant.components.switch import (
ENTITY_ID_FORMAT, SwitchDevice, PLATFORM_SCHEMA) ENTITY_ID_FORMAT, SwitchDevice, PLATFORM_SCHEMA)
from homeassistant.const import ( from homeassistant.const import (
@ -88,7 +89,7 @@ class SwitchTemplate(SwitchDevice):
self.update() self.update()
@asyncio.coroutine @callback
def template_switch_state_listener(entity, old_state, new_state): def template_switch_state_listener(entity, old_state, new_state):
"""Called when the target device changes state.""" """Called when the target device changes state."""
hass.loop.create_task(self.async_update_ha_state(True)) hass.loop.create_task(self.async_update_ha_state(True))

View File

@ -78,6 +78,18 @@ def valid_entity_id(entity_id: str) -> bool:
return ENTITY_ID_PATTERN.match(entity_id) is not None return ENTITY_ID_PATTERN.match(entity_id) is not None
def callback(func: Callable[..., None]) -> Callable[..., None]:
"""Annotation to mark method as safe to call from within the event loop."""
# pylint: disable=protected-access
func._hass_callback = True
return func
def is_callback(func: Callable[..., Any]) -> bool:
"""Check if function is safe to be called in the event loop."""
return '_hass_callback' in func.__dict__
class CoreState(enum.Enum): class CoreState(enum.Enum):
"""Represent the current state of Home Assistant.""" """Represent the current state of Home Assistant."""
@ -224,11 +236,24 @@ class HomeAssistant(object):
target: target to call. target: target to call.
args: parameters for method to call. args: parameters for method to call.
""" """
if asyncio.iscoroutinefunction(target): if is_callback(target):
self.loop.call_soon(target, *args)
elif asyncio.iscoroutinefunction(target):
self.loop.create_task(target(*args)) self.loop.create_task(target(*args))
else: else:
self.add_job(target, *args) self.add_job(target, *args)
def async_run_job(self, target: Callable[..., None], *args: Any):
"""Run a job from within the event loop.
target: target to call.
args: parameters for method to call.
"""
if is_callback(target):
target(*args)
else:
self.async_add_job(target, *args)
def _loop_empty(self): def _loop_empty(self):
"""Python 3.4.2 empty loop compatibility function.""" """Python 3.4.2 empty loop compatibility function."""
# pylint: disable=protected-access # pylint: disable=protected-access
@ -380,7 +405,6 @@ class EventBus(object):
self._loop.call_soon_threadsafe(self.async_fire, event_type, self._loop.call_soon_threadsafe(self.async_fire, event_type,
event_data, origin) event_data, origin)
return
def async_fire(self, event_type: str, event_data=None, def async_fire(self, event_type: str, event_data=None,
origin=EventOrigin.local, wait=False): origin=EventOrigin.local, wait=False):
@ -408,6 +432,8 @@ class EventBus(object):
for func in listeners: for func in listeners:
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):
self._loop.create_task(func(event)) self._loop.create_task(func(event))
elif is_callback(func):
self._loop.call_soon(func, event)
else: else:
sync_jobs.append((job_priority, (func, event))) sync_jobs.append((job_priority, (func, event)))
@ -795,7 +821,7 @@ class Service(object):
"""Represents a callable service.""" """Represents a callable service."""
__slots__ = ['func', 'description', 'fields', 'schema', __slots__ = ['func', 'description', 'fields', 'schema',
'iscoroutinefunction'] 'is_callback', 'is_coroutinefunction']
def __init__(self, func, description, fields, schema): def __init__(self, func, description, fields, schema):
"""Initialize a service.""" """Initialize a service."""
@ -803,7 +829,8 @@ class Service(object):
self.description = description or '' self.description = description or ''
self.fields = fields or {} self.fields = fields or {}
self.schema = schema self.schema = schema
self.iscoroutinefunction = asyncio.iscoroutinefunction(func) self.is_callback = is_callback(func)
self.is_coroutinefunction = asyncio.iscoroutinefunction(func)
def as_dict(self): def as_dict(self):
"""Return dictionary representation of this service.""" """Return dictionary representation of this service."""
@ -934,7 +961,7 @@ class ServiceRegistry(object):
self._loop self._loop
).result() ).result()
@asyncio.coroutine @callback
def async_call(self, domain, service, service_data=None, blocking=False): def async_call(self, domain, service, service_data=None, blocking=False):
""" """
Call a service. Call a service.
@ -966,7 +993,7 @@ class ServiceRegistry(object):
if blocking: if blocking:
fut = asyncio.Future(loop=self._loop) fut = asyncio.Future(loop=self._loop)
@asyncio.coroutine @callback
def service_executed(event): def service_executed(event):
"""Callback method that is called when service is executed.""" """Callback method that is called when service is executed."""
if event.data[ATTR_SERVICE_CALL_ID] == call_id: if event.data[ATTR_SERVICE_CALL_ID] == call_id:
@ -1007,7 +1034,8 @@ class ServiceRegistry(object):
data = {ATTR_SERVICE_CALL_ID: call_id} data = {ATTR_SERVICE_CALL_ID: call_id}
if service_handler.iscoroutinefunction: if (service_handler.is_coroutinefunction or
service_handler.is_callback):
self._bus.async_fire(EVENT_SERVICE_EXECUTED, data) self._bus.async_fire(EVENT_SERVICE_EXECUTED, data)
else: else:
self._bus.fire(EVENT_SERVICE_EXECUTED, data) self._bus.fire(EVENT_SERVICE_EXECUTED, data)
@ -1023,17 +1051,19 @@ class ServiceRegistry(object):
service_call = ServiceCall(domain, service, service_data, call_id) service_call = ServiceCall(domain, service, service_data, call_id)
if not service_handler.iscoroutinefunction: if service_handler.is_callback:
service_handler.func(service_call)
fire_service_executed()
elif service_handler.is_coroutinefunction:
yield from service_handler.func(service_call)
fire_service_executed()
else:
def execute_service(): def execute_service():
"""Execute a service and fires a SERVICE_EXECUTED event.""" """Execute a service and fires a SERVICE_EXECUTED event."""
service_handler.func(service_call) service_handler.func(service_call)
fire_service_executed() fire_service_executed()
self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE) self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE)
return
yield from service_handler.func(service_call)
fire_service_executed()
def _generate_unique_id(self): def _generate_unique_id(self):
"""Generate a unique service call id.""" """Generate a unique service call id."""
@ -1098,7 +1128,7 @@ def async_create_timer(hass, interval=TIMER_INTERVAL):
stop_event = asyncio.Event(loop=hass.loop) stop_event = asyncio.Event(loop=hass.loop)
# Setting the Event inside the loop by marking it as a coroutine # Setting the Event inside the loop by marking it as a coroutine
@asyncio.coroutine @callback
def stop_timer(event): def stop_timer(event):
"""Stop the timer.""" """Stop the timer."""
stop_event.set() stop_event.set()
@ -1212,7 +1242,7 @@ def async_monitor_worker_pool(hass):
schedule() schedule()
@asyncio.coroutine @callback
def stop_monitor(event): def stop_monitor(event):
"""Stop the monitor.""" """Stop the monitor."""
handle.cancel() handle.cancel()

View File

@ -1,9 +1,8 @@
"""Helpers for listening to events.""" """Helpers for listening to events."""
import asyncio
import functools as ft import functools as ft
from datetime import timedelta from datetime import timedelta
from ..core import HomeAssistant from ..core import HomeAssistant, callback
from ..const import ( from ..const import (
ATTR_NOW, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL) ATTR_NOW, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
from ..util import dt as dt_util from ..util import dt as dt_util
@ -57,8 +56,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
else: else:
entity_ids = tuple(entity_id.lower() for entity_id in entity_ids) entity_ids = tuple(entity_id.lower() for entity_id in entity_ids)
@ft.wraps(action) @callback
@asyncio.coroutine
def state_change_listener(event): def state_change_listener(event):
"""The listener that listens for specific state changes.""" """The listener that listens for specific state changes."""
if entity_ids != MATCH_ALL and \ if entity_ids != MATCH_ALL and \
@ -76,7 +74,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
new_state = None new_state = None
if _matcher(old_state, from_state) and _matcher(new_state, to_state): if _matcher(old_state, from_state) and _matcher(new_state, to_state):
hass.async_add_job(action, event.data.get('entity_id'), hass.async_run_job(action, event.data.get('entity_id'),
event.data.get('old_state'), event.data.get('old_state'),
event.data.get('new_state')) event.data.get('new_state'))
@ -90,11 +88,10 @@ def async_track_point_in_time(hass, action, point_in_time):
"""Add a listener that fires once after a spefic point in time.""" """Add a listener that fires once after a spefic point in time."""
utc_point_in_time = dt_util.as_utc(point_in_time) utc_point_in_time = dt_util.as_utc(point_in_time)
@ft.wraps(action) @callback
@asyncio.coroutine
def utc_converter(utc_now): def utc_converter(utc_now):
"""Convert passed in UTC now to local now.""" """Convert passed in UTC now to local now."""
hass.async_add_job(action, dt_util.as_local(utc_now)) hass.async_run_job(action, dt_util.as_local(utc_now))
return async_track_point_in_utc_time(hass, utc_converter, return async_track_point_in_utc_time(hass, utc_converter,
utc_point_in_time) utc_point_in_time)
@ -108,8 +105,7 @@ def async_track_point_in_utc_time(hass, action, point_in_time):
# Ensure point_in_time is UTC # Ensure point_in_time is UTC
point_in_time = dt_util.as_utc(point_in_time) point_in_time = dt_util.as_utc(point_in_time)
@ft.wraps(action) @callback
@asyncio.coroutine
def point_in_time_listener(event): def point_in_time_listener(event):
"""Listen for matching time_changed events.""" """Listen for matching time_changed events."""
now = event.data[ATTR_NOW] now = event.data[ATTR_NOW]
@ -125,7 +121,7 @@ def async_track_point_in_utc_time(hass, action, point_in_time):
point_in_time_listener.run = True point_in_time_listener.run = True
async_unsub() async_unsub()
hass.async_add_job(action, now) hass.async_run_job(action, now)
async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED, async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED,
point_in_time_listener) point_in_time_listener)
@ -151,14 +147,13 @@ def async_track_sunrise(hass, action, offset=None):
return next_time return next_time
@ft.wraps(action) @callback
@asyncio.coroutine
def sunrise_automation_listener(now): def sunrise_automation_listener(now):
"""Called when it's time for action.""" """Called when it's time for action."""
nonlocal remove nonlocal remove
remove = async_track_point_in_utc_time( remove = async_track_point_in_utc_time(
hass, sunrise_automation_listener, next_rise()) hass, sunrise_automation_listener, next_rise())
hass.async_add_job(action) hass.async_run_job(action)
remove = async_track_point_in_utc_time( remove = async_track_point_in_utc_time(
hass, sunrise_automation_listener, next_rise()) hass, sunrise_automation_listener, next_rise())
@ -187,14 +182,13 @@ def async_track_sunset(hass, action, offset=None):
return next_time return next_time
@ft.wraps(action) @callback
@asyncio.coroutine
def sunset_automation_listener(now): def sunset_automation_listener(now):
"""Called when it's time for action.""" """Called when it's time for action."""
nonlocal remove nonlocal remove
remove = async_track_point_in_utc_time( remove = async_track_point_in_utc_time(
hass, sunset_automation_listener, next_set()) hass, sunset_automation_listener, next_set())
hass.async_add_job(action) hass.async_run_job(action)
remove = async_track_point_in_utc_time( remove = async_track_point_in_utc_time(
hass, sunset_automation_listener, next_set()) hass, sunset_automation_listener, next_set())
@ -217,10 +211,10 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
# We do not have to wrap the function with time pattern matching logic # We do not have to wrap the function with time pattern matching logic
# if no pattern given # if no pattern given
if all(val is None for val in (year, month, day, hour, minute, second)): if all(val is None for val in (year, month, day, hour, minute, second)):
@ft.wraps(action) @callback
def time_change_listener(event): def time_change_listener(event):
"""Fire every time event that comes in.""" """Fire every time event that comes in."""
action(event.data[ATTR_NOW]) hass.async_run_job(action, event.data[ATTR_NOW])
return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener) return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener)
@ -228,8 +222,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
year, month, day = pmp(year), pmp(month), pmp(day) year, month, day = pmp(year), pmp(month), pmp(day)
hour, minute, second = pmp(hour), pmp(minute), pmp(second) hour, minute, second = pmp(hour), pmp(minute), pmp(second)
@ft.wraps(action) @callback
@asyncio.coroutine
def pattern_time_change_listener(event): def pattern_time_change_listener(event):
"""Listen for matching time_changed events.""" """Listen for matching time_changed events."""
now = event.data[ATTR_NOW] now = event.data[ATTR_NOW]
@ -246,7 +239,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
mat(now.minute, minute) and \ mat(now.minute, minute) and \
mat(now.second, second): mat(now.second, second):
hass.async_add_job(action, now) hass.async_run_job(action, now)
return hass.bus.async_listen(EVENT_TIME_CHANGED, return hass.bus.async_listen(EVENT_TIME_CHANGED,
pattern_time_change_listener) pattern_time_change_listener)

View File

@ -111,7 +111,7 @@ def mock_service(hass, domain, service):
""" """
calls = [] calls = []
hass.services.register(domain, service, calls.append) hass.services.register(domain, service, lambda call: calls.append(call))
return calls return calls

View File

@ -139,7 +139,8 @@ class TestAPI(unittest.TestCase):
hass.states.set("test.test", "not_to_be_set") hass.states.set("test.test", "not_to_be_set")
events = [] events = []
hass.bus.listen(const.EVENT_STATE_CHANGED, events.append) hass.bus.listen(const.EVENT_STATE_CHANGED,
lambda ev: events.append(ev))
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")), requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
data=json.dumps({"state": "not_to_be_set"}), data=json.dumps({"state": "not_to_be_set"}),

View File

@ -1,6 +1,7 @@
"""Test event helpers.""" """Test event helpers."""
# pylint: disable=protected-access,too-many-public-methods # pylint: disable=protected-access,too-many-public-methods
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
import asyncio
import unittest import unittest
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -113,17 +114,23 @@ class TestEventHelpers(unittest.TestCase):
wildcard_runs = [] wildcard_runs = []
wildercard_runs = [] wildercard_runs = []
track_state_change( def specific_run_callback(entity_id, old_state, new_state):
self.hass, 'light.Bowl', lambda a, b, c: specific_runs.append(1), specific_runs.append(1)
'on', 'off')
track_state_change( track_state_change(
self.hass, 'light.Bowl', self.hass, 'light.Bowl', specific_run_callback, 'on', 'off')
lambda _, old_s, new_s: wildcard_runs.append((old_s, new_s)))
track_state_change( @ha.callback
self.hass, MATCH_ALL, def wildcard_run_callback(entity_id, old_state, new_state):
lambda _, old_s, new_s: wildercard_runs.append((old_s, new_s))) wildcard_runs.append((old_state, new_state))
track_state_change(self.hass, 'light.Bowl', wildcard_run_callback)
@asyncio.coroutine
def wildercard_run_callback(entity_id, old_state, new_state):
wildercard_runs.append((old_state, new_state))
track_state_change(self.hass, MATCH_ALL, wildercard_run_callback)
# Adding state to state machine # Adding state to state machine
self.hass.states.set("light.Bowl", "on") self.hass.states.set("light.Bowl", "on")

View File

@ -23,13 +23,70 @@ from tests.common import get_test_home_assistant
PST = pytz.timezone('America/Los_Angeles') PST = pytz.timezone('America/Los_Angeles')
class TestMethods(unittest.TestCase): def test_split_entity_id():
"""Test the Home Assistant helper methods.""" """Test split_entity_id."""
assert ha.split_entity_id('domain.object_id') == ['domain', 'object_id']
def test_split_entity_id(self):
"""Test split_entity_id.""" def test_async_add_job_schedule_callback():
self.assertEqual(['domain', 'object_id'], """Test that we schedule coroutines and add jobs to the job pool."""
ha.split_entity_id('domain.object_id')) hass = MagicMock()
job = MagicMock()
ha.HomeAssistant.async_add_job(hass, ha.callback(job))
assert len(hass.loop.call_soon.mock_calls) == 1
assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 0
@patch('asyncio.iscoroutinefunction', return_value=True)
def test_async_add_job_schedule_coroutinefunction(mock_iscoro):
"""Test that we schedule coroutines and add jobs to the job pool."""
hass = MagicMock()
job = MagicMock()
ha.HomeAssistant.async_add_job(hass, job)
assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 1
assert len(hass.add_job.mock_calls) == 0
@patch('asyncio.iscoroutinefunction', return_value=False)
def test_async_add_job_add_threaded_job_to_pool(mock_iscoro):
"""Test that we schedule coroutines and add jobs to the job pool."""
hass = MagicMock()
job = MagicMock()
ha.HomeAssistant.async_add_job(hass, job)
assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 1
def test_async_run_job_calls_callback():
"""Test that the callback annotation is respected."""
hass = MagicMock()
calls = []
def job():
calls.append(1)
ha.HomeAssistant.async_run_job(hass, ha.callback(job))
assert len(calls) == 1
assert len(hass.async_add_job.mock_calls) == 0
def test_async_run_job_delegates_non_async():
"""Test that the callback annotation is respected."""
hass = MagicMock()
calls = []
def job():
calls.append(1)
ha.HomeAssistant.async_run_job(hass, job)
assert len(calls) == 0
assert len(hass.async_add_job.mock_calls) == 1
class TestHomeAssistant(unittest.TestCase): class TestHomeAssistant(unittest.TestCase):
@ -173,6 +230,44 @@ class TestEventBus(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(1, len(runs)) self.assertEqual(1, len(runs))
def test_thread_event_listener(self):
"""Test a event listener listeners."""
thread_calls = []
def thread_listener(event):
thread_calls.append(event)
self.bus.listen('test_thread', thread_listener)
self.bus.fire('test_thread')
self.hass.block_till_done()
assert len(thread_calls) == 1
def test_callback_event_listener(self):
"""Test a event listener listeners."""
callback_calls = []
@ha.callback
def callback_listener(event):
callback_calls.append(event)
self.bus.listen('test_callback', callback_listener)
self.bus.fire('test_callback')
self.hass.block_till_done()
assert len(callback_calls) == 1
def test_coroutine_event_listener(self):
"""Test a event listener listeners."""
coroutine_calls = []
@asyncio.coroutine
def coroutine_listener(event):
coroutine_calls.append(event)
self.bus.listen('test_coroutine', coroutine_listener)
self.bus.fire('test_coroutine')
self.hass.block_till_done()
assert len(coroutine_calls) == 1
class TestState(unittest.TestCase): class TestState(unittest.TestCase):
"""Test State methods.""" """Test State methods."""
@ -330,7 +425,7 @@ class TestStateMachine(unittest.TestCase):
def test_force_update(self): def test_force_update(self):
"""Test force update option.""" """Test force update option."""
events = [] events = []
self.hass.bus.listen(EVENT_STATE_CHANGED, events.append) self.hass.bus.listen(EVENT_STATE_CHANGED, lambda ev: events.append(ev))
self.states.set('light.bowl', 'on') self.states.set('light.bowl', 'on')
self.hass.block_till_done() self.hass.block_till_done()
@ -425,6 +520,22 @@ class TestServiceRegistry(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(1, len(calls)) self.assertEqual(1, len(calls))
def test_callback_service(self):
"""Test registering and calling an async service."""
calls = []
@ha.callback
def service_handler(call):
"""Service handler coroutine."""
calls.append(call)
self.services.register('test_domain', 'register_calls',
service_handler)
self.assertTrue(
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
self.hass.block_till_done()
self.assertEqual(1, len(calls))
class TestConfig(unittest.TestCase): class TestConfig(unittest.TestCase):
"""Test configuration methods.""" """Test configuration methods."""
@ -524,8 +635,7 @@ class TestWorkerPoolMonitor(object):
check_threshold() check_threshold()
assert mock_warning.called assert mock_warning.called
event_loop.run_until_complete( hass.bus.async_listen_once.mock_calls[0][1][1](None)
hass.bus.async_listen_once.mock_calls[0][1][1](None))
assert schedule_handle.cancel.called assert schedule_handle.cancel.called
@ -561,5 +671,5 @@ class TestAsyncCreateTimer(object):
assert {ha.ATTR_NOW: now} == event_data assert {ha.ATTR_NOW: now} == event_data
stop_timer = hass.bus.async_listen_once.mock_calls[0][1][1] stop_timer = hass.bus.async_listen_once.mock_calls[0][1][1]
event_loop.run_until_complete(stop_timer(None)) stop_timer(None)
assert event.set.called assert event.set.called

View File

@ -163,7 +163,7 @@ class TestRemoteMethods(unittest.TestCase):
def test_set_state_with_push(self): def test_set_state_with_push(self):
"""Test Python API set_state with push option.""" """Test Python API set_state with push option."""
events = [] events = []
hass.bus.listen(EVENT_STATE_CHANGED, events.append) hass.bus.listen(EVENT_STATE_CHANGED, lambda ev: events.append(ev))
remote.set_state(master_api, 'test.test', 'set_test_2') remote.set_state(master_api, 'test.test', 'set_test_2')
remote.set_state(master_api, 'test.test', 'set_test_2') remote.set_state(master_api, 'test.test', 'set_test_2')