diff --git a/homeassistant/components/automation/mqtt.py b/homeassistant/components/automation/mqtt.py index 6824c32bf07..f774991c547 100644 --- a/homeassistant/components/automation/mqtt.py +++ b/homeassistant/components/automation/mqtt.py @@ -4,6 +4,7 @@ Offer MQTT listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#mqtt-trigger """ +import asyncio import voluptuous as vol import homeassistant.components.mqtt as mqtt @@ -26,10 +27,11 @@ def trigger(hass, config, action): topic = config.get(CONF_TOPIC) payload = config.get(CONF_PAYLOAD) + @asyncio.coroutine def mqtt_automation_listener(msg_topic, msg_payload, qos): """Listen for MQTT messages.""" if payload is None or payload == msg_payload: - action({ + hass.async_add_job(action, { 'trigger': { 'platform': 'mqtt', 'topic': msg_topic, diff --git a/homeassistant/components/automation/numeric_state.py b/homeassistant/components/automation/numeric_state.py index 168ca05b62b..780ab4400b0 100644 --- a/homeassistant/components/automation/numeric_state.py +++ b/homeassistant/components/automation/numeric_state.py @@ -4,6 +4,7 @@ Offer numeric state listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#numeric-state-trigger """ +import asyncio import logging import voluptuous as vol @@ -34,7 +35,7 @@ def trigger(hass, config, action): if value_template is not None: value_template.hass = hass - # pylint: disable=unused-argument + @asyncio.coroutine def state_automation_listener(entity, from_s, to_s): """Listen for state changes and calls action.""" if to_s is None: @@ -50,19 +51,19 @@ def trigger(hass, config, action): } # If new one doesn't match, nothing to do - if not condition.numeric_state( + if not condition.async_numeric_state( hass, to_s, below, above, value_template, variables): return # Only match if old didn't exist or existed but didn't match # Written as: skip if old one did exist and matched - if from_s is not None and condition.numeric_state( + if from_s is not None and condition.async_numeric_state( hass, from_s, below, above, value_template, variables): return variables['trigger']['from_state'] = from_s variables['trigger']['to_state'] = to_s - action(variables) + hass.async_add_job(action, variables) return track_state_change(hass, entity_id, state_automation_listener) diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py index 8e0eb5231a5..dbe74479070 100644 --- a/homeassistant/components/automation/state.py +++ b/homeassistant/components/automation/state.py @@ -4,12 +4,15 @@ Offer state listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#state-trigger """ +import asyncio import voluptuous as vol import homeassistant.util.dt as dt_util from homeassistant.const import MATCH_ALL, CONF_PLATFORM -from homeassistant.helpers.event import track_state_change, track_point_in_time +from homeassistant.helpers.event import ( + async_track_state_change, async_track_point_in_utc_time) import homeassistant.helpers.config_validation as cv +from homeassistant.util.async import run_callback_threadsafe CONF_ENTITY_ID = "entity_id" CONF_FROM = "from" @@ -38,16 +41,17 @@ def trigger(hass, config, action): from_state = config.get(CONF_FROM, MATCH_ALL) to_state = config.get(CONF_TO) or config.get(CONF_STATE) or MATCH_ALL time_delta = config.get(CONF_FOR) - remove_state_for_cancel = None - remove_state_for_listener = None + async_remove_state_for_cancel = None + async_remove_state_for_listener = None + @asyncio.coroutine def state_automation_listener(entity, from_s, to_s): """Listen for state changes and calls action.""" - nonlocal remove_state_for_cancel, remove_state_for_listener + nonlocal async_remove_state_for_cancel, async_remove_state_for_listener def call_action(): """Call action with right context.""" - action({ + hass.async_add_job(action, { 'trigger': { 'platform': 'state', 'entity_id': entity, @@ -61,35 +65,41 @@ def trigger(hass, config, action): call_action() return + @asyncio.coroutine def state_for_listener(now): """Fire on state changes after a delay and calls action.""" - remove_state_for_cancel() + async_remove_state_for_cancel() call_action() + @asyncio.coroutine def state_for_cancel_listener(entity, inner_from_s, inner_to_s): """Fire on changes and cancel for listener if changed.""" if inner_to_s.state == to_s.state: return - remove_state_for_listener() - remove_state_for_cancel() + async_remove_state_for_listener() + async_remove_state_for_cancel() - remove_state_for_listener = track_point_in_time( + async_remove_state_for_listener = async_track_point_in_utc_time( hass, state_for_listener, dt_util.utcnow() + time_delta) - remove_state_for_cancel = track_state_change( + async_remove_state_for_cancel = async_track_state_change( hass, entity, state_for_cancel_listener) - unsub = track_state_change(hass, entity_id, state_automation_listener, - from_state, to_state) + unsub = async_track_state_change( + hass, entity_id, state_automation_listener, from_state, to_state) + + def async_remove(): + """Remove state listeners async.""" + unsub() + # pylint: disable=not-callable + if async_remove_state_for_cancel is not None: + async_remove_state_for_cancel() + + if async_remove_state_for_listener is not None: + async_remove_state_for_listener() def remove(): """Remove state listeners.""" - unsub() - # pylint: disable=not-callable - if remove_state_for_cancel is not None: - remove_state_for_cancel() - - if remove_state_for_listener is not None: - remove_state_for_listener() + run_callback_threadsafe(hass.loop, async_remove).result() return remove diff --git a/homeassistant/components/automation/sun.py b/homeassistant/components/automation/sun.py index 991f9b3b385..faa628f572a 100644 --- a/homeassistant/components/automation/sun.py +++ b/homeassistant/components/automation/sun.py @@ -4,6 +4,7 @@ Offer sun based automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#sun-trigger """ +import asyncio from datetime import timedelta import logging @@ -30,9 +31,10 @@ def trigger(hass, config, action): event = config.get(CONF_EVENT) offset = config.get(CONF_OFFSET) + @asyncio.coroutine def call_action(): """Call action with right context.""" - action({ + hass.async_add_job(action, { 'trigger': { 'platform': 'sun', 'event': event, diff --git a/homeassistant/components/automation/time.py b/homeassistant/components/automation/time.py index 0732e2b212c..91f196eaf3f 100644 --- a/homeassistant/components/automation/time.py +++ b/homeassistant/components/automation/time.py @@ -4,6 +4,7 @@ Offer time listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#time-trigger """ +import asyncio import logging import voluptuous as vol @@ -38,9 +39,10 @@ def trigger(hass, config, action): minutes = config.get(CONF_MINUTES) seconds = config.get(CONF_SECONDS) + @asyncio.coroutine def time_automation_listener(now): """Listen for time changes and calls action.""" - action({ + hass.async_add_job(action, { 'trigger': { 'platform': 'time', 'now': now, diff --git a/homeassistant/components/automation/zone.py b/homeassistant/components/automation/zone.py index ec948684805..971257350e3 100644 --- a/homeassistant/components/automation/zone.py +++ b/homeassistant/components/automation/zone.py @@ -4,6 +4,7 @@ Offer zone automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#zone-trigger """ +import asyncio import voluptuous as vol from homeassistant.const import ( @@ -31,6 +32,7 @@ def trigger(hass, config, action): zone_entity_id = config.get(CONF_ZONE) event = config.get(CONF_EVENT) + @asyncio.coroutine def zone_automation_listener(entity, from_s, to_s): """Listen for state changes and calls action.""" if from_s and not location.has_location(from_s) or \ @@ -47,7 +49,7 @@ def trigger(hass, config, action): # pylint: disable=too-many-boolean-expressions if event == EVENT_ENTER and not from_match and to_match or \ event == EVENT_LEAVE and from_match and not to_match: - action({ + hass.async_add_job(action, { 'trigger': { 'platform': 'zone', 'entity_id': entity, diff --git a/homeassistant/components/logbook.py b/homeassistant/components/logbook.py index e65232b28a9..9100c098413 100644 --- a/homeassistant/components/logbook.py +++ b/homeassistant/components/logbook.py @@ -4,6 +4,7 @@ Event parser and human readable log generator. For more details about this component, please refer to the documentation at https://home-assistant.io/components/logbook/ """ +import asyncio import logging from datetime import timedelta from itertools import groupby @@ -20,6 +21,7 @@ from homeassistant.const import (EVENT_HOMEASSISTANT_START, STATE_NOT_HOME, STATE_OFF, STATE_ON, ATTR_HIDDEN) from homeassistant.core import State, split_entity_id, DOMAIN as HA_DOMAIN +from homeassistant.util.async import run_callback_threadsafe DOMAIN = "logbook" DEPENDENCIES = ['recorder', 'frontend'] @@ -57,6 +59,13 @@ LOG_MESSAGE_SCHEMA = vol.Schema({ def log_entry(hass, name, message, domain=None, entity_id=None): + """Add an entry to the logbook.""" + run_callback_threadsafe( + hass.loop, async_log_entry, hass, name, message, domain, entity_id + ).result() + + +def async_log_entry(hass, name, message, domain=None, entity_id=None): """Add an entry to the logbook.""" data = { ATTR_NAME: name, @@ -67,11 +76,12 @@ def log_entry(hass, name, message, domain=None, entity_id=None): data[ATTR_DOMAIN] = domain if entity_id is not None: data[ATTR_ENTITY_ID] = entity_id - hass.bus.fire(EVENT_LOGBOOK_ENTRY, data) + hass.bus.async_fire(EVENT_LOGBOOK_ENTRY, data) def setup(hass, config): """Listen for download events to download files.""" + @asyncio.coroutine def log_message(service): """Handle sending notification message service calls.""" message = service.data[ATTR_MESSAGE] @@ -80,8 +90,8 @@ def setup(hass, config): entity_id = service.data.get(ATTR_ENTITY_ID) message.hass = hass - message = message.render() - log_entry(hass, name, message, domain, entity_id) + message = message.async_render() + async_log_entry(hass, name, message, domain, entity_id) hass.wsgi.register_view(LogbookView(hass, config)) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index abf52da4359..01956a85c36 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -4,6 +4,7 @@ Support for MQTT message handling. For more details about this component, please refer to the documentation at https://home-assistant.io/components/mqtt/ """ +import asyncio import logging import os import socket @@ -11,6 +12,7 @@ import time import voluptuous as vol +from homeassistant.core import JobPriority from homeassistant.bootstrap import prepare_setup_platform from homeassistant.config import load_yaml_config_file from homeassistant.exceptions import HomeAssistantError @@ -164,11 +166,20 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None): def subscribe(hass, topic, callback, qos=DEFAULT_QOS): """Subscribe to an MQTT topic.""" + @asyncio.coroutine def mqtt_topic_subscriber(event): """Match subscribed MQTT topic.""" - if _match_topic(topic, event.data[ATTR_TOPIC]): - callback(event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD], - event.data[ATTR_QOS]) + if not _match_topic(topic, event.data[ATTR_TOPIC]): + return + + if asyncio.iscoroutinefunction(callback): + yield from callback( + event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD], + event.data[ATTR_QOS]) + else: + hass.add_job(callback, event.data[ATTR_TOPIC], + event.data[ATTR_PAYLOAD], event.data[ATTR_QOS], + priority=JobPriority.EVENT_CALLBACK) remove = hass.bus.listen(EVENT_MQTT_MESSAGE_RECEIVED, mqtt_topic_subscriber) diff --git a/homeassistant/core.py b/homeassistant/core.py index bcea24246ca..2a6372dbf6f 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -248,12 +248,16 @@ class HomeAssistant(object): def notify_when_done(): """Notify event loop when pool done.""" + count = 0 while True: # Wait for the work queue to empty self.pool.block_till_done() # Verify the loop is empty if self._loop_empty(): + count += 1 + + if count == 2: break # sleep in the loop executor, this forces execution back into @@ -675,40 +679,29 @@ class StateMachine(object): return list(self._states.values()) def get(self, entity_id): - """Retrieve state of entity_id or None if not found.""" + """Retrieve state of entity_id or None if not found. + + Async friendly. + """ return self._states.get(entity_id.lower()) def is_state(self, entity_id, state): - """Test if entity exists and is specified state.""" - return run_callback_threadsafe( - self._loop, self.async_is_state, entity_id, state - ).result() - - def async_is_state(self, entity_id, state): """Test if entity exists and is specified state. - This method must be run in the event loop. + Async friendly. """ - entity_id = entity_id.lower() + state_obj = self.get(entity_id) - return (entity_id in self._states and - self._states[entity_id].state == state) + return state_obj and state_obj.state == state def is_state_attr(self, entity_id, name, value): - """Test if entity exists and has a state attribute set to value.""" - return run_callback_threadsafe( - self._loop, self.async_is_state_attr, entity_id, name, value - ).result() - - def async_is_state_attr(self, entity_id, name, value): """Test if entity exists and has a state attribute set to value. - This method must be run in the event loop. + Async friendly. """ - entity_id = entity_id.lower() + state_obj = self.get(entity_id) - return (entity_id in self._states and - self._states[entity_id].attributes.get(name, None) == value) + return state_obj and state_obj.attributes.get(name, None) == value def remove(self, entity_id): """Remove the state of an entity. @@ -799,7 +792,8 @@ class StateMachine(object): class Service(object): """Represents a callable service.""" - __slots__ = ['func', 'description', 'fields', 'schema'] + __slots__ = ['func', 'description', 'fields', 'schema', + 'iscoroutinefunction'] def __init__(self, func, description, fields, schema): """Initialize a service.""" @@ -807,6 +801,7 @@ class Service(object): self.description = description or '' self.fields = fields or {} self.schema = schema + self.iscoroutinefunction = asyncio.iscoroutinefunction(func) def as_dict(self): """Return dictionary representation of this service.""" @@ -815,19 +810,6 @@ class Service(object): 'fields': self.fields, } - def __call__(self, call): - """Execute the service.""" - try: - if self.schema: - call.data = self.schema(call.data) - call.data = MappingProxyType(call.data) - - self.func(call) - except vol.MultipleInvalid as ex: - _LOGGER.error('Invalid service data for %s.%s: %s', - call.domain, call.service, - humanize_error(call.data, ex)) - # pylint: disable=too-few-public-methods class ServiceCall(object): @@ -839,7 +821,7 @@ class ServiceCall(object): """Initialize a service call.""" self.domain = domain.lower() self.service = service.lower() - self.data = data or {} + self.data = MappingProxyType(data or {}) self.call_id = call_id def __repr__(self): @@ -983,9 +965,9 @@ class ServiceRegistry(object): fut = asyncio.Future(loop=self._loop) @asyncio.coroutine - def service_executed(call): + def service_executed(event): """Callback method that is called when service is executed.""" - if call.data[ATTR_SERVICE_CALL_ID] == call_id: + if event.data[ATTR_SERVICE_CALL_ID] == call_id: fut.set_result(True) unsub = self._bus.async_listen(EVENT_SERVICE_EXECUTED, @@ -1000,9 +982,10 @@ class ServiceRegistry(object): unsub() return success + @asyncio.coroutine def _event_to_service_call(self, event): """Callback for SERVICE_CALLED events from the event bus.""" - service_data = event.data.get(ATTR_SERVICE_DATA) + service_data = event.data.get(ATTR_SERVICE_DATA) or {} domain = event.data.get(ATTR_DOMAIN).lower() service = event.data.get(ATTR_SERVICE).lower() call_id = event.data.get(ATTR_SERVICE_CALL_ID) @@ -1014,19 +997,41 @@ class ServiceRegistry(object): return service_handler = self._services[domain][service] + + def fire_service_executed(): + """Fire service executed event.""" + if not call_id: + return + + data = {ATTR_SERVICE_CALL_ID: call_id} + + if service_handler.iscoroutinefunction: + self._bus.async_fire(EVENT_SERVICE_EXECUTED, data) + else: + self._bus.fire(EVENT_SERVICE_EXECUTED, data) + + try: + if service_handler.schema: + service_data = service_handler.schema(service_data) + except vol.Invalid as ex: + _LOGGER.error('Invalid service data for %s.%s: %s', + domain, service, humanize_error(service_data, ex)) + fire_service_executed() + return + service_call = ServiceCall(domain, service, service_data, call_id) - # Add a job to the pool that calls _execute_service - self._add_job(self._execute_service, service_handler, service_call, - priority=JobPriority.EVENT_SERVICE) + if not service_handler.iscoroutinefunction: + def execute_service(): + """Execute a service and fires a SERVICE_EXECUTED event.""" + service_handler.func(service_call) + fire_service_executed() - def _execute_service(self, service, call): - """Execute a service and fires a SERVICE_EXECUTED event.""" - service(call) + self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE) + return - if call.call_id is not None: - self._bus.fire( - EVENT_SERVICE_EXECUTED, {ATTR_SERVICE_CALL_ID: call.call_id}) + yield from service_handler.func(service_call) + fire_service_executed() def _generate_unique_id(self): """Generate a unique service call id.""" diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index f4ce02c0846..041f514aeda 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -84,6 +84,15 @@ def or_from_config(config: ConfigType, config_validation: bool=True): def numeric_state(hass: HomeAssistant, entity, below=None, above=None, value_template=None, variables=None): """Test a numeric state condition.""" + return run_callback_threadsafe( + hass.loop, async_numeric_state, hass, entity, below, above, + value_template, variables, + ).result() + + +def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None, + value_template=None, variables=None): + """Test a numeric state condition.""" if isinstance(entity, str): entity = hass.states.get(entity) @@ -96,7 +105,7 @@ def numeric_state(hass: HomeAssistant, entity, below=None, above=None, variables = dict(variables or {}) variables['state'] = entity try: - value = value_template.render(variables) + value = value_template.async_render(variables) except TemplateError as ex: _LOGGER.error("Template error: %s", ex) return False @@ -290,7 +299,10 @@ def time_from_config(config, config_validation=True): def zone(hass, zone_ent, entity): - """Test if zone-condition matches.""" + """Test if zone-condition matches. + + Can be run async. + """ if isinstance(zone_ent, str): zone_ent = hass.states.get(zone_ent) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 0b4768b809d..7529d6288ab 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -1,4 +1,5 @@ """An abstract class for entities.""" +import asyncio import logging from typing import Any, Optional, List, Dict @@ -11,6 +12,7 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant from homeassistant.exceptions import NoEntitySpecifiedError from homeassistant.util import ensure_unique_string, slugify +from homeassistant.util.async import run_coroutine_threadsafe # Entity attributes that we will overwrite _OVERWRITE = {} # type: Dict[str, Any] @@ -143,6 +145,23 @@ class Entity(object): If force_refresh == True will update entity before setting state. """ + # We're already in a thread, do the force refresh here. + if force_refresh and not hasattr(self, 'async_update'): + self.update() + force_refresh = False + + run_coroutine_threadsafe( + self.async_update_ha_state(force_refresh), self.hass.loop + ).result() + + @asyncio.coroutine + def async_update_ha_state(self, force_refresh=False): + """Update Home Assistant with current state of entity. + + If force_refresh == True will update entity before setting state. + + This method must be run in the event loop. + """ if self.hass is None: raise RuntimeError("Attribute hass is None for {}".format(self)) @@ -151,7 +170,13 @@ class Entity(object): "No entity id specified for entity {}".format(self.name)) if force_refresh: - self.update() + if hasattr(self, 'async_update'): + # pylint: disable=no-member + self.async_update() + else: + # PS: Run this in our own thread pool once we have + # future support? + yield from self.hass.loop.run_in_executor(None, self.update) state = STATE_UNKNOWN if self.state is None else str(self.state) attr = self.state_attributes or {} @@ -192,7 +217,7 @@ class Entity(object): # Could not convert state to float pass - return self.hass.states.set( + self.hass.states.async_set( self.entity_id, state, attr, self.force_update) def remove(self) -> None: diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 7331525c052..e27f711afda 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -18,6 +18,28 @@ def track_state_change(hass, entity_ids, action, from_state=None, Returns a function that can be called to remove the listener. """ + async_unsub = run_callback_threadsafe( + hass.loop, async_track_state_change, hass, entity_ids, action, + from_state, to_state).result() + + def remove(): + """Remove listener.""" + run_callback_threadsafe(hass.loop, async_unsub).result() + + return remove + + +def async_track_state_change(hass, entity_ids, action, from_state=None, + to_state=None): + """Track specific state changes. + + entity_ids, from_state and to_state can be string or list. + Use list to match multiple. + + Returns a function that can be called to remove the listener. + + Must be run within the event loop. + """ from_state = _process_state_match(from_state) to_state = _process_state_match(to_state) @@ -52,7 +74,7 @@ def track_state_change(hass, entity_ids, action, from_state=None, event.data.get('old_state'), event.data.get('new_state')) - return hass.bus.listen(EVENT_STATE_CHANGED, state_change_listener) + return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener) def track_point_in_time(hass, action, point_in_time): @@ -69,6 +91,19 @@ def track_point_in_time(hass, action, point_in_time): def track_point_in_utc_time(hass, action, point_in_time): + """Add a listener that fires once after a specific point in UTC time.""" + async_unsub = run_callback_threadsafe( + hass.loop, async_track_point_in_utc_time, hass, action, point_in_time + ).result() + + def remove(): + """Remove listener.""" + run_callback_threadsafe(hass.loop, async_unsub).result() + + return remove + + +def async_track_point_in_utc_time(hass, action, point_in_time): """Add a listener that fires once after a specific point in UTC time.""" # Ensure point_in_time is UTC point_in_time = dt_util.as_utc(point_in_time) @@ -88,20 +123,14 @@ def track_point_in_utc_time(hass, action, point_in_time): # listener gets lined up twice to be executed. This will make # sure the second time it does nothing. point_in_time_listener.run = True - async_remove() + async_unsub() hass.async_add_job(action, now) - future = run_callback_threadsafe( - hass.loop, hass.bus.async_listen, EVENT_TIME_CHANGED, - point_in_time_listener) - async_remove = future.result() + async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED, + point_in_time_listener) - def remove(): - """Remove listener.""" - run_callback_threadsafe(hass.loop, async_remove).result() - - return remove + return async_unsub def track_sunrise(hass, action, offset=None): @@ -118,19 +147,21 @@ def track_sunrise(hass, action, offset=None): return next_time + @asyncio.coroutine def sunrise_automation_listener(now): """Called when it's time for action.""" nonlocal remove - remove = track_point_in_utc_time(hass, sunrise_automation_listener, - next_rise()) - action() + remove = async_track_point_in_utc_time( + hass, sunrise_automation_listener, next_rise()) + hass.async_add_job(action) - remove = track_point_in_utc_time(hass, sunrise_automation_listener, - next_rise()) + remove = run_callback_threadsafe( + hass.loop, async_track_point_in_utc_time, hass, + sunrise_automation_listener, next_rise()).result() def remove_listener(): - """Remove sunrise listener.""" - remove() + """Remove sunset listener.""" + run_callback_threadsafe(hass.loop, remove).result() return remove_listener @@ -149,19 +180,21 @@ def track_sunset(hass, action, offset=None): return next_time + @asyncio.coroutine def sunset_automation_listener(now): """Called when it's time for action.""" nonlocal remove - remove = track_point_in_utc_time(hass, sunset_automation_listener, - next_set()) - action() + remove = async_track_point_in_utc_time( + hass, sunset_automation_listener, next_set()) + hass.async_add_job(action) - remove = track_point_in_utc_time(hass, sunset_automation_listener, - next_set()) + remove = run_callback_threadsafe( + hass.loop, async_track_point_in_utc_time, hass, + sunset_automation_listener, next_set()).result() def remove_listener(): """Remove sunset listener.""" - remove() + run_callback_threadsafe(hass.loop, remove).result() return remove_listener diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index d8005858a1e..6193d7f9ab8 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -149,8 +149,8 @@ class Template(object): global_vars = ENV.make_globals({ 'closest': location_methods.closest, 'distance': location_methods.distance, - 'is_state': self.hass.states.async_is_state, - 'is_state_attr': self.hass.states.async_is_state_attr, + 'is_state': self.hass.states.is_state, + 'is_state_attr': self.hass.states.is_state_attr, 'states': AllStates(self.hass), }) diff --git a/tests/components/test_init.py b/tests/components/test_init.py index 62467c14a2f..76878432ecd 100644 --- a/tests/components/test_init.py +++ b/tests/components/test_init.py @@ -77,7 +77,8 @@ class TestComponentsCore(unittest.TestCase): service_call = ha.ServiceCall('homeassistant', 'turn_on', { 'entity_id': ['light.test', 'sensor.bla', 'light.bla'] }) - self.hass.services._services['homeassistant']['turn_on'](service_call) + service = self.hass.services._services['homeassistant']['turn_on'] + service.func(service_call) self.assertEqual(2, mock_call.call_count) self.assertEqual( diff --git a/tests/components/test_logbook.py b/tests/components/test_logbook.py index a2cbd7094ca..539622d9296 100644 --- a/tests/components/test_logbook.py +++ b/tests/components/test_logbook.py @@ -1,7 +1,8 @@ """The tests for the logbook component.""" # pylint: disable=protected-access,too-many-public-methods -import unittest from datetime import timedelta +import unittest +from unittest.mock import patch from homeassistant.components import sun import homeassistant.core as ha @@ -18,13 +19,17 @@ from tests.common import mock_http_component, get_test_home_assistant class TestComponentLogbook(unittest.TestCase): """Test the History component.""" - EMPTY_CONFIG = logbook.CONFIG_SCHEMA({ha.DOMAIN: {}, logbook.DOMAIN: {}}) + EMPTY_CONFIG = logbook.CONFIG_SCHEMA({logbook.DOMAIN: {}}) def setUp(self): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() mock_http_component(self.hass) - assert setup_component(self.hass, logbook.DOMAIN, self.EMPTY_CONFIG) + self.hass.config.components += ['frontend', 'recorder', 'api'] + with patch('homeassistant.components.logbook.' + 'register_built_in_panel'): + assert setup_component(self.hass, logbook.DOMAIN, + self.EMPTY_CONFIG) def tearDown(self): """Stop everything that was started.""" @@ -44,7 +49,6 @@ class TestComponentLogbook(unittest.TestCase): logbook.ATTR_DOMAIN: 'switch', logbook.ATTR_ENTITY_ID: 'switch.test_switch' }, True) - self.hass.block_till_done() self.assertEqual(1, len(calls)) last_call = calls[-1] @@ -65,7 +69,6 @@ class TestComponentLogbook(unittest.TestCase): self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener) self.hass.services.call(logbook.DOMAIN, 'log', {}, True) - self.hass.block_till_done() self.assertEqual(0, len(calls)) diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 593e8b433c0..81ef17ff0fd 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -1,6 +1,9 @@ """Test the entity helper.""" # pylint: disable=protected-access,too-many-public-methods -import unittest +import asyncio +from unittest.mock import MagicMock + +import pytest import homeassistant.helpers.entity as entity from homeassistant.const import ATTR_HIDDEN @@ -8,26 +11,75 @@ from homeassistant.const import ATTR_HIDDEN from tests.common import get_test_home_assistant -class TestHelpersEntity(unittest.TestCase): +def test_generate_entity_id_requires_hass_or_ids(): + """Ensure we require at least hass or current ids.""" + fmt = 'test.{}' + with pytest.raises(ValueError): + entity.generate_entity_id(fmt, 'hello world') + + +def test_generate_entity_id_given_keys(): + """Test generating an entity id given current ids.""" + fmt = 'test.{}' + assert entity.generate_entity_id( + fmt, 'overwrite hidden true', current_ids=[ + 'test.overwrite_hidden_true']) == 'test.overwrite_hidden_true_2' + assert entity.generate_entity_id( + fmt, 'overwrite hidden true', current_ids=[ + 'test.another_entity']) == 'test.overwrite_hidden_true' + + +def test_async_update_support(event_loop): + """Test async update getting called.""" + sync_update = [] + async_update = [] + + class AsyncEntity(entity.Entity): + hass = MagicMock() + entity_id = 'sensor.test' + + def update(self): + sync_update.append([1]) + + ent = AsyncEntity() + ent.hass.loop = event_loop + + @asyncio.coroutine + def test(): + yield from ent.async_update_ha_state(True) + + event_loop.run_until_complete(test()) + + assert len(sync_update) == 1 + assert len(async_update) == 0 + + ent.async_update = lambda: async_update.append(1) + + event_loop.run_until_complete(test()) + + assert len(sync_update) == 1 + assert len(async_update) == 1 + + +class TestHelpersEntity(object): """Test homeassistant.helpers.entity module.""" - def setUp(self): # pylint: disable=invalid-name + def setup_method(self, method): """Setup things to be run when tests are started.""" self.entity = entity.Entity() self.entity.entity_id = 'test.overwrite_hidden_true' self.hass = self.entity.hass = get_test_home_assistant() self.entity.update_ha_state() - def tearDown(self): # pylint: disable=invalid-name + def teardown_method(self, method): """Stop everything that was started.""" - self.hass.stop() entity.set_customize({}) + self.hass.stop() def test_default_hidden_not_in_attributes(self): """Test that the default hidden property is set to False.""" - self.assertNotIn( - ATTR_HIDDEN, - self.hass.states.get(self.entity.entity_id).attributes) + assert ATTR_HIDDEN not in self.hass.states.get( + self.entity.entity_id).attributes def test_overwriting_hidden_property_to_true(self): """Test we can overwrite hidden property to True.""" @@ -35,31 +87,11 @@ class TestHelpersEntity(unittest.TestCase): self.entity.update_ha_state() state = self.hass.states.get(self.entity.entity_id) - self.assertTrue(state.attributes.get(ATTR_HIDDEN)) - - def test_generate_entity_id_requires_hass_or_ids(self): - """Ensure we require at least hass or current ids.""" - fmt = 'test.{}' - with self.assertRaises(ValueError): - entity.generate_entity_id(fmt, 'hello world') + assert state.attributes.get(ATTR_HIDDEN) def test_generate_entity_id_given_hass(self): """Test generating an entity id given hass object.""" fmt = 'test.{}' - self.assertEqual( - 'test.overwrite_hidden_true_2', - entity.generate_entity_id(fmt, 'overwrite hidden true', - hass=self.hass)) - - def test_generate_entity_id_given_keys(self): - """Test generating an entity id given current ids.""" - fmt = 'test.{}' - self.assertEqual( - 'test.overwrite_hidden_true_2', - entity.generate_entity_id( - fmt, 'overwrite hidden true', - current_ids=['test.overwrite_hidden_true'])) - self.assertEqual( - 'test.overwrite_hidden_true', - entity.generate_entity_id(fmt, 'overwrite hidden true', - current_ids=['test.another_entity'])) + assert entity.generate_entity_id( + fmt, 'overwrite hidden true', + hass=self.hass) == 'test.overwrite_hidden_true_2' diff --git a/tests/test_core.py b/tests/test_core.py index 9b57f07e9e6..9fa742985c4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,6 +1,7 @@ """Test to verify that Home Assistant core works.""" # pylint: disable=protected-access,too-many-public-methods # pylint: disable=too-few-public-methods +import asyncio import os import signal import unittest @@ -362,7 +363,6 @@ class TestServiceRegistry(unittest.TestCase): self.hass = get_test_home_assistant() self.services = self.hass.services self.services.register("Test_Domain", "TEST_SERVICE", lambda x: None) - self.hass.block_till_done() def tearDown(self): # pylint: disable=invalid-name """Stop down stuff we started.""" @@ -387,8 +387,13 @@ class TestServiceRegistry(unittest.TestCase): def test_call_with_blocking_done_in_time(self): """Test call with blocking.""" calls = [] + + def service_handler(call): + """Service handler.""" + calls.append(call) + self.services.register("test_domain", "register_calls", - lambda x: calls.append(1)) + service_handler) self.assertTrue( self.services.call('test_domain', 'REGISTER_CALLS', blocking=True)) @@ -404,6 +409,22 @@ class TestServiceRegistry(unittest.TestCase): finally: ha.SERVICE_CALL_LIMIT = prior + def test_async_service(self): + """Test registering and calling an async service.""" + calls = [] + + @asyncio.coroutine + def service_handler(call): + """Service handler coroutine.""" + calls.append(call) + + self.services.register('test_domain', 'register_calls', + service_handler) + self.assertTrue( + self.services.call('test_domain', 'REGISTER_CALLS', blocking=True)) + self.hass.block_till_done() + self.assertEqual(1, len(calls)) + class TestConfig(unittest.TestCase): """Test configuration methods."""