Merge pull request #3627 from home-assistant/async-entity-update

Add async updates to entities
This commit is contained in:
Paulus Schoutsen 2016-10-01 16:20:48 -07:00 committed by GitHub
commit 996d7cf1cd
28 changed files with 477 additions and 466 deletions

View File

@ -4,6 +4,7 @@ Allow to setup simple automation rules via the config file.
For more details about this component, please refer to the documentation at For more details about this component, please refer to the documentation at
https://home-assistant.io/components/automation/ https://home-assistant.io/components/automation/
""" """
import asyncio
from functools import partial from functools import partial
import logging import logging
import os import os
@ -23,6 +24,7 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.loader import get_platform from homeassistant.loader import get_platform
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util.async import run_coroutine_threadsafe
DOMAIN = 'automation' DOMAIN = 'automation'
ENTITY_ID_FORMAT = DOMAIN + '.{}' ENTITY_ID_FORMAT = DOMAIN + '.{}'
@ -44,9 +46,6 @@ CONDITION_TYPE_OR = 'or'
DEFAULT_CONDITION_TYPE = CONDITION_TYPE_AND DEFAULT_CONDITION_TYPE = CONDITION_TYPE_AND
DEFAULT_HIDE_ENTITY = False DEFAULT_HIDE_ENTITY = False
METHOD_TRIGGER = 'trigger'
METHOD_IF_ACTION = 'if_action'
ATTR_LAST_TRIGGERED = 'last_triggered' ATTR_LAST_TRIGGERED = 'last_triggered'
ATTR_VARIABLES = 'variables' ATTR_VARIABLES = 'variables'
SERVICE_TRIGGER = 'trigger' SERVICE_TRIGGER = 'trigger'
@ -55,21 +54,14 @@ SERVICE_RELOAD = 'reload'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def _platform_validator(method, schema): def _platform_validator(config):
"""Generate platform validator for different steps.""" """Validate it is a valid platform."""
def validator(config): platform = get_platform(DOMAIN, config[CONF_PLATFORM])
"""Validate it is a valid platform."""
platform = get_platform(DOMAIN, config[CONF_PLATFORM])
if not hasattr(platform, method): if not hasattr(platform, 'TRIGGER_SCHEMA'):
raise vol.Invalid('invalid method platform') return config
if not hasattr(platform, schema): return getattr(platform, 'TRIGGER_SCHEMA')(config)
return config
return getattr(platform, schema)(config)
return validator
_TRIGGER_SCHEMA = vol.All( _TRIGGER_SCHEMA = vol.All(
cv.ensure_list, cv.ensure_list,
@ -78,33 +70,17 @@ _TRIGGER_SCHEMA = vol.All(
vol.Schema({ vol.Schema({
vol.Required(CONF_PLATFORM): cv.platform_validator(DOMAIN) vol.Required(CONF_PLATFORM): cv.platform_validator(DOMAIN)
}, extra=vol.ALLOW_EXTRA), }, extra=vol.ALLOW_EXTRA),
_platform_validator(METHOD_TRIGGER, 'TRIGGER_SCHEMA') _platform_validator
), ),
] ]
) )
_CONDITION_SCHEMA = vol.Any( _CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA])
CONDITION_USE_TRIGGER_VALUES,
vol.All(
cv.ensure_list,
[
vol.All(
vol.Schema({
CONF_PLATFORM: str,
CONF_CONDITION: str,
}, extra=vol.ALLOW_EXTRA),
cv.has_at_least_one_key(CONF_PLATFORM, CONF_CONDITION),
),
]
)
)
PLATFORM_SCHEMA = vol.Schema({ PLATFORM_SCHEMA = vol.Schema({
CONF_ALIAS: cv.string, CONF_ALIAS: cv.string,
vol.Optional(CONF_HIDE_ENTITY, default=DEFAULT_HIDE_ENTITY): cv.boolean, vol.Optional(CONF_HIDE_ENTITY, default=DEFAULT_HIDE_ENTITY): cv.boolean,
vol.Required(CONF_TRIGGER): _TRIGGER_SCHEMA, vol.Required(CONF_TRIGGER): _TRIGGER_SCHEMA,
vol.Required(CONF_CONDITION_TYPE, default=DEFAULT_CONDITION_TYPE):
vol.All(vol.Lower, vol.Any(CONDITION_TYPE_AND, CONDITION_TYPE_OR)),
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA, vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA, vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
}) })
@ -165,7 +141,8 @@ def setup(hass, config):
"""Setup the automation.""" """Setup the automation."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
success = _process_config(hass, config, component) success = run_coroutine_threadsafe(
_async_process_config(hass, config, component), hass.loop).result()
if not success: if not success:
return False return False
@ -173,22 +150,28 @@ def setup(hass, config):
descriptions = conf_util.load_yaml_config_file( descriptions = conf_util.load_yaml_config_file(
os.path.join(os.path.dirname(__file__), 'services.yaml')) os.path.join(os.path.dirname(__file__), 'services.yaml'))
@asyncio.coroutine
def trigger_service_handler(service_call): def trigger_service_handler(service_call):
"""Handle automation triggers.""" """Handle automation triggers."""
for entity in component.extract_from_service(service_call): for entity in component.extract_from_service(service_call):
entity.trigger(service_call.data.get(ATTR_VARIABLES)) yield from entity.async_trigger(
service_call.data.get(ATTR_VARIABLES))
@asyncio.coroutine
def service_handler(service_call): def service_handler(service_call):
"""Handle automation service calls.""" """Handle automation service calls."""
method = 'async_{}'.format(service_call.service)
for entity in component.extract_from_service(service_call): for entity in component.extract_from_service(service_call):
getattr(entity, service_call.service)() yield from getattr(entity, method)()
@asyncio.coroutine
def reload_service_handler(service_call): def reload_service_handler(service_call):
"""Remove all automations and load new ones from config.""" """Remove all automations and load new ones from config."""
conf = component.prepare_reload() conf = yield from hass.loop.run_in_executor(
None, component.prepare_reload)
if conf is None: if conf is None:
return return
_process_config(hass, conf, component) yield from _async_process_config(hass, conf, component)
hass.services.register(DOMAIN, SERVICE_TRIGGER, trigger_service_handler, hass.services.register(DOMAIN, SERVICE_TRIGGER, trigger_service_handler,
descriptions.get(SERVICE_TRIGGER), descriptions.get(SERVICE_TRIGGER),
@ -209,15 +192,17 @@ def setup(hass, config):
class AutomationEntity(ToggleEntity): class AutomationEntity(ToggleEntity):
"""Entity to show status of entity.""" """Entity to show status of entity."""
# pylint: disable=abstract-method
# pylint: disable=too-many-arguments, too-many-instance-attributes # pylint: disable=too-many-arguments, too-many-instance-attributes
def __init__(self, name, attach_triggers, cond_func, action, hidden): def __init__(self, name, async_attach_triggers, cond_func, async_action,
hidden):
"""Initialize an automation entity.""" """Initialize an automation entity."""
self._name = name self._name = name
self._attach_triggers = attach_triggers self._async_attach_triggers = async_attach_triggers
self._detach_triggers = attach_triggers(self.trigger) self._async_detach_triggers = None
self._cond_func = cond_func self._cond_func = cond_func
self._action = action self._async_action = async_action
self._enabled = True self._enabled = False
self._last_triggered = None self._last_triggered = None
self._hidden = hidden self._hidden = hidden
@ -248,41 +233,60 @@ class AutomationEntity(ToggleEntity):
"""Return True if entity is on.""" """Return True if entity is on."""
return self._enabled return self._enabled
def turn_on(self, **kwargs) -> None: @asyncio.coroutine
"""Turn the entity on.""" def async_turn_on(self, **kwargs) -> None:
if self._enabled: """Turn the entity on and update the state."""
return yield from self.async_enable()
yield from self.async_update_ha_state()
self._detach_triggers = self._attach_triggers(self.trigger) @asyncio.coroutine
self._enabled = True def async_turn_off(self, **kwargs) -> None:
self.update_ha_state()
def turn_off(self, **kwargs) -> None:
"""Turn the entity off.""" """Turn the entity off."""
if not self._enabled: if not self._enabled:
return return
self._detach_triggers() self._async_detach_triggers()
self._detach_triggers = None self._async_detach_triggers = None
self._enabled = False self._enabled = False
self.update_ha_state() yield from self.async_update_ha_state()
def trigger(self, variables): @asyncio.coroutine
def async_toggle(self):
"""Toggle the state of the entity."""
if self._enabled:
yield from self.async_turn_off()
else:
yield from self.async_turn_on()
@asyncio.coroutine
def async_trigger(self, variables):
"""Trigger automation.""" """Trigger automation."""
if self._cond_func(variables): if self._cond_func(variables):
self._action(variables) yield from self._async_action(variables)
self._last_triggered = utcnow() self._last_triggered = utcnow()
self.update_ha_state() yield from self.async_update_ha_state()
def remove(self): def remove(self):
"""Remove automation from HASS.""" """Remove automation from HASS."""
self.turn_off() run_coroutine_threadsafe(self.async_turn_off(),
self.hass.loop).result()
super().remove() super().remove()
@asyncio.coroutine
def async_enable(self):
"""Enable this automation entity."""
if self._enabled:
return
def _process_config(hass, config, component): self._async_detach_triggers = yield from self._async_attach_triggers(
self.async_trigger)
self._enabled = True
@asyncio.coroutine
def _async_process_config(hass, config, component):
"""Process config and add automations.""" """Process config and add automations."""
success = False entities = []
for config_key in extract_domain_configs(config, DOMAIN): for config_key in extract_domain_configs(config, DOMAIN):
conf = config[config_key] conf = config[config_key]
@ -293,10 +297,11 @@ def _process_config(hass, config, component):
hidden = config_block[CONF_HIDE_ENTITY] hidden = config_block[CONF_HIDE_ENTITY]
action = _get_action(hass, config_block.get(CONF_ACTION, {}), name) action = _async_get_action(hass, config_block.get(CONF_ACTION, {}),
name)
if CONF_CONDITION in config_block: if CONF_CONDITION in config_block:
cond_func = _process_if(hass, config, config_block) cond_func = _async_process_if(hass, config, config_block)
if cond_func is None: if cond_func is None:
continue continue
@ -305,101 +310,73 @@ def _process_config(hass, config, component):
"""Condition will always pass.""" """Condition will always pass."""
return True return True
attach_triggers = partial(_process_trigger, hass, config, async_attach_triggers = partial(
config_block.get(CONF_TRIGGER, []), name) _async_process_trigger, hass, config,
entity = AutomationEntity(name, attach_triggers, cond_func, action, config_block.get(CONF_TRIGGER, []), name)
hidden) entity = AutomationEntity(name, async_attach_triggers, cond_func,
component.add_entities((entity,)) action, hidden)
success = True yield from entity.async_enable()
entities.append(entity)
return success yield from hass.loop.run_in_executor(
None, component.add_entities, entities)
return len(entities) > 0
def _get_action(hass, config, name): def _async_get_action(hass, config, name):
"""Return an action based on a configuration.""" """Return an action based on a configuration."""
script_obj = script.Script(hass, config, name) script_obj = script.Script(hass, config, name)
@asyncio.coroutine
def action(variables=None): def action(variables=None):
"""Action to be executed.""" """Action to be executed."""
_LOGGER.info('Executing %s', name) _LOGGER.info('Executing %s', name)
logbook.log_entry(hass, name, 'has been triggered', DOMAIN) logbook.async_log_entry(hass, name, 'has been triggered', DOMAIN)
script_obj.run(variables) yield from script_obj.async_run(variables)
return action return action
def _process_if(hass, config, p_config): def _async_process_if(hass, config, p_config):
"""Process if checks.""" """Process if checks."""
cond_type = p_config.get(CONF_CONDITION_TYPE,
DEFAULT_CONDITION_TYPE).lower()
# Deprecated since 0.19 - 5/5/2016
if cond_type != DEFAULT_CONDITION_TYPE:
_LOGGER.warning('Using condition_type: "or" is deprecated. Please use '
'"condition: or" instead.')
if_configs = p_config.get(CONF_CONDITION) if_configs = p_config.get(CONF_CONDITION)
use_trigger = if_configs == CONDITION_USE_TRIGGER_VALUES
if use_trigger:
if_configs = p_config[CONF_TRIGGER]
checks = [] checks = []
for if_config in if_configs: for if_config in if_configs:
# Deprecated except for used by use_trigger_values
# since 0.19 - 5/5/2016
if CONF_PLATFORM in if_config:
if not use_trigger:
_LOGGER.warning("Please switch your condition configuration "
"to use 'condition' instead of 'platform'.")
if_config = dict(if_config)
if_config[CONF_CONDITION] = if_config.pop(CONF_PLATFORM)
# To support use_trigger_values with state trigger accepting
# multiple entity_ids to monitor.
if_entity_id = if_config.get(ATTR_ENTITY_ID)
if isinstance(if_entity_id, list) and len(if_entity_id) == 1:
if_config[ATTR_ENTITY_ID] = if_entity_id[0]
try: try:
checks.append(condition.from_config(if_config)) checks.append(condition.async_from_config(if_config, False))
except HomeAssistantError as ex: except HomeAssistantError as ex:
# Invalid conditions are allowed if we base it on trigger _LOGGER.warning('Invalid condition: %s', ex)
if use_trigger: return None
_LOGGER.warning('Ignoring invalid condition: %s', ex)
else:
_LOGGER.warning('Invalid condition: %s', ex)
return None
if cond_type == CONDITION_TYPE_AND: def if_action(variables=None):
def if_action(variables=None): """AND all conditions."""
"""AND all conditions.""" return all(check(hass, variables) for check in checks)
return all(check(hass, variables) for check in checks)
else:
def if_action(variables=None):
"""OR all conditions."""
return any(check(hass, variables) for check in checks)
return if_action return if_action
def _process_trigger(hass, config, trigger_configs, name, action): @asyncio.coroutine
def _async_process_trigger(hass, config, trigger_configs, name, action):
"""Setup the triggers.""" """Setup the triggers."""
removes = [] removes = []
for conf in trigger_configs: for conf in trigger_configs:
platform = _resolve_platform(METHOD_TRIGGER, hass, config, platform = yield from hass.loop.run_in_executor(
conf.get(CONF_PLATFORM)) None, prepare_setup_platform, hass, config, DOMAIN,
if platform is None: conf.get(CONF_PLATFORM))
continue
remove = platform.trigger(hass, conf, action) if platform is None:
return None
remove = platform.async_trigger(hass, conf, action)
if not remove: if not remove:
_LOGGER.error("Error setting up rule %s", name) _LOGGER.error("Error setting up trigger %s", name)
continue continue
_LOGGER.info("Initialized rule %s", name) _LOGGER.info("Initialized trigger %s", name)
removes.append(remove) removes.append(remove)
if not removes: if not removes:
@ -411,17 +388,3 @@ def _process_trigger(hass, config, trigger_configs, name, action):
remove() remove()
return remove_triggers return remove_triggers
def _resolve_platform(method, hass, config, platform):
"""Find the automation platform."""
if platform is None:
return None
platform = prepare_setup_platform(hass, config, DOMAIN, platform)
if platform is None or not hasattr(platform, method):
_LOGGER.error("Unknown automation platform specified for %s: %s",
method, platform)
return None
return platform

View File

@ -24,7 +24,7 @@ TRIGGER_SCHEMA = vol.Schema({
}) })
def trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for events based on configuration.""" """Listen for events based on configuration."""
event_type = config.get(CONF_EVENT_TYPE) event_type = config.get(CONF_EVENT_TYPE)
event_data = config.get(CONF_EVENT_DATA) event_data = config.get(CONF_EVENT_DATA)
@ -41,4 +41,4 @@ def trigger(hass, config, action):
}, },
}) })
return hass.bus.listen(event_type, handle_event) return hass.bus.async_listen(event_type, handle_event)

View File

@ -22,7 +22,7 @@ TRIGGER_SCHEMA = vol.Schema({
}) })
def trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
topic = config.get(CONF_TOPIC) topic = config.get(CONF_TOPIC)
payload = config.get(CONF_PAYLOAD) payload = config.get(CONF_PAYLOAD)
@ -40,4 +40,4 @@ def trigger(hass, config, action):
} }
}) })
return mqtt.subscribe(hass, topic, mqtt_automation_listener) return mqtt.async_subscribe(hass, topic, mqtt_automation_listener)

View File

@ -12,7 +12,7 @@ import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
CONF_VALUE_TEMPLATE, CONF_PLATFORM, CONF_ENTITY_ID, CONF_VALUE_TEMPLATE, CONF_PLATFORM, CONF_ENTITY_ID,
CONF_BELOW, CONF_ABOVE) CONF_BELOW, CONF_ABOVE)
from homeassistant.helpers.event import track_state_change from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers import condition, config_validation as cv from homeassistant.helpers import condition, config_validation as cv
TRIGGER_SCHEMA = vol.All(vol.Schema({ TRIGGER_SCHEMA = vol.All(vol.Schema({
@ -26,7 +26,7 @@ TRIGGER_SCHEMA = vol.All(vol.Schema({
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
entity_id = config.get(CONF_ENTITY_ID) entity_id = config.get(CONF_ENTITY_ID)
below = config.get(CONF_BELOW) below = config.get(CONF_BELOW)
@ -66,4 +66,4 @@ def trigger(hass, config, action):
hass.async_add_job(action, variables) hass.async_add_job(action, variables)
return track_state_change(hass, entity_id, state_automation_listener) return async_track_state_change(hass, entity_id, state_automation_listener)

View File

@ -12,7 +12,6 @@ from homeassistant.const import MATCH_ALL, CONF_PLATFORM
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
async_track_state_change, async_track_point_in_utc_time) async_track_state_change, async_track_point_in_utc_time)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util.async import run_callback_threadsafe
CONF_ENTITY_ID = "entity_id" CONF_ENTITY_ID = "entity_id"
CONF_FROM = "from" CONF_FROM = "from"
@ -35,7 +34,7 @@ TRIGGER_SCHEMA = vol.All(
) )
def trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
entity_id = config.get(CONF_ENTITY_ID) entity_id = config.get(CONF_ENTITY_ID)
from_state = config.get(CONF_FROM, MATCH_ALL) from_state = config.get(CONF_FROM, MATCH_ALL)
@ -98,8 +97,4 @@ def trigger(hass, config, action):
if async_remove_state_for_listener is not None: if async_remove_state_for_listener is not None:
async_remove_state_for_listener() async_remove_state_for_listener()
def remove(): return async_remove
"""Remove state listeners."""
run_callback_threadsafe(hass.loop, async_remove).result()
return remove

View File

@ -12,7 +12,7 @@ import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
CONF_EVENT, CONF_OFFSET, CONF_PLATFORM, SUN_EVENT_SUNRISE) CONF_EVENT, CONF_OFFSET, CONF_PLATFORM, SUN_EVENT_SUNRISE)
from homeassistant.helpers.event import track_sunrise, track_sunset from homeassistant.helpers.event import async_track_sunrise, async_track_sunset
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
DEPENDENCIES = ['sun'] DEPENDENCIES = ['sun']
@ -26,7 +26,7 @@ TRIGGER_SCHEMA = vol.Schema({
}) })
def trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for events based on configuration.""" """Listen for events based on configuration."""
event = config.get(CONF_EVENT) event = config.get(CONF_EVENT)
offset = config.get(CONF_OFFSET) offset = config.get(CONF_OFFSET)
@ -44,6 +44,6 @@ def trigger(hass, config, action):
# Do something to call action # Do something to call action
if event == SUN_EVENT_SUNRISE: if event == SUN_EVENT_SUNRISE:
return track_sunrise(hass, call_action, offset) return async_track_sunrise(hass, call_action, offset)
else: else:
return track_sunset(hass, call_action, offset) return async_track_sunset(hass, call_action, offset)

View File

@ -11,7 +11,7 @@ import voluptuous as vol
from homeassistant.const import CONF_VALUE_TEMPLATE, CONF_PLATFORM from homeassistant.const import CONF_VALUE_TEMPLATE, CONF_PLATFORM
from homeassistant.helpers import condition from homeassistant.helpers import condition
from homeassistant.helpers.event import track_state_change from homeassistant.helpers.event import async_track_state_change
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -23,7 +23,7 @@ TRIGGER_SCHEMA = IF_ACTION_SCHEMA = vol.Schema({
}) })
def trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
value_template = config.get(CONF_VALUE_TEMPLATE) value_template = config.get(CONF_VALUE_TEMPLATE)
value_template.hass = hass value_template.hass = hass
@ -51,5 +51,5 @@ def trigger(hass, config, action):
elif not template_result: elif not template_result:
already_triggered = False already_triggered = False
return track_state_change(hass, value_template.extract_entities(), return async_track_state_change(hass, value_template.extract_entities(),
state_changed_listener) state_changed_listener)

View File

@ -11,7 +11,7 @@ import voluptuous as vol
from homeassistant.const import CONF_AFTER, CONF_PLATFORM from homeassistant.const import CONF_AFTER, CONF_PLATFORM
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.event import track_time_change from homeassistant.helpers.event import async_track_time_change
CONF_HOURS = "hours" CONF_HOURS = "hours"
CONF_MINUTES = "minutes" CONF_MINUTES = "minutes"
@ -29,7 +29,7 @@ TRIGGER_SCHEMA = vol.All(vol.Schema({
CONF_SECONDS, CONF_AFTER)) CONF_SECONDS, CONF_AFTER))
def trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
if CONF_AFTER in config: if CONF_AFTER in config:
after = config.get(CONF_AFTER) after = config.get(CONF_AFTER)
@ -49,5 +49,5 @@ def trigger(hass, config, action):
}, },
}) })
return track_time_change(hass, time_automation_listener, return async_track_time_change(hass, time_automation_listener,
hour=hours, minute=minutes, second=seconds) hour=hours, minute=minutes, second=seconds)

View File

@ -9,7 +9,7 @@ import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
CONF_EVENT, CONF_ENTITY_ID, CONF_ZONE, MATCH_ALL, CONF_PLATFORM) CONF_EVENT, CONF_ENTITY_ID, CONF_ZONE, MATCH_ALL, CONF_PLATFORM)
from homeassistant.helpers.event import track_state_change from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers import ( from homeassistant.helpers import (
condition, config_validation as cv, location) condition, config_validation as cv, location)
@ -26,7 +26,7 @@ TRIGGER_SCHEMA = vol.Schema({
}) })
def trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
entity_id = config.get(CONF_ENTITY_ID) entity_id = config.get(CONF_ENTITY_ID)
zone_entity_id = config.get(CONF_ZONE) zone_entity_id = config.get(CONF_ZONE)
@ -60,5 +60,5 @@ def trigger(hass, config, action):
}, },
}) })
return track_state_change(hass, entity_id, zone_automation_listener, return async_track_state_change(hass, entity_id, zone_automation_listener,
MATCH_ALL, MATCH_ALL) MATCH_ALL, MATCH_ALL)

View File

@ -4,6 +4,7 @@ Support for exposing a templated binary sensor.
For more details about this platform, please refer to the documentation at For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/binary_sensor.template/ https://home-assistant.io/components/binary_sensor.template/
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
@ -81,9 +82,10 @@ class BinarySensorTemplate(BinarySensorDevice):
self.update() self.update()
@asyncio.coroutine
def template_bsensor_state_listener(entity, old_state, new_state): def template_bsensor_state_listener(entity, old_state, new_state):
"""Called when the target device changes state.""" """Called when the target device changes state."""
self.update_ha_state(True) yield from self.async_update_ha_state(True)
track_state_change(hass, entity_ids, template_bsensor_state_listener) track_state_change(hass, entity_ids, template_bsensor_state_listener)
@ -107,10 +109,11 @@ class BinarySensorTemplate(BinarySensorDevice):
"""No polling needed.""" """No polling needed."""
return False return False
def update(self): @asyncio.coroutine
def async_update(self):
"""Get the latest data and update the state.""" """Get the latest data and update the state."""
try: try:
self._state = self._template.render().lower() == 'true' self._state = self._template.async_render().lower() == 'true'
except TemplateError as ex: except TemplateError as ex:
if ex.args and ex.args[0].startswith( if ex.args and ex.args[0].startswith(
"UndefinedError: 'None' has no attribute"): "UndefinedError: 'None' has no attribute"):

View File

@ -21,6 +21,7 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.const import ( from homeassistant.const import (
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
CONF_PLATFORM, CONF_SCAN_INTERVAL, CONF_VALUE_TEMPLATE) CONF_PLATFORM, CONF_SCAN_INTERVAL, CONF_VALUE_TEMPLATE)
from homeassistant.util.async import run_callback_threadsafe
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -165,6 +166,18 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
def subscribe(hass, topic, callback, qos=DEFAULT_QOS): def subscribe(hass, topic, callback, qos=DEFAULT_QOS):
"""Subscribe to an MQTT topic."""
async_remove = run_callback_threadsafe(
hass.loop, async_subscribe, hass, topic, callback, qos).result()
def remove_mqtt():
"""Remove MQTT subscription."""
run_callback_threadsafe(hass.loop, async_remove).result()
return remove_mqtt
def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS):
"""Subscribe to an MQTT topic.""" """Subscribe to an MQTT topic."""
@asyncio.coroutine @asyncio.coroutine
def mqtt_topic_subscriber(event): def mqtt_topic_subscriber(event):
@ -181,13 +194,13 @@ def subscribe(hass, topic, callback, qos=DEFAULT_QOS):
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS], event.data[ATTR_PAYLOAD], event.data[ATTR_QOS],
priority=JobPriority.EVENT_CALLBACK) priority=JobPriority.EVENT_CALLBACK)
remove = hass.bus.listen(EVENT_MQTT_MESSAGE_RECEIVED, async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED,
mqtt_topic_subscriber) mqtt_topic_subscriber)
# Future: track subscriber count and unsubscribe in remove # Future: track subscriber count and unsubscribe in remove
MQTT_CLIENT.subscribe(topic, qos) MQTT_CLIENT.subscribe(topic, qos)
return remove return async_remove
def _setup_server(hass, config): def _setup_server(hass, config):

View File

@ -4,6 +4,7 @@ Allows the creation of a sensor that breaks out state_attributes.
For more details about this platform, please refer to the documentation at For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/sensor.template/ https://home-assistant.io/components/sensor.template/
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
@ -78,9 +79,10 @@ class SensorTemplate(Entity):
self.update() self.update()
@asyncio.coroutine
def template_sensor_state_listener(entity, old_state, new_state): def template_sensor_state_listener(entity, old_state, new_state):
"""Called when the target device changes state.""" """Called when the target device changes state."""
self.update_ha_state(True) yield from self.async_update_ha_state(True)
track_state_change(hass, entity_ids, template_sensor_state_listener) track_state_change(hass, entity_ids, template_sensor_state_listener)
@ -104,10 +106,11 @@ class SensorTemplate(Entity):
"""No polling needed.""" """No polling needed."""
return False return False
def update(self): @asyncio.coroutine
def async_update(self):
"""Get the latest data and update the states.""" """Get the latest data and update the states."""
try: try:
self._state = self._template.render() self._state = self._template.async_render()
except TemplateError as ex: except TemplateError as ex:
if ex.args and ex.args[0].startswith( if ex.args and ex.args[0].startswith(
"UndefinedError: 'None' has no attribute"): "UndefinedError: 'None' has no attribute"):

View File

@ -4,6 +4,7 @@ Support for switches which integrates with other components.
For more details about this platform, please refer to the documentation at For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/switch.template/ https://home-assistant.io/components/switch.template/
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
@ -87,9 +88,10 @@ class SwitchTemplate(SwitchDevice):
self.update() self.update()
@asyncio.coroutine
def template_switch_state_listener(entity, old_state, new_state): def template_switch_state_listener(entity, old_state, new_state):
"""Called when the target device changes state.""" """Called when the target device changes state."""
self.update_ha_state(True) yield from self.async_update_ha_state(True)
track_state_change(hass, entity_ids, template_switch_state_listener) track_state_change(hass, entity_ids, template_switch_state_listener)
@ -121,10 +123,11 @@ class SwitchTemplate(SwitchDevice):
"""Fire the off action.""" """Fire the off action."""
self._off_script.run() self._off_script.run()
def update(self): @asyncio.coroutine
def async_update(self):
"""Update the state from the template.""" """Update the state from the template."""
try: try:
state = self._template.render().lower() state = self._template.async_render().lower()
if state in _VALID_STATES: if state in _VALID_STATES:
self._state = state in ('true', STATE_ON) self._state = state in ('true', STATE_ON)

View File

@ -122,8 +122,8 @@ class HomeAssistant(object):
def __init__(self, loop=None): def __init__(self, loop=None):
"""Initialize new Home Assistant object.""" """Initialize new Home Assistant object."""
self.loop = loop or asyncio.get_event_loop() self.loop = loop or asyncio.get_event_loop()
self.executer = ThreadPoolExecutor(max_workers=5) self.executor = ThreadPoolExecutor(max_workers=5)
self.loop.set_default_executor(self.executer) self.loop.set_default_executor(self.executor)
self.pool = pool = create_worker_pool() self.pool = pool = create_worker_pool()
self.bus = EventBus(pool, self.loop) self.bus = EventBus(pool, self.loop)
self.services = ServiceRegistry(self.bus, self.add_job, self.loop) self.services = ServiceRegistry(self.bus, self.add_job, self.loop)
@ -287,7 +287,7 @@ class HomeAssistant(object):
self.bus.async_fire(EVENT_HOMEASSISTANT_STOP) self.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
yield from self.loop.run_in_executor(None, self.pool.block_till_done) yield from self.loop.run_in_executor(None, self.pool.block_till_done)
yield from self.loop.run_in_executor(None, self.pool.stop) yield from self.loop.run_in_executor(None, self.pool.stop)
self.executer.shutdown() self.executor.shutdown()
self.state = CoreState.not_running self.state = CoreState.not_running
self.loop.stop() self.loop.stop()

View File

@ -1,5 +1,6 @@
"""Offer reusable conditions.""" """Offer reusable conditions."""
from datetime import timedelta from datetime import timedelta
import functools as ft
import logging import logging
import sys import sys
@ -20,15 +21,44 @@ import homeassistant.util.dt as dt_util
from homeassistant.util.async import run_callback_threadsafe from homeassistant.util.async import run_callback_threadsafe
FROM_CONFIG_FORMAT = '{}_from_config' FROM_CONFIG_FORMAT = '{}_from_config'
ASYNC_FROM_CONFIG_FORMAT = 'async_{}_from_config'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# PyLint does not like the use of _threaded_factory
# pylint: disable=invalid-name
def from_config(config: ConfigType, config_validation: bool=True):
"""Turn a condition configuration into a method.""" def _threaded_factory(async_factory):
factory = getattr( """Helper method to create threaded versions of async factories."""
sys.modules[__name__], @ft.wraps(async_factory)
FROM_CONFIG_FORMAT.format(config.get(CONF_CONDITION)), None) def factory(config, config_validation=True):
"""Threaded factory."""
async_check = async_factory(config, config_validation)
def condition_if(hass, variables=None):
"""Validate condition."""
return run_callback_threadsafe(
hass.loop, async_check, hass, variables,
).result()
return condition_if
return factory
def async_from_config(config: ConfigType, config_validation: bool=True):
"""Turn a condition configuration into a method.
Should be run on the event loop.
"""
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(
sys.modules[__name__],
fmt.format(config.get(CONF_CONDITION)), None)
if factory:
break
if factory is None: if factory is None:
raise HomeAssistantError('Invalid condition "{}" specified {}'.format( raise HomeAssistantError('Invalid condition "{}" specified {}'.format(
@ -37,49 +67,70 @@ def from_config(config: ConfigType, config_validation: bool=True):
return factory(config, config_validation) return factory(config, config_validation)
def and_from_config(config: ConfigType, config_validation: bool=True): from_config = _threaded_factory(async_from_config)
def async_and_from_config(config: ConfigType, config_validation: bool=True):
"""Create multi condition matcher using 'AND'.""" """Create multi condition matcher using 'AND'."""
if config_validation: if config_validation:
config = cv.AND_CONDITION_SCHEMA(config) config = cv.AND_CONDITION_SCHEMA(config)
checks = [from_config(entry, False) for entry in config['conditions']] checks = None
def if_and_condition(hass: HomeAssistant, def if_and_condition(hass: HomeAssistant,
variables=None) -> bool: variables=None) -> bool:
"""Test and condition.""" """Test and condition."""
for check in checks: nonlocal checks
try:
if checks is None:
checks = [async_from_config(entry, False) for entry
in config['conditions']]
try:
for check in checks:
if not check(hass, variables): if not check(hass, variables):
return False return False
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
_LOGGER.warning('Error during and-condition: %s', ex) _LOGGER.warning('Error during and-condition: %s', ex)
return False return False
return True return True
return if_and_condition return if_and_condition
def or_from_config(config: ConfigType, config_validation: bool=True): and_from_config = _threaded_factory(async_and_from_config)
def async_or_from_config(config: ConfigType, config_validation: bool=True):
"""Create multi condition matcher using 'OR'.""" """Create multi condition matcher using 'OR'."""
if config_validation: if config_validation:
config = cv.OR_CONDITION_SCHEMA(config) config = cv.OR_CONDITION_SCHEMA(config)
checks = [from_config(entry, False) for entry in config['conditions']] checks = None
def if_or_condition(hass: HomeAssistant, def if_or_condition(hass: HomeAssistant,
variables=None) -> bool: variables=None) -> bool:
"""Test and condition.""" """Test and condition."""
for check in checks: nonlocal checks
try:
if checks is None:
checks = [async_from_config(entry, False) for entry
in config['conditions']]
try:
for check in checks:
if check(hass, variables): if check(hass, variables):
return True return True
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
_LOGGER.warning('Error during or-condition: %s', ex) _LOGGER.warning('Error during or-condition: %s', ex)
return False return False
return if_or_condition return if_or_condition
or_from_config = _threaded_factory(async_or_from_config)
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def numeric_state(hass: HomeAssistant, entity, below=None, above=None, def numeric_state(hass: HomeAssistant, entity, below=None, above=None,
value_template=None, variables=None): value_template=None, variables=None):
@ -125,7 +176,7 @@ def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None,
return True return True
def numeric_state_from_config(config, config_validation=True): def async_numeric_state_from_config(config, config_validation=True):
"""Wrap action method with state based condition.""" """Wrap action method with state based condition."""
if config_validation: if config_validation:
config = cv.NUMERIC_STATE_CONDITION_SCHEMA(config) config = cv.NUMERIC_STATE_CONDITION_SCHEMA(config)
@ -139,12 +190,15 @@ def numeric_state_from_config(config, config_validation=True):
if value_template is not None: if value_template is not None:
value_template.hass = hass value_template.hass = hass
return numeric_state(hass, entity_id, below, above, value_template, return async_numeric_state(
variables) hass, entity_id, below, above, value_template, variables)
return if_numeric_state return if_numeric_state
numeric_state_from_config = _threaded_factory(async_numeric_state_from_config)
def state(hass, entity, req_state, for_period=None): def state(hass, entity, req_state, for_period=None):
"""Test if state matches requirements.""" """Test if state matches requirements."""
if isinstance(entity, str): if isinstance(entity, str):
@ -235,7 +289,7 @@ def async_template(hass, value_template, variables=None):
return value.lower() == 'true' return value.lower() == 'true'
def template_from_config(config, config_validation=True): def async_template_from_config(config, config_validation=True):
"""Wrap action method with state based condition.""" """Wrap action method with state based condition."""
if config_validation: if config_validation:
config = cv.TEMPLATE_CONDITION_SCHEMA(config) config = cv.TEMPLATE_CONDITION_SCHEMA(config)
@ -245,11 +299,14 @@ def template_from_config(config, config_validation=True):
"""Validate template based if-condition.""" """Validate template based if-condition."""
value_template.hass = hass value_template.hass = hass
return template(hass, value_template, variables) return async_template(hass, value_template, variables)
return template_if return template_if
template_from_config = _threaded_factory(async_template_from_config)
def time(before=None, after=None, weekday=None): def time(before=None, after=None, weekday=None):
"""Test if local time condition matches. """Test if local time condition matches.

View File

@ -24,13 +24,20 @@ def generate_entity_id(entity_id_format: str, name: Optional[str],
current_ids: Optional[List[str]]=None, current_ids: Optional[List[str]]=None,
hass: Optional[HomeAssistant]=None) -> str: hass: Optional[HomeAssistant]=None) -> str:
"""Generate a unique entity ID based on given entity IDs or used IDs.""" """Generate a unique entity ID based on given entity IDs or used IDs."""
name = (name or DEVICE_DEFAULT_NAME).lower()
if current_ids is None: if current_ids is None:
if hass is None: if hass is None:
raise ValueError("Missing required parameter currentids or hass") raise ValueError("Missing required parameter currentids or hass")
current_ids = hass.states.entity_ids() current_ids = hass.states.entity_ids()
return async_generate_entity_id(entity_id_format, name, current_ids)
def async_generate_entity_id(entity_id_format: str, name: Optional[str],
current_ids: Optional[List[str]]=None) -> str:
"""Generate a unique entity ID based on given entity IDs or used IDs."""
name = (name or DEVICE_DEFAULT_NAME).lower()
return ensure_unique_string( return ensure_unique_string(
entity_id_format.format(slugify(name)), current_ids) entity_id_format.format(slugify(name)), current_ids)
@ -49,6 +56,11 @@ class Entity(object):
# SAFE TO OVERWRITE # SAFE TO OVERWRITE
# The properties and methods here are safe to overwrite when inheriting # The properties and methods here are safe to overwrite when inheriting
# this class. These may be used to customize the behavior of the entity. # this class. These may be used to customize the behavior of the entity.
entity_id = None # type: str
# Owning hass instance. Will be set by EntityComponent
hass = None # type: Optional[HomeAssistant]
@property @property
def should_poll(self) -> bool: def should_poll(self) -> bool:
"""Return True if entity has to be polled for state. """Return True if entity has to be polled for state.
@ -128,18 +140,22 @@ class Entity(object):
return False return False
def update(self): def update(self):
"""Retrieve latest state.""" """Retrieve latest state.
pass
entity_id = None # type: str When not implemented, will forward call to async version if available.
"""
async_update = getattr(self, 'async_update', None)
if async_update is None:
return
run_coroutine_threadsafe(async_update(), self.hass.loop).result()
# DO NOT OVERWRITE # DO NOT OVERWRITE
# These properties and methods are either managed by Home Assistant or they # These properties and methods are either managed by Home Assistant or they
# are used to perform a very specific function. Overwriting these may # are used to perform a very specific function. Overwriting these may
# produce undesirable effects in the entity's operation. # produce undesirable effects in the entity's operation.
hass = None # type: Optional[HomeAssistant]
def update_ha_state(self, force_refresh=False): def update_ha_state(self, force_refresh=False):
"""Update Home Assistant with current state of entity. """Update Home Assistant with current state of entity.
@ -172,7 +188,7 @@ class Entity(object):
if force_refresh: if force_refresh:
if hasattr(self, 'async_update'): if hasattr(self, 'async_update'):
# pylint: disable=no-member # pylint: disable=no-member
self.async_update() yield from self.async_update()
else: else:
# PS: Run this in our own thread pool once we have # PS: Run this in our own thread pool once we have
# future support? # future support?

View File

@ -3,30 +3,36 @@ import asyncio
import functools as ft import functools as ft
from datetime import timedelta from datetime import timedelta
from ..core import HomeAssistant
from ..const import ( from ..const import (
ATTR_NOW, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL) ATTR_NOW, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
from ..util import dt as dt_util from ..util import dt as dt_util
from ..util.async import run_callback_threadsafe from ..util.async import run_callback_threadsafe
# PyLint does not like the use of _threaded_factory
# pylint: disable=invalid-name
def track_state_change(hass, entity_ids, action, from_state=None,
to_state=None):
"""Track specific state changes.
entity_ids, from_state and to_state can be string or list. def _threaded_factory(async_factory):
Use list to match multiple. """Convert an async event helper to a threaded one."""
@ft.wraps(async_factory)
def factory(*args, **kwargs):
"""Call async event helper safely."""
hass = args[0]
Returns a function that can be called to remove the listener. if not isinstance(hass, HomeAssistant):
""" raise TypeError('First parameter needs to be a hass instance')
async_unsub = run_callback_threadsafe(
hass.loop, async_track_state_change, hass, entity_ids, action,
from_state, to_state).result()
def remove(): async_remove = run_callback_threadsafe(
"""Remove listener.""" hass.loop, ft.partial(async_factory, *args, **kwargs)).result()
run_callback_threadsafe(hass.loop, async_unsub).result()
return remove def remove():
"""Threadsafe removal."""
run_callback_threadsafe(hass.loop, async_remove).result()
return remove
return factory
def async_track_state_change(hass, entity_ids, action, from_state=None, def async_track_state_change(hass, entity_ids, action, from_state=None,
@ -77,7 +83,10 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener) return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener)
def track_point_in_time(hass, action, point_in_time): track_state_change = _threaded_factory(async_track_state_change)
def async_track_point_in_time(hass, action, point_in_time):
"""Add a listener that fires once after a spefic point in time.""" """Add a listener that fires once after a spefic point in time."""
utc_point_in_time = dt_util.as_utc(point_in_time) utc_point_in_time = dt_util.as_utc(point_in_time)
@ -87,20 +96,11 @@ def track_point_in_time(hass, action, point_in_time):
"""Convert passed in UTC now to local now.""" """Convert passed in UTC now to local now."""
hass.async_add_job(action, dt_util.as_local(utc_now)) hass.async_add_job(action, dt_util.as_local(utc_now))
return track_point_in_utc_time(hass, utc_converter, utc_point_in_time) return async_track_point_in_utc_time(hass, utc_converter,
utc_point_in_time)
def track_point_in_utc_time(hass, action, point_in_time): track_point_in_time = _threaded_factory(async_track_point_in_time)
"""Add a listener that fires once after a specific point in UTC time."""
async_unsub = run_callback_threadsafe(
hass.loop, async_track_point_in_utc_time, hass, action, point_in_time
).result()
def remove():
"""Remove listener."""
run_callback_threadsafe(hass.loop, async_unsub).result()
return remove
def async_track_point_in_utc_time(hass, action, point_in_time): def async_track_point_in_utc_time(hass, action, point_in_time):
@ -133,7 +133,10 @@ def async_track_point_in_utc_time(hass, action, point_in_time):
return async_unsub return async_unsub
def track_sunrise(hass, action, offset=None): track_point_in_utc_time = _threaded_factory(async_track_point_in_utc_time)
def async_track_sunrise(hass, action, offset=None):
"""Add a listener that will fire a specified offset from sunrise daily.""" """Add a listener that will fire a specified offset from sunrise daily."""
from homeassistant.components import sun from homeassistant.components import sun
offset = offset or timedelta() offset = offset or timedelta()
@ -147,6 +150,7 @@ def track_sunrise(hass, action, offset=None):
return next_time return next_time
@ft.wraps(action)
@asyncio.coroutine @asyncio.coroutine
def sunrise_automation_listener(now): def sunrise_automation_listener(now):
"""Called when it's time for action.""" """Called when it's time for action."""
@ -155,18 +159,20 @@ def track_sunrise(hass, action, offset=None):
hass, sunrise_automation_listener, next_rise()) hass, sunrise_automation_listener, next_rise())
hass.async_add_job(action) hass.async_add_job(action)
remove = run_callback_threadsafe( remove = async_track_point_in_utc_time(
hass.loop, async_track_point_in_utc_time, hass, hass, sunrise_automation_listener, next_rise())
sunrise_automation_listener, next_rise()).result()
def remove_listener(): def remove_listener():
"""Remove sunset listener.""" """Remove sunset listener."""
run_callback_threadsafe(hass.loop, remove).result() remove()
return remove_listener return remove_listener
def track_sunset(hass, action, offset=None): track_sunrise = _threaded_factory(async_track_sunrise)
def async_track_sunset(hass, action, offset=None):
"""Add a listener that will fire a specified offset from sunset daily.""" """Add a listener that will fire a specified offset from sunset daily."""
from homeassistant.components import sun from homeassistant.components import sun
offset = offset or timedelta() offset = offset or timedelta()
@ -180,6 +186,7 @@ def track_sunset(hass, action, offset=None):
return next_time return next_time
@ft.wraps(action)
@asyncio.coroutine @asyncio.coroutine
def sunset_automation_listener(now): def sunset_automation_listener(now):
"""Called when it's time for action.""" """Called when it's time for action."""
@ -188,20 +195,23 @@ def track_sunset(hass, action, offset=None):
hass, sunset_automation_listener, next_set()) hass, sunset_automation_listener, next_set())
hass.async_add_job(action) hass.async_add_job(action)
remove = run_callback_threadsafe( remove = async_track_point_in_utc_time(
hass.loop, async_track_point_in_utc_time, hass, hass, sunset_automation_listener, next_set())
sunset_automation_listener, next_set()).result()
def remove_listener(): def remove_listener():
"""Remove sunset listener.""" """Remove sunset listener."""
run_callback_threadsafe(hass.loop, remove).result() remove()
return remove_listener return remove_listener
track_sunset = _threaded_factory(async_track_sunset)
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def track_utc_time_change(hass, action, year=None, month=None, day=None, def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
hour=None, minute=None, second=None, local=False): hour=None, minute=None, second=None,
local=False):
"""Add a listener that will fire if time matches a pattern.""" """Add a listener that will fire if time matches a pattern."""
# We do not have to wrap the function with time pattern matching logic # We do not have to wrap the function with time pattern matching logic
# if no pattern given # if no pattern given
@ -211,7 +221,7 @@ def track_utc_time_change(hass, action, year=None, month=None, day=None,
"""Fire every time event that comes in.""" """Fire every time event that comes in."""
action(event.data[ATTR_NOW]) action(event.data[ATTR_NOW])
return hass.bus.listen(EVENT_TIME_CHANGED, time_change_listener) return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener)
pmp = _process_time_match pmp = _process_time_match
year, month, day = pmp(year), pmp(month), pmp(day) year, month, day = pmp(year), pmp(month), pmp(day)
@ -237,15 +247,22 @@ def track_utc_time_change(hass, action, year=None, month=None, day=None,
hass.async_add_job(action, now) hass.async_add_job(action, now)
return hass.bus.listen(EVENT_TIME_CHANGED, pattern_time_change_listener) return hass.bus.async_listen(EVENT_TIME_CHANGED,
pattern_time_change_listener)
track_utc_time_change = _threaded_factory(async_track_utc_time_change)
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def track_time_change(hass, action, year=None, month=None, day=None, def async_track_time_change(hass, action, year=None, month=None, day=None,
hour=None, minute=None, second=None): hour=None, minute=None, second=None):
"""Add a listener that will fire if UTC time matches a pattern.""" """Add a listener that will fire if UTC time matches a pattern."""
return track_utc_time_change(hass, action, year, month, day, hour, minute, return async_track_utc_time_change(hass, action, year, month, day, hour,
second, local=True) minute, second, local=True)
track_time_change = _threaded_factory(async_track_time_change)
def _process_state_match(parameter): def _process_state_match(parameter):

View File

@ -1,6 +1,6 @@
"""Helpers to execute scripts.""" """Helpers to execute scripts."""
import asyncio
import logging import logging
import threading
from itertools import islice from itertools import islice
from typing import Optional, Sequence from typing import Optional, Sequence
@ -10,9 +10,11 @@ from homeassistant.core import HomeAssistant
from homeassistant.const import CONF_CONDITION from homeassistant.const import CONF_CONDITION
from homeassistant.helpers import ( from homeassistant.helpers import (
service, condition, template, config_validation as cv) service, condition, template, config_validation as cv)
from homeassistant.helpers.event import track_point_in_utc_time from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as date_util import homeassistant.util.dt as date_util
from homeassistant.util.async import (
run_coroutine_threadsafe, run_callback_threadsafe)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -47,8 +49,7 @@ class Script():
self.last_action = None self.last_action = None
self.can_cancel = any(CONF_DELAY in action for action self.can_cancel = any(CONF_DELAY in action for action
in self.sequence) in self.sequence)
self._lock = threading.Lock() self._async_unsub_delay_listener = None
self._unsub_delay_listener = None
self._template_cache = {} self._template_cache = {}
@property @property
@ -56,94 +57,107 @@ class Script():
"""Return true if script is on.""" """Return true if script is on."""
return self._cur != -1 return self._cur != -1
def run(self, variables: Optional[Sequence]=None) -> None: def run(self, variables=None):
"""Run script.""" """Run script."""
with self._lock: run_coroutine_threadsafe(
if self._cur == -1: self.async_run(variables), self.hass.loop).result()
self._log('Running script')
self._cur = 0
# Unregister callback if we were in a delay but turn on is called @asyncio.coroutine
# again. In that case we just continue execution. def async_run(self, variables: Optional[Sequence]=None) -> None:
self._remove_listener() """Run script.
for cur, action in islice(enumerate(self.sequence), self._cur, Returns a coroutine.
None): """
if self._cur == -1:
self._log('Running script')
self._cur = 0
if CONF_DELAY in action: # Unregister callback if we were in a delay but turn on is called
# Call ourselves in the future to continue work # again. In that case we just continue execution.
def script_delay(now): self._async_remove_listener()
"""Called after delay is done."""
self._unsub_delay_listener = None
self.run(variables)
delay = action[CONF_DELAY] for cur, action in islice(enumerate(self.sequence), self._cur,
None):
if isinstance(delay, template.Template): if CONF_DELAY in action:
delay = vol.All( # Call ourselves in the future to continue work
cv.time_period, @asyncio.coroutine
cv.positive_timedelta)( def script_delay(now):
delay.render()) """Called after delay is done."""
self._async_unsub_delay_listener = None
yield from self.async_run(variables)
self._unsub_delay_listener = track_point_in_utc_time( delay = action[CONF_DELAY]
if isinstance(delay, template.Template):
delay = vol.All(
cv.time_period,
cv.positive_timedelta)(
delay.async_render())
self._async_unsub_delay_listener = \
async_track_point_in_utc_time(
self.hass, script_delay, self.hass, script_delay,
date_util.utcnow() + delay) date_util.utcnow() + delay)
self._cur = cur + 1 self._cur = cur + 1
if self._change_listener: self._trigger_change_listener()
self._change_listener() return
return
elif CONF_CONDITION in action: elif CONF_CONDITION in action:
if not self._check_condition(action, variables): if not self._async_check_condition(action, variables):
break break
elif CONF_EVENT in action: elif CONF_EVENT in action:
self._fire_event(action) self._async_fire_event(action)
else: else:
self._call_service(action, variables) yield from self._async_call_service(action, variables)
self._cur = -1 self._cur = -1
self.last_action = None self.last_action = None
if self._change_listener: self._trigger_change_listener()
self._change_listener()
def stop(self) -> None: def stop(self) -> None:
"""Stop running script.""" """Stop running script."""
with self._lock: run_callback_threadsafe(self.hass.loop, self.async_stop).result()
if self._cur == -1:
return
self._cur = -1 def async_stop(self) -> None:
self._remove_listener() """Stop running script."""
if self._change_listener: if self._cur == -1:
self._change_listener() return
def _call_service(self, action, variables): self._cur = -1
self._async_remove_listener()
self._trigger_change_listener()
@asyncio.coroutine
def _async_call_service(self, action, variables):
"""Call the service specified in the action.""" """Call the service specified in the action."""
self.last_action = action.get(CONF_ALIAS, 'call service') self.last_action = action.get(CONF_ALIAS, 'call service')
self._log("Executing step %s" % self.last_action) self._log("Executing step %s" % self.last_action)
service.call_from_config(self.hass, action, True, variables, yield from service.async_call_from_config(
validate_config=False) self.hass, action, True, variables, validate_config=False)
def _fire_event(self, action): def _async_fire_event(self, action):
"""Fire an event.""" """Fire an event."""
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
self._log("Executing step %s" % self.last_action) self._log("Executing step %s" % self.last_action)
self.hass.bus.fire(action[CONF_EVENT], action.get(CONF_EVENT_DATA)) self.hass.bus.async_fire(action[CONF_EVENT],
action.get(CONF_EVENT_DATA))
def _check_condition(self, action, variables): def _async_check_condition(self, action, variables):
"""Test if condition is matching.""" """Test if condition is matching."""
self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION]) self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION])
check = condition.from_config(action, False)(self.hass, variables) check = condition.async_from_config(action, False)(
self.hass, variables)
self._log("Test condition {}: {}".format(self.last_action, check)) self._log("Test condition {}: {}".format(self.last_action, check))
return check return check
def _remove_listener(self): def _async_remove_listener(self):
"""Remove point in time listener, if any.""" """Remove point in time listener, if any."""
if self._unsub_delay_listener: if self._async_unsub_delay_listener:
self._unsub_delay_listener() self._async_unsub_delay_listener()
self._unsub_delay_listener = None self._async_unsub_delay_listener = None
def _log(self, msg): def _log(self, msg):
"""Logger helper.""" """Logger helper."""
@ -151,3 +165,10 @@ class Script():
msg = "Script {}: {}".format(self.name, msg) msg = "Script {}: {}".format(self.name, msg)
_LOGGER.info(msg) _LOGGER.info(msg)
def _trigger_change_listener(self):
"""Trigger the change listener."""
if not self._change_listener:
return
self.hass.async_add_job(self._change_listener)

View File

@ -1,4 +1,5 @@
"""Service calling related helpers.""" """Service calling related helpers."""
import asyncio
import functools import functools
import logging import logging
# pylint: disable=unused-import # pylint: disable=unused-import
@ -11,6 +12,7 @@ from homeassistant.core import HomeAssistant # NOQA
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.loader import get_component from homeassistant.loader import get_component
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util.async import run_coroutine_threadsafe
HASS = None # type: Optional[HomeAssistant] HASS = None # type: Optional[HomeAssistant]
@ -37,6 +39,15 @@ def service(domain, service_name):
def call_from_config(hass, config, blocking=False, variables=None, def call_from_config(hass, config, blocking=False, variables=None,
validate_config=True): validate_config=True):
"""Call a service based on a config hash.""" """Call a service based on a config hash."""
run_coroutine_threadsafe(
async_call_from_config(hass, config, blocking, variables,
validate_config), hass.loop).result()
@asyncio.coroutine
def async_call_from_config(hass, config, blocking=False, variables=None,
validate_config=True):
"""Call a service based on a config hash."""
if validate_config: if validate_config:
try: try:
config = cv.SERVICE_SCHEMA(config) config = cv.SERVICE_SCHEMA(config)
@ -49,7 +60,8 @@ def call_from_config(hass, config, blocking=False, variables=None,
else: else:
try: try:
config[CONF_SERVICE_TEMPLATE].hass = hass config[CONF_SERVICE_TEMPLATE].hass = hass
domain_service = config[CONF_SERVICE_TEMPLATE].render(variables) domain_service = config[CONF_SERVICE_TEMPLATE].async_render(
variables)
domain_service = cv.service(domain_service) domain_service = cv.service(domain_service)
except TemplateError as ex: except TemplateError as ex:
_LOGGER.error('Error rendering service name template: %s', ex) _LOGGER.error('Error rendering service name template: %s', ex)
@ -71,14 +83,15 @@ def call_from_config(hass, config, blocking=False, variables=None,
return {key: _data_template_creator(item) return {key: _data_template_creator(item)
for key, item in value.items()} for key, item in value.items()}
value.hass = hass value.hass = hass
return value.render(variables) return value.async_render(variables)
service_data.update(_data_template_creator( service_data.update(_data_template_creator(
config[CONF_SERVICE_DATA_TEMPLATE])) config[CONF_SERVICE_DATA_TEMPLATE]))
if CONF_SERVICE_ENTITY_ID in config: if CONF_SERVICE_ENTITY_ID in config:
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID] service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
hass.services.call(domain, service_name, service_data, blocking) yield from hass.services.async_call(
domain, service_name, service_data, blocking)
def extract_entity_ids(hass, service_call): def extract_entity_ids(hass, service_call):

View File

@ -2,7 +2,7 @@
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from homeassistant.bootstrap import _setup_component from homeassistant.bootstrap import setup_component
import homeassistant.components.automation as automation import homeassistant.components.automation as automation
from homeassistant.const import ATTR_ENTITY_ID from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -31,7 +31,7 @@ class TestAutomation(unittest.TestCase):
def test_service_data_not_a_dict(self): def test_service_data_not_a_dict(self):
"""Test service data not dict.""" """Test service data not dict."""
assert not _setup_component(self.hass, automation.DOMAIN, { assert not setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'trigger': { 'trigger': {
'platform': 'event', 'platform': 'event',
@ -46,7 +46,7 @@ class TestAutomation(unittest.TestCase):
def test_service_specify_data(self): def test_service_specify_data(self):
"""Test service data.""" """Test service data."""
assert _setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'alias': 'hello', 'alias': 'hello',
'trigger': { 'trigger': {
@ -77,7 +77,7 @@ class TestAutomation(unittest.TestCase):
def test_service_specify_entity_id(self): def test_service_specify_entity_id(self):
"""Test service data.""" """Test service data."""
assert _setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'trigger': { 'trigger': {
'platform': 'event', 'platform': 'event',
@ -98,7 +98,7 @@ class TestAutomation(unittest.TestCase):
def test_service_specify_entity_id_list(self): def test_service_specify_entity_id_list(self):
"""Test service data.""" """Test service data."""
assert _setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'trigger': { 'trigger': {
'platform': 'event', 'platform': 'event',
@ -119,7 +119,7 @@ class TestAutomation(unittest.TestCase):
def test_two_triggers(self): def test_two_triggers(self):
"""Test triggers.""" """Test triggers."""
assert _setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'trigger': [ 'trigger': [
{ {
@ -147,7 +147,7 @@ class TestAutomation(unittest.TestCase):
def test_two_conditions_with_and(self): def test_two_conditions_with_and(self):
"""Test two and conditions.""" """Test two and conditions."""
entity_id = 'test.entity' entity_id = 'test.entity'
assert _setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'trigger': [ 'trigger': [
{ {
@ -188,123 +188,9 @@ class TestAutomation(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(1, len(self.calls)) self.assertEqual(1, len(self.calls))
def test_two_conditions_with_or(self):
"""Test two or conditions."""
entity_id = 'test.entity'
assert _setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: {
'trigger': [
{
'platform': 'event',
'event_type': 'test_event',
},
],
'condition_type': 'OR',
'condition': [
{
'platform': 'state',
'entity_id': entity_id,
'state': '200'
},
{
'platform': 'numeric_state',
'entity_id': entity_id,
'below': 150
}
],
'action': {
'service': 'test.automation',
}
}
})
self.hass.states.set(entity_id, 200)
self.hass.bus.fire('test_event')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.hass.states.set(entity_id, 100)
self.hass.bus.fire('test_event')
self.hass.block_till_done()
self.assertEqual(2, len(self.calls))
self.hass.states.set(entity_id, 250)
self.hass.bus.fire('test_event')
self.hass.block_till_done()
self.assertEqual(2, len(self.calls))
def test_using_trigger_as_condition(self):
"""Test triggers as condition."""
entity_id = 'test.entity'
assert _setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: {
'trigger': [
{
'platform': 'state',
'entity_id': entity_id,
'from': '120',
'state': '100'
},
{
'platform': 'numeric_state',
'entity_id': entity_id,
'below': 150
}
],
'condition': 'use_trigger_values',
'action': {
'service': 'test.automation',
}
}
})
self.hass.states.set(entity_id, 100)
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.hass.states.set(entity_id, 120)
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.hass.states.set(entity_id, 100)
self.hass.block_till_done()
self.assertEqual(2, len(self.calls))
self.hass.states.set(entity_id, 151)
self.hass.block_till_done()
self.assertEqual(2, len(self.calls))
def test_using_trigger_as_condition_with_invalid_condition(self):
"""Event is not a valid condition."""
entity_id = 'test.entity'
self.hass.states.set(entity_id, 100)
assert _setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: {
'trigger': [
{
'platform': 'event',
'event_type': 'test_event',
},
{
'platform': 'numeric_state',
'entity_id': entity_id,
'below': 150
}
],
'condition': 'use_trigger_values',
'action': {
'service': 'test.automation',
}
}
})
self.hass.bus.fire('test_event')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
def test_automation_list_setting(self): def test_automation_list_setting(self):
"""Event is not a valid condition.""" """Event is not a valid condition."""
self.assertTrue(_setup_component(self.hass, automation.DOMAIN, { self.assertTrue(setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: [{ automation.DOMAIN: [{
'trigger': { 'trigger': {
'platform': 'event', 'platform': 'event',
@ -335,7 +221,7 @@ class TestAutomation(unittest.TestCase):
def test_automation_calling_two_actions(self): def test_automation_calling_two_actions(self):
"""Test if we can call two actions from automation definition.""" """Test if we can call two actions from automation definition."""
self.assertTrue(_setup_component(self.hass, automation.DOMAIN, { self.assertTrue(setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'trigger': { 'trigger': {
'platform': 'event', 'platform': 'event',
@ -366,7 +252,7 @@ class TestAutomation(unittest.TestCase):
assert self.hass.states.get(entity_id) is None assert self.hass.states.get(entity_id) is None
assert not automation.is_on(self.hass, entity_id) assert not automation.is_on(self.hass, entity_id)
assert _setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'alias': 'hello', 'alias': 'hello',
'trigger': { 'trigger': {
@ -433,7 +319,7 @@ class TestAutomation(unittest.TestCase):
}) })
def test_reload_config_service(self, mock_load_yaml): def test_reload_config_service(self, mock_load_yaml):
"""Test the reload config service.""" """Test the reload config service."""
assert _setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'alias': 'hello', 'alias': 'hello',
'trigger': { 'trigger': {
@ -483,7 +369,7 @@ class TestAutomation(unittest.TestCase):
}) })
def test_reload_config_when_invalid_config(self, mock_load_yaml): def test_reload_config_when_invalid_config(self, mock_load_yaml):
"""Test the reload config service handling invalid config.""" """Test the reload config service handling invalid config."""
assert _setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'alias': 'hello', 'alias': 'hello',
'trigger': { 'trigger': {
@ -517,7 +403,7 @@ class TestAutomation(unittest.TestCase):
def test_reload_config_handles_load_fails(self): def test_reload_config_handles_load_fails(self):
"""Test the reload config service.""" """Test the reload config service."""
assert _setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'alias': 'hello', 'alias': 'hello',
'trigger': { 'trigger': {

View File

@ -499,7 +499,7 @@ class TestAutomationNumericState(unittest.TestCase):
'event_type': 'test_event', 'event_type': 'test_event',
}, },
'condition': { 'condition': {
'platform': 'numeric_state', 'condition': 'numeric_state',
'entity_id': entity_id, 'entity_id': entity_id,
'above': test_state, 'above': test_state,
'below': test_state + 2 'below': test_state + 2

View File

@ -213,7 +213,7 @@ class TestAutomationState(unittest.TestCase):
'event_type': 'test_event', 'event_type': 'test_event',
}, },
'condition': [{ 'condition': [{
'platform': 'state', 'condition': 'state',
'entity_id': entity_id, 'entity_id': entity_id,
'state': test_state 'state': test_state
}], }],
@ -360,7 +360,7 @@ class TestAutomationState(unittest.TestCase):
'event_type': 'test_event', 'event_type': 'test_event',
}, },
'condition': { 'condition': {
'platform': 'state', 'condition': 'state',
'entity_id': 'test.entity', 'entity_id': 'test.entity',
'state': 'on', 'state': 'on',
'for': { 'for': {

View File

@ -172,7 +172,7 @@ class TestAutomationSun(unittest.TestCase):
'event_type': 'test_event', 'event_type': 'test_event',
}, },
'condition': { 'condition': {
'platform': 'sun', 'condition': 'sun',
'before': 'sunrise', 'before': 'sunrise',
}, },
'action': { 'action': {
@ -208,7 +208,7 @@ class TestAutomationSun(unittest.TestCase):
'event_type': 'test_event', 'event_type': 'test_event',
}, },
'condition': { 'condition': {
'platform': 'sun', 'condition': 'sun',
'after': 'sunrise', 'after': 'sunrise',
}, },
'action': { 'action': {
@ -244,7 +244,7 @@ class TestAutomationSun(unittest.TestCase):
'event_type': 'test_event', 'event_type': 'test_event',
}, },
'condition': { 'condition': {
'platform': 'sun', 'condition': 'sun',
'before': 'sunrise', 'before': 'sunrise',
'before_offset': '+1:00:00' 'before_offset': '+1:00:00'
}, },
@ -281,7 +281,7 @@ class TestAutomationSun(unittest.TestCase):
'event_type': 'test_event', 'event_type': 'test_event',
}, },
'condition': { 'condition': {
'platform': 'sun', 'condition': 'sun',
'after': 'sunrise', 'after': 'sunrise',
'after_offset': '+1:00:00' 'after_offset': '+1:00:00'
}, },
@ -319,7 +319,7 @@ class TestAutomationSun(unittest.TestCase):
'event_type': 'test_event', 'event_type': 'test_event',
}, },
'condition': { 'condition': {
'platform': 'sun', 'condition': 'sun',
'after': 'sunrise', 'after': 'sunrise',
'before': 'sunset' 'before': 'sunset'
}, },
@ -365,7 +365,7 @@ class TestAutomationSun(unittest.TestCase):
'event_type': 'test_event', 'event_type': 'test_event',
}, },
'condition': { 'condition': {
'platform': 'sun', 'condition': 'sun',
'after': 'sunset', 'after': 'sunset',
}, },
'action': { 'action': {

View File

@ -339,7 +339,7 @@ class TestAutomationTemplate(unittest.TestCase):
'event_type': 'test_event', 'event_type': 'test_event',
}, },
'condition': [{ 'condition': [{
'platform': 'template', 'condition': 'template',
'value_template': '{{ is_state("test.entity", "world") }}' 'value_template': '{{ is_state("test.entity", "world") }}'
}], }],
'action': { 'action': {

View File

@ -250,7 +250,7 @@ class TestAutomationTime(unittest.TestCase):
'event_type': 'test_event' 'event_type': 'test_event'
}, },
'condition': { 'condition': {
'platform': 'time', 'condition': 'time',
'before': '10:00', 'before': '10:00',
}, },
'action': { 'action': {
@ -285,7 +285,7 @@ class TestAutomationTime(unittest.TestCase):
'event_type': 'test_event' 'event_type': 'test_event'
}, },
'condition': { 'condition': {
'platform': 'time', 'condition': 'time',
'after': '10:00', 'after': '10:00',
}, },
'action': { 'action': {
@ -320,7 +320,7 @@ class TestAutomationTime(unittest.TestCase):
'event_type': 'test_event' 'event_type': 'test_event'
}, },
'condition': { 'condition': {
'platform': 'time', 'condition': 'time',
'weekday': 'mon', 'weekday': 'mon',
}, },
'action': { 'action': {
@ -356,7 +356,7 @@ class TestAutomationTime(unittest.TestCase):
'event_type': 'test_event' 'event_type': 'test_event'
}, },
'condition': { 'condition': {
'platform': 'time', 'condition': 'time',
'weekday': ['mon', 'tue'], 'weekday': ['mon', 'tue'],
}, },
'action': { 'action': {

View File

@ -197,7 +197,7 @@ class TestAutomationZone(unittest.TestCase):
'event_type': 'test_event' 'event_type': 'test_event'
}, },
'condition': { 'condition': {
'platform': 'zone', 'condition': 'zone',
'entity_id': 'test.entity', 'entity_id': 'test.entity',
'zone': 'zone.test', 'zone': 'zone.test',
}, },

View File

@ -119,7 +119,7 @@ class TestBinarySensorTemplate(unittest.TestCase):
vs.update_ha_state() vs.update_ha_state()
self.hass.block_till_done() self.hass.block_till_done()
with mock.patch.object(vs, 'update') as mock_update: with mock.patch.object(vs, 'async_update') as mock_update:
self.hass.bus.fire(EVENT_STATE_CHANGED) self.hass.bus.fire(EVENT_STATE_CHANGED)
self.hass.block_till_done() self.hass.block_till_done()
assert mock_update.call_count == 1 assert mock_update.call_count == 1

View File

@ -53,7 +53,12 @@ def test_async_update_support(event_loop):
assert len(sync_update) == 1 assert len(sync_update) == 1
assert len(async_update) == 0 assert len(async_update) == 0
ent.async_update = lambda: async_update.append(1) @asyncio.coroutine
def async_update_func():
"""Async update."""
async_update.append(1)
ent.async_update = async_update_func
event_loop.run_until_complete(test()) event_loop.run_until_complete(test())
@ -95,3 +100,19 @@ class TestHelpersEntity(object):
assert entity.generate_entity_id( assert entity.generate_entity_id(
fmt, 'overwrite hidden true', fmt, 'overwrite hidden true',
hass=self.hass) == 'test.overwrite_hidden_true_2' hass=self.hass) == 'test.overwrite_hidden_true_2'
def test_update_calls_async_update_if_available(self):
"""Test async update getting called."""
async_update = []
class AsyncEntity(entity.Entity):
hass = self.hass
entity_id = 'sensor.test'
@asyncio.coroutine
def async_update(self):
async_update.append([1])
ent = AsyncEntity()
ent.update()
assert len(async_update) == 1