mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Add async_safe annotation (#3688)
* Add async_safe annotation * More async_run_job * coroutine -> async_save * Lint * Rename async_safe -> callback * Add tests to core for different job types * Add one more test with different type of callbacks * Fix typing signature for callback methods * Fix callback service executed method * Fix method signatures for callback
This commit is contained in:
parent
be7401f4a2
commit
5085cdb0f7
@ -4,11 +4,11 @@ Offer event listening automation rules.
|
||||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#event-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import CONF_PLATFORM
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
|
||||
@ -29,12 +29,12 @@ def async_trigger(hass, config, action):
|
||||
event_type = config.get(CONF_EVENT_TYPE)
|
||||
event_data = config.get(CONF_EVENT_DATA)
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def handle_event(event):
|
||||
"""Listen for events and calls the action when data matches."""
|
||||
if not event_data or all(val == event.data.get(key) for key, val
|
||||
in event_data.items()):
|
||||
hass.async_add_job(action, {
|
||||
hass.async_run_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'event',
|
||||
'event': event,
|
||||
|
@ -4,9 +4,9 @@ Offer MQTT listening automation rules.
|
||||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#mqtt-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
import homeassistant.components.mqtt as mqtt
|
||||
from homeassistant.const import (CONF_PLATFORM, CONF_PAYLOAD)
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
@ -27,11 +27,11 @@ def async_trigger(hass, config, action):
|
||||
topic = config.get(CONF_TOPIC)
|
||||
payload = config.get(CONF_PAYLOAD)
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def mqtt_automation_listener(msg_topic, msg_payload, qos):
|
||||
"""Listen for MQTT messages."""
|
||||
if payload is None or payload == msg_payload:
|
||||
hass.async_add_job(action, {
|
||||
hass.async_run_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'mqtt',
|
||||
'topic': msg_topic,
|
||||
|
@ -4,11 +4,11 @@ Offer numeric state listening automation rules.
|
||||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#numeric-state-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import (
|
||||
CONF_VALUE_TEMPLATE, CONF_PLATFORM, CONF_ENTITY_ID,
|
||||
CONF_BELOW, CONF_ABOVE)
|
||||
@ -35,7 +35,7 @@ def async_trigger(hass, config, action):
|
||||
if value_template is not None:
|
||||
value_template.hass = hass
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def state_automation_listener(entity, from_s, to_s):
|
||||
"""Listen for state changes and calls action."""
|
||||
if to_s is None:
|
||||
@ -64,6 +64,6 @@ def async_trigger(hass, config, action):
|
||||
variables['trigger']['from_state'] = from_s
|
||||
variables['trigger']['to_state'] = to_s
|
||||
|
||||
hass.async_add_job(action, variables)
|
||||
hass.async_run_job(action, variables)
|
||||
|
||||
return async_track_state_change(hass, entity_id, state_automation_listener)
|
||||
|
@ -4,9 +4,9 @@ Offer state listening automation rules.
|
||||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#state-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.const import MATCH_ALL, CONF_PLATFORM
|
||||
from homeassistant.helpers.event import (
|
||||
@ -43,14 +43,14 @@ def async_trigger(hass, config, action):
|
||||
async_remove_state_for_cancel = None
|
||||
async_remove_state_for_listener = None
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def state_automation_listener(entity, from_s, to_s):
|
||||
"""Listen for state changes and calls action."""
|
||||
nonlocal async_remove_state_for_cancel, async_remove_state_for_listener
|
||||
|
||||
def call_action():
|
||||
"""Call action with right context."""
|
||||
hass.async_add_job(action, {
|
||||
hass.async_run_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'state',
|
||||
'entity_id': entity,
|
||||
@ -64,13 +64,13 @@ def async_trigger(hass, config, action):
|
||||
call_action()
|
||||
return
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def state_for_listener(now):
|
||||
"""Fire on state changes after a delay and calls action."""
|
||||
async_remove_state_for_cancel()
|
||||
call_action()
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def state_for_cancel_listener(entity, inner_from_s, inner_to_s):
|
||||
"""Fire on changes and cancel for listener if changed."""
|
||||
if inner_to_s.state == to_s.state:
|
||||
|
@ -4,12 +4,12 @@ Offer sun based automation rules.
|
||||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#sun-trigger
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import (
|
||||
CONF_EVENT, CONF_OFFSET, CONF_PLATFORM, SUN_EVENT_SUNRISE)
|
||||
from homeassistant.helpers.event import async_track_sunrise, async_track_sunset
|
||||
@ -31,10 +31,10 @@ def async_trigger(hass, config, action):
|
||||
event = config.get(CONF_EVENT)
|
||||
offset = config.get(CONF_OFFSET)
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def call_action():
|
||||
"""Call action with right context."""
|
||||
hass.async_add_job(action, {
|
||||
hass.async_run_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'sun',
|
||||
'event': event,
|
||||
|
@ -4,11 +4,11 @@ Offer template automation rules.
|
||||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#template-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import CONF_VALUE_TEMPLATE, CONF_PLATFORM
|
||||
from homeassistant.helpers import condition
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
@ -31,7 +31,7 @@ def async_trigger(hass, config, action):
|
||||
# Local variable to keep track of if the action has already been triggered
|
||||
already_triggered = False
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def state_changed_listener(entity_id, from_s, to_s):
|
||||
"""Listen for state changes and calls action."""
|
||||
nonlocal already_triggered
|
||||
@ -40,7 +40,7 @@ def async_trigger(hass, config, action):
|
||||
# Check to see if template returns true
|
||||
if template_result and not already_triggered:
|
||||
already_triggered = True
|
||||
hass.async_add_job(action, {
|
||||
hass.async_run_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'template',
|
||||
'entity_id': entity_id,
|
||||
|
@ -4,11 +4,11 @@ Offer time listening automation rules.
|
||||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#time-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import CONF_AFTER, CONF_PLATFORM
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.event import async_track_time_change
|
||||
@ -39,10 +39,10 @@ def async_trigger(hass, config, action):
|
||||
minutes = config.get(CONF_MINUTES)
|
||||
seconds = config.get(CONF_SECONDS)
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def time_automation_listener(now):
|
||||
"""Listen for time changes and calls action."""
|
||||
hass.async_add_job(action, {
|
||||
hass.async_run_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'time',
|
||||
'now': now,
|
||||
|
@ -4,9 +4,9 @@ Offer zone automation rules.
|
||||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#zone-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import (
|
||||
CONF_EVENT, CONF_ENTITY_ID, CONF_ZONE, MATCH_ALL, CONF_PLATFORM)
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
@ -32,7 +32,7 @@ def async_trigger(hass, config, action):
|
||||
zone_entity_id = config.get(CONF_ZONE)
|
||||
event = config.get(CONF_EVENT)
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def zone_automation_listener(entity, from_s, to_s):
|
||||
"""Listen for state changes and calls action."""
|
||||
if from_s and not location.has_location(from_s) or \
|
||||
@ -49,7 +49,7 @@ def async_trigger(hass, config, action):
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if event == EVENT_ENTER and not from_match and to_match or \
|
||||
event == EVENT_LEAVE and from_match and not to_match:
|
||||
hass.async_add_job(action, {
|
||||
hass.async_run_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'zone',
|
||||
'entity_id': entity,
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components.binary_sensor import (
|
||||
BinarySensorDevice, ENTITY_ID_FORMAT, PLATFORM_SCHEMA,
|
||||
SENSOR_CLASSES_SCHEMA)
|
||||
@ -82,7 +83,7 @@ class BinarySensorTemplate(BinarySensorDevice):
|
||||
|
||||
self.update()
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def template_bsensor_state_listener(entity, old_state, new_state):
|
||||
"""Called when the target device changes state."""
|
||||
hass.loop.create_task(self.async_update_ha_state(True))
|
||||
|
@ -171,7 +171,7 @@ def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS):
|
||||
if not _match_topic(topic, event.data[ATTR_TOPIC]):
|
||||
return
|
||||
|
||||
hass.async_add_job(callback, event.data[ATTR_TOPIC],
|
||||
hass.async_run_job(callback, event.data[ATTR_TOPIC],
|
||||
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS])
|
||||
|
||||
async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED,
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components.sensor import ENTITY_ID_FORMAT, PLATFORM_SCHEMA
|
||||
from homeassistant.const import (
|
||||
ATTR_FRIENDLY_NAME, ATTR_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE,
|
||||
@ -79,7 +80,7 @@ class SensorTemplate(Entity):
|
||||
|
||||
self.update()
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def template_sensor_state_listener(entity, old_state, new_state):
|
||||
"""Called when the target device changes state."""
|
||||
hass.loop.create_task(self.async_update_ha_state(True))
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components.switch import (
|
||||
ENTITY_ID_FORMAT, SwitchDevice, PLATFORM_SCHEMA)
|
||||
from homeassistant.const import (
|
||||
@ -88,7 +89,7 @@ class SwitchTemplate(SwitchDevice):
|
||||
|
||||
self.update()
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def template_switch_state_listener(entity, old_state, new_state):
|
||||
"""Called when the target device changes state."""
|
||||
hass.loop.create_task(self.async_update_ha_state(True))
|
||||
|
@ -78,6 +78,18 @@ def valid_entity_id(entity_id: str) -> bool:
|
||||
return ENTITY_ID_PATTERN.match(entity_id) is not None
|
||||
|
||||
|
||||
def callback(func: Callable[..., None]) -> Callable[..., None]:
|
||||
"""Annotation to mark method as safe to call from within the event loop."""
|
||||
# pylint: disable=protected-access
|
||||
func._hass_callback = True
|
||||
return func
|
||||
|
||||
|
||||
def is_callback(func: Callable[..., Any]) -> bool:
|
||||
"""Check if function is safe to be called in the event loop."""
|
||||
return '_hass_callback' in func.__dict__
|
||||
|
||||
|
||||
class CoreState(enum.Enum):
|
||||
"""Represent the current state of Home Assistant."""
|
||||
|
||||
@ -224,11 +236,24 @@ class HomeAssistant(object):
|
||||
target: target to call.
|
||||
args: parameters for method to call.
|
||||
"""
|
||||
if asyncio.iscoroutinefunction(target):
|
||||
if is_callback(target):
|
||||
self.loop.call_soon(target, *args)
|
||||
elif asyncio.iscoroutinefunction(target):
|
||||
self.loop.create_task(target(*args))
|
||||
else:
|
||||
self.add_job(target, *args)
|
||||
|
||||
def async_run_job(self, target: Callable[..., None], *args: Any):
|
||||
"""Run a job from within the event loop.
|
||||
|
||||
target: target to call.
|
||||
args: parameters for method to call.
|
||||
"""
|
||||
if is_callback(target):
|
||||
target(*args)
|
||||
else:
|
||||
self.async_add_job(target, *args)
|
||||
|
||||
def _loop_empty(self):
|
||||
"""Python 3.4.2 empty loop compatibility function."""
|
||||
# pylint: disable=protected-access
|
||||
@ -380,7 +405,6 @@ class EventBus(object):
|
||||
|
||||
self._loop.call_soon_threadsafe(self.async_fire, event_type,
|
||||
event_data, origin)
|
||||
return
|
||||
|
||||
def async_fire(self, event_type: str, event_data=None,
|
||||
origin=EventOrigin.local, wait=False):
|
||||
@ -408,6 +432,8 @@ class EventBus(object):
|
||||
for func in listeners:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
self._loop.create_task(func(event))
|
||||
elif is_callback(func):
|
||||
self._loop.call_soon(func, event)
|
||||
else:
|
||||
sync_jobs.append((job_priority, (func, event)))
|
||||
|
||||
@ -795,7 +821,7 @@ class Service(object):
|
||||
"""Represents a callable service."""
|
||||
|
||||
__slots__ = ['func', 'description', 'fields', 'schema',
|
||||
'iscoroutinefunction']
|
||||
'is_callback', 'is_coroutinefunction']
|
||||
|
||||
def __init__(self, func, description, fields, schema):
|
||||
"""Initialize a service."""
|
||||
@ -803,7 +829,8 @@ class Service(object):
|
||||
self.description = description or ''
|
||||
self.fields = fields or {}
|
||||
self.schema = schema
|
||||
self.iscoroutinefunction = asyncio.iscoroutinefunction(func)
|
||||
self.is_callback = is_callback(func)
|
||||
self.is_coroutinefunction = asyncio.iscoroutinefunction(func)
|
||||
|
||||
def as_dict(self):
|
||||
"""Return dictionary representation of this service."""
|
||||
@ -934,7 +961,7 @@ class ServiceRegistry(object):
|
||||
self._loop
|
||||
).result()
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def async_call(self, domain, service, service_data=None, blocking=False):
|
||||
"""
|
||||
Call a service.
|
||||
@ -966,7 +993,7 @@ class ServiceRegistry(object):
|
||||
if blocking:
|
||||
fut = asyncio.Future(loop=self._loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def service_executed(event):
|
||||
"""Callback method that is called when service is executed."""
|
||||
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
|
||||
@ -1007,7 +1034,8 @@ class ServiceRegistry(object):
|
||||
|
||||
data = {ATTR_SERVICE_CALL_ID: call_id}
|
||||
|
||||
if service_handler.iscoroutinefunction:
|
||||
if (service_handler.is_coroutinefunction or
|
||||
service_handler.is_callback):
|
||||
self._bus.async_fire(EVENT_SERVICE_EXECUTED, data)
|
||||
else:
|
||||
self._bus.fire(EVENT_SERVICE_EXECUTED, data)
|
||||
@ -1023,17 +1051,19 @@ class ServiceRegistry(object):
|
||||
|
||||
service_call = ServiceCall(domain, service, service_data, call_id)
|
||||
|
||||
if not service_handler.iscoroutinefunction:
|
||||
if service_handler.is_callback:
|
||||
service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
elif service_handler.is_coroutinefunction:
|
||||
yield from service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
else:
|
||||
def execute_service():
|
||||
"""Execute a service and fires a SERVICE_EXECUTED event."""
|
||||
service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
|
||||
self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE)
|
||||
return
|
||||
|
||||
yield from service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
|
||||
def _generate_unique_id(self):
|
||||
"""Generate a unique service call id."""
|
||||
@ -1098,7 +1128,7 @@ def async_create_timer(hass, interval=TIMER_INTERVAL):
|
||||
stop_event = asyncio.Event(loop=hass.loop)
|
||||
|
||||
# Setting the Event inside the loop by marking it as a coroutine
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def stop_timer(event):
|
||||
"""Stop the timer."""
|
||||
stop_event.set()
|
||||
@ -1212,7 +1242,7 @@ def async_monitor_worker_pool(hass):
|
||||
|
||||
schedule()
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def stop_monitor(event):
|
||||
"""Stop the monitor."""
|
||||
handle.cancel()
|
||||
|
@ -1,9 +1,8 @@
|
||||
"""Helpers for listening to events."""
|
||||
import asyncio
|
||||
import functools as ft
|
||||
from datetime import timedelta
|
||||
|
||||
from ..core import HomeAssistant
|
||||
from ..core import HomeAssistant, callback
|
||||
from ..const import (
|
||||
ATTR_NOW, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
|
||||
from ..util import dt as dt_util
|
||||
@ -57,8 +56,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
|
||||
else:
|
||||
entity_ids = tuple(entity_id.lower() for entity_id in entity_ids)
|
||||
|
||||
@ft.wraps(action)
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def state_change_listener(event):
|
||||
"""The listener that listens for specific state changes."""
|
||||
if entity_ids != MATCH_ALL and \
|
||||
@ -76,7 +74,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
|
||||
new_state = None
|
||||
|
||||
if _matcher(old_state, from_state) and _matcher(new_state, to_state):
|
||||
hass.async_add_job(action, event.data.get('entity_id'),
|
||||
hass.async_run_job(action, event.data.get('entity_id'),
|
||||
event.data.get('old_state'),
|
||||
event.data.get('new_state'))
|
||||
|
||||
@ -90,11 +88,10 @@ def async_track_point_in_time(hass, action, point_in_time):
|
||||
"""Add a listener that fires once after a spefic point in time."""
|
||||
utc_point_in_time = dt_util.as_utc(point_in_time)
|
||||
|
||||
@ft.wraps(action)
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def utc_converter(utc_now):
|
||||
"""Convert passed in UTC now to local now."""
|
||||
hass.async_add_job(action, dt_util.as_local(utc_now))
|
||||
hass.async_run_job(action, dt_util.as_local(utc_now))
|
||||
|
||||
return async_track_point_in_utc_time(hass, utc_converter,
|
||||
utc_point_in_time)
|
||||
@ -108,8 +105,7 @@ def async_track_point_in_utc_time(hass, action, point_in_time):
|
||||
# Ensure point_in_time is UTC
|
||||
point_in_time = dt_util.as_utc(point_in_time)
|
||||
|
||||
@ft.wraps(action)
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def point_in_time_listener(event):
|
||||
"""Listen for matching time_changed events."""
|
||||
now = event.data[ATTR_NOW]
|
||||
@ -125,7 +121,7 @@ def async_track_point_in_utc_time(hass, action, point_in_time):
|
||||
point_in_time_listener.run = True
|
||||
async_unsub()
|
||||
|
||||
hass.async_add_job(action, now)
|
||||
hass.async_run_job(action, now)
|
||||
|
||||
async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED,
|
||||
point_in_time_listener)
|
||||
@ -151,14 +147,13 @@ def async_track_sunrise(hass, action, offset=None):
|
||||
|
||||
return next_time
|
||||
|
||||
@ft.wraps(action)
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def sunrise_automation_listener(now):
|
||||
"""Called when it's time for action."""
|
||||
nonlocal remove
|
||||
remove = async_track_point_in_utc_time(
|
||||
hass, sunrise_automation_listener, next_rise())
|
||||
hass.async_add_job(action)
|
||||
hass.async_run_job(action)
|
||||
|
||||
remove = async_track_point_in_utc_time(
|
||||
hass, sunrise_automation_listener, next_rise())
|
||||
@ -187,14 +182,13 @@ def async_track_sunset(hass, action, offset=None):
|
||||
|
||||
return next_time
|
||||
|
||||
@ft.wraps(action)
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def sunset_automation_listener(now):
|
||||
"""Called when it's time for action."""
|
||||
nonlocal remove
|
||||
remove = async_track_point_in_utc_time(
|
||||
hass, sunset_automation_listener, next_set())
|
||||
hass.async_add_job(action)
|
||||
hass.async_run_job(action)
|
||||
|
||||
remove = async_track_point_in_utc_time(
|
||||
hass, sunset_automation_listener, next_set())
|
||||
@ -217,10 +211,10 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
|
||||
# We do not have to wrap the function with time pattern matching logic
|
||||
# if no pattern given
|
||||
if all(val is None for val in (year, month, day, hour, minute, second)):
|
||||
@ft.wraps(action)
|
||||
@callback
|
||||
def time_change_listener(event):
|
||||
"""Fire every time event that comes in."""
|
||||
action(event.data[ATTR_NOW])
|
||||
hass.async_run_job(action, event.data[ATTR_NOW])
|
||||
|
||||
return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener)
|
||||
|
||||
@ -228,8 +222,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
|
||||
year, month, day = pmp(year), pmp(month), pmp(day)
|
||||
hour, minute, second = pmp(hour), pmp(minute), pmp(second)
|
||||
|
||||
@ft.wraps(action)
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def pattern_time_change_listener(event):
|
||||
"""Listen for matching time_changed events."""
|
||||
now = event.data[ATTR_NOW]
|
||||
@ -246,7 +239,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
|
||||
mat(now.minute, minute) and \
|
||||
mat(now.second, second):
|
||||
|
||||
hass.async_add_job(action, now)
|
||||
hass.async_run_job(action, now)
|
||||
|
||||
return hass.bus.async_listen(EVENT_TIME_CHANGED,
|
||||
pattern_time_change_listener)
|
||||
|
@ -111,7 +111,7 @@ def mock_service(hass, domain, service):
|
||||
"""
|
||||
calls = []
|
||||
|
||||
hass.services.register(domain, service, calls.append)
|
||||
hass.services.register(domain, service, lambda call: calls.append(call))
|
||||
|
||||
return calls
|
||||
|
||||
|
@ -139,7 +139,8 @@ class TestAPI(unittest.TestCase):
|
||||
hass.states.set("test.test", "not_to_be_set")
|
||||
|
||||
events = []
|
||||
hass.bus.listen(const.EVENT_STATE_CHANGED, events.append)
|
||||
hass.bus.listen(const.EVENT_STATE_CHANGED,
|
||||
lambda ev: events.append(ev))
|
||||
|
||||
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
|
||||
data=json.dumps({"state": "not_to_be_set"}),
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Test event helpers."""
|
||||
# pylint: disable=protected-access,too-many-public-methods
|
||||
# pylint: disable=too-few-public-methods
|
||||
import asyncio
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
@ -113,17 +114,23 @@ class TestEventHelpers(unittest.TestCase):
|
||||
wildcard_runs = []
|
||||
wildercard_runs = []
|
||||
|
||||
track_state_change(
|
||||
self.hass, 'light.Bowl', lambda a, b, c: specific_runs.append(1),
|
||||
'on', 'off')
|
||||
def specific_run_callback(entity_id, old_state, new_state):
|
||||
specific_runs.append(1)
|
||||
|
||||
track_state_change(
|
||||
self.hass, 'light.Bowl',
|
||||
lambda _, old_s, new_s: wildcard_runs.append((old_s, new_s)))
|
||||
self.hass, 'light.Bowl', specific_run_callback, 'on', 'off')
|
||||
|
||||
track_state_change(
|
||||
self.hass, MATCH_ALL,
|
||||
lambda _, old_s, new_s: wildercard_runs.append((old_s, new_s)))
|
||||
@ha.callback
|
||||
def wildcard_run_callback(entity_id, old_state, new_state):
|
||||
wildcard_runs.append((old_state, new_state))
|
||||
|
||||
track_state_change(self.hass, 'light.Bowl', wildcard_run_callback)
|
||||
|
||||
@asyncio.coroutine
|
||||
def wildercard_run_callback(entity_id, old_state, new_state):
|
||||
wildercard_runs.append((old_state, new_state))
|
||||
|
||||
track_state_change(self.hass, MATCH_ALL, wildercard_run_callback)
|
||||
|
||||
# Adding state to state machine
|
||||
self.hass.states.set("light.Bowl", "on")
|
||||
|
@ -23,13 +23,70 @@ from tests.common import get_test_home_assistant
|
||||
PST = pytz.timezone('America/Los_Angeles')
|
||||
|
||||
|
||||
class TestMethods(unittest.TestCase):
|
||||
"""Test the Home Assistant helper methods."""
|
||||
def test_split_entity_id():
|
||||
"""Test split_entity_id."""
|
||||
assert ha.split_entity_id('domain.object_id') == ['domain', 'object_id']
|
||||
|
||||
def test_split_entity_id(self):
|
||||
"""Test split_entity_id."""
|
||||
self.assertEqual(['domain', 'object_id'],
|
||||
ha.split_entity_id('domain.object_id'))
|
||||
|
||||
def test_async_add_job_schedule_callback():
|
||||
"""Test that we schedule coroutines and add jobs to the job pool."""
|
||||
hass = MagicMock()
|
||||
job = MagicMock()
|
||||
|
||||
ha.HomeAssistant.async_add_job(hass, ha.callback(job))
|
||||
assert len(hass.loop.call_soon.mock_calls) == 1
|
||||
assert len(hass.loop.create_task.mock_calls) == 0
|
||||
assert len(hass.add_job.mock_calls) == 0
|
||||
|
||||
|
||||
@patch('asyncio.iscoroutinefunction', return_value=True)
|
||||
def test_async_add_job_schedule_coroutinefunction(mock_iscoro):
|
||||
"""Test that we schedule coroutines and add jobs to the job pool."""
|
||||
hass = MagicMock()
|
||||
job = MagicMock()
|
||||
|
||||
ha.HomeAssistant.async_add_job(hass, job)
|
||||
assert len(hass.loop.call_soon.mock_calls) == 0
|
||||
assert len(hass.loop.create_task.mock_calls) == 1
|
||||
assert len(hass.add_job.mock_calls) == 0
|
||||
|
||||
|
||||
@patch('asyncio.iscoroutinefunction', return_value=False)
|
||||
def test_async_add_job_add_threaded_job_to_pool(mock_iscoro):
|
||||
"""Test that we schedule coroutines and add jobs to the job pool."""
|
||||
hass = MagicMock()
|
||||
job = MagicMock()
|
||||
|
||||
ha.HomeAssistant.async_add_job(hass, job)
|
||||
assert len(hass.loop.call_soon.mock_calls) == 0
|
||||
assert len(hass.loop.create_task.mock_calls) == 0
|
||||
assert len(hass.add_job.mock_calls) == 1
|
||||
|
||||
|
||||
def test_async_run_job_calls_callback():
|
||||
"""Test that the callback annotation is respected."""
|
||||
hass = MagicMock()
|
||||
calls = []
|
||||
|
||||
def job():
|
||||
calls.append(1)
|
||||
|
||||
ha.HomeAssistant.async_run_job(hass, ha.callback(job))
|
||||
assert len(calls) == 1
|
||||
assert len(hass.async_add_job.mock_calls) == 0
|
||||
|
||||
|
||||
def test_async_run_job_delegates_non_async():
|
||||
"""Test that the callback annotation is respected."""
|
||||
hass = MagicMock()
|
||||
calls = []
|
||||
|
||||
def job():
|
||||
calls.append(1)
|
||||
|
||||
ha.HomeAssistant.async_run_job(hass, job)
|
||||
assert len(calls) == 0
|
||||
assert len(hass.async_add_job.mock_calls) == 1
|
||||
|
||||
|
||||
class TestHomeAssistant(unittest.TestCase):
|
||||
@ -173,6 +230,44 @@ class TestEventBus(unittest.TestCase):
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(runs))
|
||||
|
||||
def test_thread_event_listener(self):
|
||||
"""Test a event listener listeners."""
|
||||
thread_calls = []
|
||||
|
||||
def thread_listener(event):
|
||||
thread_calls.append(event)
|
||||
|
||||
self.bus.listen('test_thread', thread_listener)
|
||||
self.bus.fire('test_thread')
|
||||
self.hass.block_till_done()
|
||||
assert len(thread_calls) == 1
|
||||
|
||||
def test_callback_event_listener(self):
|
||||
"""Test a event listener listeners."""
|
||||
callback_calls = []
|
||||
|
||||
@ha.callback
|
||||
def callback_listener(event):
|
||||
callback_calls.append(event)
|
||||
|
||||
self.bus.listen('test_callback', callback_listener)
|
||||
self.bus.fire('test_callback')
|
||||
self.hass.block_till_done()
|
||||
assert len(callback_calls) == 1
|
||||
|
||||
def test_coroutine_event_listener(self):
|
||||
"""Test a event listener listeners."""
|
||||
coroutine_calls = []
|
||||
|
||||
@asyncio.coroutine
|
||||
def coroutine_listener(event):
|
||||
coroutine_calls.append(event)
|
||||
|
||||
self.bus.listen('test_coroutine', coroutine_listener)
|
||||
self.bus.fire('test_coroutine')
|
||||
self.hass.block_till_done()
|
||||
assert len(coroutine_calls) == 1
|
||||
|
||||
|
||||
class TestState(unittest.TestCase):
|
||||
"""Test State methods."""
|
||||
@ -330,7 +425,7 @@ class TestStateMachine(unittest.TestCase):
|
||||
def test_force_update(self):
|
||||
"""Test force update option."""
|
||||
events = []
|
||||
self.hass.bus.listen(EVENT_STATE_CHANGED, events.append)
|
||||
self.hass.bus.listen(EVENT_STATE_CHANGED, lambda ev: events.append(ev))
|
||||
|
||||
self.states.set('light.bowl', 'on')
|
||||
self.hass.block_till_done()
|
||||
@ -425,6 +520,22 @@ class TestServiceRegistry(unittest.TestCase):
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(calls))
|
||||
|
||||
def test_callback_service(self):
|
||||
"""Test registering and calling an async service."""
|
||||
calls = []
|
||||
|
||||
@ha.callback
|
||||
def service_handler(call):
|
||||
"""Service handler coroutine."""
|
||||
calls.append(call)
|
||||
|
||||
self.services.register('test_domain', 'register_calls',
|
||||
service_handler)
|
||||
self.assertTrue(
|
||||
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(calls))
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
"""Test configuration methods."""
|
||||
@ -524,8 +635,7 @@ class TestWorkerPoolMonitor(object):
|
||||
check_threshold()
|
||||
assert mock_warning.called
|
||||
|
||||
event_loop.run_until_complete(
|
||||
hass.bus.async_listen_once.mock_calls[0][1][1](None))
|
||||
hass.bus.async_listen_once.mock_calls[0][1][1](None)
|
||||
assert schedule_handle.cancel.called
|
||||
|
||||
|
||||
@ -561,5 +671,5 @@ class TestAsyncCreateTimer(object):
|
||||
assert {ha.ATTR_NOW: now} == event_data
|
||||
|
||||
stop_timer = hass.bus.async_listen_once.mock_calls[0][1][1]
|
||||
event_loop.run_until_complete(stop_timer(None))
|
||||
stop_timer(None)
|
||||
assert event.set.called
|
||||
|
@ -163,7 +163,7 @@ class TestRemoteMethods(unittest.TestCase):
|
||||
def test_set_state_with_push(self):
|
||||
"""Test Python API set_state with push option."""
|
||||
events = []
|
||||
hass.bus.listen(EVENT_STATE_CHANGED, events.append)
|
||||
hass.bus.listen(EVENT_STATE_CHANGED, lambda ev: events.append(ev))
|
||||
|
||||
remote.set_state(master_api, 'test.test', 'set_test_2')
|
||||
remote.set_state(master_api, 'test.test', 'set_test_2')
|
||||
|
Loading…
x
Reference in New Issue
Block a user