From 4acb12168979b77785b0b7c9e8b19bf5f8acdcbf Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 21 Apr 2016 12:22:19 -0700 Subject: [PATCH 01/11] Allow variables in service.call_from_config --- homeassistant/helpers/config_validation.py | 8 +-- homeassistant/helpers/service.py | 76 ++++++++++------------ tests/helpers/test_service.py | 22 +++++++ 3 files changed, 60 insertions(+), 46 deletions(-) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 923d071231a..51684e5f1cd 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -231,8 +231,8 @@ EVENT_SCHEMA = vol.Schema({ SERVICE_SCHEMA = vol.All(vol.Schema({ vol.Exclusive('service', 'service name'): service, - vol.Exclusive('service_template', 'service name'): string, - vol.Exclusive('data', 'service data'): dict, - vol.Exclusive('data_template', 'service data'): {match_all: template}, - 'entity_id': entity_ids, + vol.Exclusive('service_template', 'service name'): template, + vol.Optional('data'): dict, + vol.Optional('data_template'): {match_all: template}, + vol.Optional('entity_id'): entity_ids, }), has_at_least_one_key('service', 'service_template')) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 8f366352532..8b89d856c50 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -2,9 +2,13 @@ import functools import logging +import voluptuous as vol + from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.exceptions import TemplateError from homeassistant.helpers import template from homeassistant.loader import get_component +import homeassistant.helpers.config_validation as cv HASS = None @@ -28,47 +32,38 @@ def service(domain, service_name): return register_service_decorator -def call_from_config(hass, config, blocking=False): +def call_from_config(hass, config, blocking=False, variables=None): """Call a service based on a config hash.""" - validation_error = validate_service_call(config) - if validation_error: - _LOGGER.error(validation_error) - return - - domain_service = ( - config[CONF_SERVICE] - if CONF_SERVICE in config - else template.render(hass, config[CONF_SERVICE_TEMPLATE])) - try: - domain, service_name = domain_service.split('.', 1) - except ValueError: - _LOGGER.error('Invalid service specified: %s', domain_service) + config = cv.SERVICE_SCHEMA(config) + except vol.Invalid as ex: + _LOGGER.error("Invalid config for calling service: %s", ex) return - service_data = config.get(CONF_SERVICE_DATA) - - if service_data is None: - service_data = {} - elif isinstance(service_data, dict): - service_data = dict(service_data) + if CONF_SERVICE in config: + domain_service = config[CONF_SERVICE] else: - _LOGGER.error("%s should be a dictionary", CONF_SERVICE_DATA) - service_data = {} + try: + domain_service = template.render( + hass, config[CONF_SERVICE_TEMPLATE], variables) + domain_service = cv.service(domain_service) + except TemplateError as ex: + _LOGGER.error('Error rendering service name template: %s', ex) + return + except vol.Invalid as ex: + _LOGGER.error('Template rendered invalid service: %s', + domain_service) + return - service_data_template = config.get(CONF_SERVICE_DATA_TEMPLATE) - if service_data_template and isinstance(service_data_template, dict): - for key, value in service_data_template.items(): - service_data[key] = template.render(hass, value) - elif service_data_template: - _LOGGER.error("%s should be a dictionary", CONF_SERVICE_DATA) + domain, service_name = domain_service.split('.', 1) + service_data = dict(config.get(CONF_SERVICE_DATA, {})) - entity_id = config.get(CONF_SERVICE_ENTITY_ID) - if isinstance(entity_id, str): - service_data[ATTR_ENTITY_ID] = [ent.strip() for ent in - entity_id.split(",")] - elif entity_id is not None: - service_data[ATTR_ENTITY_ID] = entity_id + if CONF_SERVICE_DATA_TEMPLATE in config: + for key, value in config[CONF_SERVICE_DATA_TEMPLATE].items(): + service_data[key] = template.render(hass, value, variables) + + if CONF_SERVICE_ENTITY_ID in config: + service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID] hass.services.call(domain, service_name, service_data, blocking) @@ -98,11 +93,8 @@ def validate_service_call(config): Helper method to validate that a configuration is a valid service call. Returns None if validation succeeds, else an error description """ - if not isinstance(config, dict): - return 'Invalid configuration {}'.format(config) - if CONF_SERVICE not in config and CONF_SERVICE_TEMPLATE not in config: - return 'Missing key {} or {}: {}'.format( - CONF_SERVICE, - CONF_SERVICE_TEMPLATE, - config) - return None + try: + cv.SERVICE_SCHEMA(config) + return None + except vol.Invalid as ex: + return str(ex) diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index be224b51ff0..59ba1781ab2 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import patch +import homeassistant.components # noqa - to prevent circular import from homeassistant import core as ha, loader from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID from homeassistant.helpers import service @@ -53,6 +54,27 @@ class TestServiceHelpers(unittest.TestCase): self.assertEqual('goodbye', runs[0].data['hello']) + def test_passing_variables_to_templates(self): + config = { + 'service_template': '{{ var_service }}', + 'entity_id': 'hello.world', + 'data_template': { + 'hello': '{{ var_data }}', + }, + } + runs = [] + + decor = service.service('test_domain', 'test_service') + decor(lambda x, y: runs.append(y)) + + service.call_from_config(self.hass, config, variables={ + 'var_service': 'test_domain.test_service', + 'var_data': 'goodbye', + }) + self.hass.pool.block_till_done() + + self.assertEqual('goodbye', runs[0].data['hello']) + def test_split_entity_string(self): """Test splitting of entity string.""" service.call_from_config(self.hass, { From c4913a87e42df6057842dc2071149770d10ff370 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 21 Apr 2016 12:27:23 -0700 Subject: [PATCH 02/11] Alexa: Expose intent variables to service calls --- homeassistant/components/alexa.py | 2 +- tests/components/test_alexa.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/alexa.py b/homeassistant/components/alexa.py index 806d6874a8d..bb9e1816a68 100644 --- a/homeassistant/components/alexa.py +++ b/homeassistant/components/alexa.py @@ -91,7 +91,7 @@ def _handle_alexa(handler, path_match, data): card['content']) if action is not None: - call_from_config(handler.server.hass, action, True) + call_from_config(handler.server.hass, action, True, response.variables) handler.write_json(response.as_dict()) diff --git a/tests/components/test_alexa.py b/tests/components/test_alexa.py index b004ab642ed..03fa5c2d33c 100644 --- a/tests/components/test_alexa.py +++ b/tests/components/test_alexa.py @@ -71,8 +71,8 @@ def setUpModule(): # pylint: disable=invalid-name }, 'action': { 'service': 'test.alexa', - 'data': { - 'hello': 1 + 'data_template': { + 'hello': '{{ ZodiacSign }}' }, 'entity_id': 'switch.test', } @@ -278,6 +278,12 @@ class TestAlexa(unittest.TestCase): 'timestamp': '2015-05-13T12:34:56Z', 'intent': { 'name': 'CallServiceIntent', + 'slots': { + 'ZodiacSign': { + 'name': 'ZodiacSign', + 'value': 'virgo', + } + } } } } @@ -289,7 +295,7 @@ class TestAlexa(unittest.TestCase): self.assertEqual('test', call.domain) self.assertEqual('alexa', call.service) self.assertEqual(['switch.test'], call.data.get('entity_id')) - self.assertEqual(1, call.data.get('hello')) + self.assertEqual('virgo', call.data.get('hello')) def test_session_ended_request(self): """Test the request for ending the session.""" From 4e568f8b99e45d1642887d4770708ea28a33aa18 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 21 Apr 2016 13:59:42 -0700 Subject: [PATCH 03/11] Automation: Add trigger context and expose to action --- .../components/automation/__init__.py | 20 ++++----- homeassistant/components/automation/event.py | 7 ++- homeassistant/components/automation/mqtt.py | 9 +++- .../components/automation/numeric_state.py | 33 +++++++++++--- homeassistant/components/automation/state.py | 43 ++++++++++++------- homeassistant/components/automation/sun.py | 16 +++++-- .../components/automation/template.py | 23 +++++++--- homeassistant/components/automation/time.py | 9 +++- homeassistant/components/automation/zone.py | 23 +++++++--- homeassistant/helpers/event.py | 7 ++- tests/components/automation/test_init.py | 7 ++- tests/components/automation/test_mqtt.py | 10 ++++- .../automation/test_numeric_state.py | 17 +++++++- tests/components/automation/test_state.py | 14 +++++- tests/components/automation/test_sun.py | 6 +++ tests/components/automation/test_template.py | 13 +++++- tests/components/automation/test_time.py | 7 ++- tests/components/automation/test_zone.py | 10 +++++ tests/helpers/test_event.py | 24 ++++++++++- tests/helpers/test_service.py | 3 +- 20 files changed, 232 insertions(+), 69 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 7e13ae3ed75..8cbaf35a5c4 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -122,12 +122,11 @@ def _setup_automation(hass, config_block, name, config): def _get_action(hass, config, name): """Return an action based on a configuration.""" - def action(): + def action(variables=None): """Action to be executed.""" _LOGGER.info('Executing %s', name) logbook.log_entry(hass, name, 'has been triggered', DOMAIN) - - call_from_config(hass, config) + call_from_config(hass, config, variables=variables) return action @@ -159,24 +158,21 @@ def _process_if(hass, config, p_config, action): checks.append(check) if cond_type == CONDITION_TYPE_AND: - def if_action(): + def if_action(variables=None): """AND all conditions.""" - if all(check() for check in checks): - action() + if all(check(variables) for check in checks): + action(variables) else: - def if_action(): + def if_action(variables=None): """OR all conditions.""" - if any(check() for check in checks): - action() + if any(check(variables) for check in checks): + action(variables) return if_action def _process_trigger(hass, config, trigger_configs, name, action): """Setup the triggers.""" - if isinstance(trigger_configs, dict): - trigger_configs = [trigger_configs] - for conf in trigger_configs: platform = _resolve_platform(METHOD_TRIGGER, hass, config, conf.get(CONF_PLATFORM)) diff --git a/homeassistant/components/automation/event.py b/homeassistant/components/automation/event.py index 80dd6c29f6b..46b5b4ef10d 100644 --- a/homeassistant/components/automation/event.py +++ b/homeassistant/components/automation/event.py @@ -26,7 +26,12 @@ def trigger(hass, config, action): """Listen for events and calls the action when data matches.""" if not event_data or all(val == event.data.get(key) for key, val in event_data.items()): - action() + action({ + 'trigger': { + 'platform': 'event', + 'event': event, + }, + }) hass.bus.listen(event_type, handle_event) return True diff --git a/homeassistant/components/automation/mqtt.py b/homeassistant/components/automation/mqtt.py index db0c1be7c2a..e4a6b221e04 100644 --- a/homeassistant/components/automation/mqtt.py +++ b/homeassistant/components/automation/mqtt.py @@ -30,7 +30,14 @@ def trigger(hass, config, action): def mqtt_automation_listener(msg_topic, msg_payload, qos): """Listen for MQTT messages.""" if payload is None or payload == msg_payload: - action() + action({ + 'trigger': { + 'platform': 'mqtt', + 'topic': msg_topic, + 'payload': msg_payload, + 'qos': qos, + } + }) mqtt.subscribe(hass, topic, mqtt_automation_listener) diff --git a/homeassistant/components/automation/numeric_state.py b/homeassistant/components/automation/numeric_state.py index 74f5b3ba805..6ed2add0b25 100644 --- a/homeassistant/components/automation/numeric_state.py +++ b/homeassistant/components/automation/numeric_state.py @@ -18,12 +18,15 @@ CONF_ABOVE = "above" _LOGGER = logging.getLogger(__name__) -def _renderer(hass, value_template, state): +def _renderer(hass, value_template, state, variables=None): """Render the state value.""" if value_template is None: return state.state - return template.render(hass, value_template, {'state': state}) + variables = dict(variables or {}) + variables['state'] = state + + return template.render(hass, value_template, variables) def trigger(hass, config, action): @@ -50,9 +53,27 @@ def trigger(hass, config, action): def state_automation_listener(entity, from_s, to_s): """Listen for state changes and calls action.""" # Fire action if we go from outside range into range - if _in_range(above, below, renderer(to_s)) and \ - (from_s is None or not _in_range(above, below, renderer(from_s))): - action() + if to_s is None: + return + + variables = { + 'trigger': { + 'platform': 'numeric_state', + 'entity_id': entity_id, + 'below': below, + 'above': above, + } + } + to_s_value = renderer(to_s, variables) + from_s_value = None if from_s is None else renderer(from_s, variables) + if _in_range(above, below, to_s_value) and \ + (from_s is None or not _in_range(above, below, from_s_value)): + variables['trigger']['from_state'] = from_s + variables['trigger']['from_value'] = from_s_value + variables['trigger']['to_state'] = to_s + variables['trigger']['to_value'] = to_s_value + + action(variables) track_state_change( hass, entity_id, state_automation_listener) @@ -80,7 +101,7 @@ def if_action(hass, config): renderer = partial(_renderer, hass, value_template) - def if_numeric_state(): + def if_numeric_state(variables): """Test numeric state condition.""" state = hass.states.get(entity_id) return state is not None and _in_range(above, below, renderer(state)) diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py index 742e6195949..802debbe63e 100644 --- a/homeassistant/components/automation/state.py +++ b/homeassistant/components/automation/state.py @@ -73,29 +73,42 @@ def trigger(hass, config, action): def state_automation_listener(entity, from_s, to_s): """Listen for state changes and calls action.""" + def call_action(): + """Call action with right context.""" + action({ + 'trigger': { + 'platform': 'state', + 'entity_id': entity, + 'from_state': from_s, + 'to_state': to_s, + 'for': time_delta, + } + }) + + if time_delta is None: + call_action() + return + def state_for_listener(now): """Fire on state changes after a delay and calls action.""" hass.bus.remove_listener( - EVENT_STATE_CHANGED, for_state_listener) - action() + EVENT_STATE_CHANGED, attached_state_for_cancel_listener) + call_action() def state_for_cancel_listener(entity, inner_from_s, inner_to_s): """Fire on changes and cancel for listener if changed.""" if inner_to_s == to_s: return - hass.bus.remove_listener(EVENT_TIME_CHANGED, for_time_listener) - hass.bus.remove_listener( - EVENT_STATE_CHANGED, for_state_listener) + hass.bus.remove_listener(EVENT_TIME_CHANGED, + attached_state_for_listener) + hass.bus.remove_listener(EVENT_STATE_CHANGED, + attached_state_for_cancel_listener) - if time_delta is not None: - target_tm = dt_util.utcnow() + time_delta - for_time_listener = track_point_in_time( - hass, state_for_listener, target_tm) - for_state_listener = track_state_change( - hass, entity_id, state_for_cancel_listener, - MATCH_ALL, MATCH_ALL) - else: - action() + attached_state_for_listener = track_point_in_time( + hass, state_for_listener, dt_util.utcnow() + time_delta) + + attached_state_for_cancel_listener = track_state_change( + hass, entity_id, state_for_cancel_listener) track_state_change( hass, entity_id, state_automation_listener, from_state, to_state) @@ -109,7 +122,7 @@ def if_action(hass, config): state = config.get(CONF_STATE) time_delta = get_time_config(config) - def if_state(): + def if_state(variables): """Test if condition.""" is_state = hass.states.is_state(entity_id, state) return (time_delta is None and is_state or diff --git a/homeassistant/components/automation/sun.py b/homeassistant/components/automation/sun.py index 2a564a3b588..c9db88a83c2 100644 --- a/homeassistant/components/automation/sun.py +++ b/homeassistant/components/automation/sun.py @@ -55,11 +55,21 @@ def trigger(hass, config, action): event = config.get(CONF_EVENT) offset = config.get(CONF_OFFSET) + def call_action(): + """Call action with right context.""" + action({ + 'trigger': { + 'platform': 'sun', + 'event': event, + 'offset': offset, + }, + }) + # Do something to call action if event == EVENT_SUNRISE: - track_sunrise(hass, action, offset) + track_sunrise(hass, call_action, offset) else: - track_sunset(hass, action, offset) + track_sunset(hass, call_action, offset) return True @@ -97,7 +107,7 @@ def if_action(hass, config): """Return time after sunset.""" return sun.next_setting(hass) + after_offset - def time_if(): + def time_if(variables): """Validate time based if-condition.""" now = dt_util.now() before = before_func() diff --git a/homeassistant/components/automation/template.py b/homeassistant/components/automation/template.py index 02e8f30d209..66c20518c7e 100644 --- a/homeassistant/components/automation/template.py +++ b/homeassistant/components/automation/template.py @@ -9,9 +9,10 @@ import logging import voluptuous as vol from homeassistant.const import ( - CONF_VALUE_TEMPLATE, EVENT_STATE_CHANGED, CONF_PLATFORM) + CONF_VALUE_TEMPLATE, CONF_PLATFORM, MATCH_ALL) from homeassistant.exceptions import TemplateError from homeassistant.helpers import template +from homeassistant.helpers.event import track_state_change import homeassistant.helpers.config_validation as cv @@ -30,7 +31,7 @@ def trigger(hass, config, action): # Local variable to keep track of if the action has already been triggered already_triggered = False - def event_listener(event): + def state_changed_listener(entity_id, from_s, to_s): """Listen for state changes and calls action.""" nonlocal already_triggered template_result = _check_template(hass, value_template) @@ -38,11 +39,18 @@ def trigger(hass, config, action): # Check to see if template returns true if template_result and not already_triggered: already_triggered = True - action() + action({ + 'trigger': { + 'platform': 'template', + 'entity_id': entity_id, + 'from_state': from_s, + 'to_state': to_s, + }, + }) elif not template_result: already_triggered = False - hass.bus.listen(EVENT_STATE_CHANGED, event_listener) + track_state_change(hass, MATCH_ALL, state_changed_listener) return True @@ -50,13 +58,14 @@ def if_action(hass, config): """Wrap action method with state based condition.""" value_template = config.get(CONF_VALUE_TEMPLATE) - return lambda: _check_template(hass, value_template) + return lambda variables: _check_template(hass, value_template, + variables=variables) -def _check_template(hass, value_template): +def _check_template(hass, value_template, variables=None): """Check if result of template is true.""" try: - value = template.render(hass, value_template, {}) + value = template.render(hass, value_template, variables) except TemplateError as ex: if ex.args and ex.args[0].startswith( "UndefinedError: 'None' has no attribute"): diff --git a/homeassistant/components/automation/time.py b/homeassistant/components/automation/time.py index 761ad73b826..879b0e113d9 100644 --- a/homeassistant/components/automation/time.py +++ b/homeassistant/components/automation/time.py @@ -41,7 +41,12 @@ def trigger(hass, config, action): def time_automation_listener(now): """Listen for time changes and calls action.""" - action() + action({ + 'trigger': { + 'platform': 'time', + 'now': now, + }, + }) track_time_change(hass, time_automation_listener, hour=hours, minute=minutes, second=seconds) @@ -73,7 +78,7 @@ def if_action(hass, config): _error_time(after, CONF_AFTER) return None - def time_if(): + def time_if(variables): """Validate time based if-condition.""" now = dt_util.now() if before is not None and now > now.replace(hour=before.hour, diff --git a/homeassistant/components/automation/zone.py b/homeassistant/components/automation/zone.py index 66ea3c2d7c7..fd798f45549 100644 --- a/homeassistant/components/automation/zone.py +++ b/homeassistant/components/automation/zone.py @@ -48,13 +48,22 @@ def trigger(hass, config, action): to_s.attributes.get(ATTR_LONGITUDE)): return - from_match = _in_zone(hass, zone_entity_id, from_s) if from_s else None - to_match = _in_zone(hass, zone_entity_id, to_s) + zone_state = hass.states.get(zone_entity_id) + from_match = _in_zone(hass, zone_state, from_s) if from_s else None + to_match = _in_zone(hass, zone_state, to_s) # pylint: disable=too-many-boolean-expressions if event == EVENT_ENTER and not from_match and to_match or \ event == EVENT_LEAVE and from_match and not to_match: - action() + action({ + 'trigger': { + 'platform': 'zone', + 'entity_id': entity, + 'from_state': from_s, + 'to_state': to_s, + 'zone': zone_state, + }, + }) track_state_change( hass, entity_id, zone_automation_listener, MATCH_ALL, MATCH_ALL) @@ -67,20 +76,20 @@ def if_action(hass, config): entity_id = config.get(CONF_ENTITY_ID) zone_entity_id = config.get(CONF_ZONE) - def if_in_zone(): + def if_in_zone(variables): """Test if condition.""" - return _in_zone(hass, zone_entity_id, hass.states.get(entity_id)) + zone_state = hass.states.get(zone_entity_id) + return _in_zone(hass, zone_state, hass.states.get(entity_id)) return if_in_zone -def _in_zone(hass, zone_entity_id, state): +def _in_zone(hass, zone_state, state): """Check if state is in zone.""" if not state or None in (state.attributes.get(ATTR_LATITUDE), state.attributes.get(ATTR_LONGITUDE)): return False - zone_state = hass.states.get(zone_entity_id) return zone_state and zone.in_zone( zone_state, state.attributes.get(ATTR_LATITUDE), state.attributes.get(ATTR_LONGITUDE), diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index b6a08cc59d0..50a7b290cc8 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -21,7 +21,9 @@ def track_state_change(hass, entity_ids, action, from_state=None, to_state = _process_match_param(to_state) # Ensure it is a lowercase list with entity ids we want to match on - if isinstance(entity_ids, str): + if entity_ids == MATCH_ALL: + pass + elif isinstance(entity_ids, str): entity_ids = (entity_ids.lower(),) else: entity_ids = tuple(entity_id.lower() for entity_id in entity_ids) @@ -29,7 +31,8 @@ def track_state_change(hass, entity_ids, action, from_state=None, @ft.wraps(action) def state_change_listener(event): """The listener that listens for specific state changes.""" - if event.data['entity_id'] not in entity_ids: + if entity_ids != MATCH_ALL and \ + event.data['entity_id'] not in entity_ids: return if event.data['old_state'] is None: diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 355865e9e9c..f6d33c18071 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -51,7 +51,10 @@ class TestAutomation(unittest.TestCase): }, 'action': { 'service': 'test.automation', - 'data': {'some': 'data'} + 'data_template': { + 'some': '{{ trigger.platform }} - ' + '{{ trigger.event.event_type }}' + }, } } }) @@ -59,7 +62,7 @@ class TestAutomation(unittest.TestCase): self.hass.bus.fire('test_event') self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) - self.assertEqual('data', self.calls[0].data['some']) + self.assertEqual('event - test_event', self.calls[0].data['some']) def test_service_specify_entity_id(self): """Test service data.""" diff --git a/tests/components/automation/test_mqtt.py b/tests/components/automation/test_mqtt.py index 0fd2a9aef06..29d55b424f2 100644 --- a/tests/components/automation/test_mqtt.py +++ b/tests/components/automation/test_mqtt.py @@ -35,14 +35,20 @@ class TestAutomationMQTT(unittest.TestCase): 'topic': 'test-topic' }, 'action': { - 'service': 'test.automation' + 'service': 'test.automation', + 'data_template': { + 'some': '{{ trigger.platform }} - {{ trigger.topic }}' + ' - {{ trigger.payload }}' + }, } } }) - fire_mqtt_message(self.hass, 'test-topic', '') + fire_mqtt_message(self.hass, 'test-topic', 'test_payload') self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) + self.assertEqual('mqtt - test-topic - test_payload', + self.calls[0].data['some']) def test_if_fires_on_topic_and_payload_match(self): """Test if message is fired on topic and payload match.""" diff --git a/tests/components/automation/test_numeric_state.py b/tests/components/automation/test_numeric_state.py index ee29c0fb56f..37df19e38ed 100644 --- a/tests/components/automation/test_numeric_state.py +++ b/tests/components/automation/test_numeric_state.py @@ -437,15 +437,28 @@ class TestAutomationNumericState(unittest.TestCase): 'below': 10, }, 'action': { - 'service': 'test.automation' + 'service': 'test.automation', + 'data_template': { + 'some': '{{ trigger.%s }}' % '}} - {{ trigger.'.join(( + 'platform', 'entity_id', 'below', 'above', + 'from_state.state', 'from_value', + 'to_state.state', 'to_value')) + }, } } }) # 9 is below 10 - self.hass.states.set('test.entity', 'entity', + self.hass.states.set('test.entity', 'test state 1', + {'test_attribute': '1.2'}) + self.hass.pool.block_till_done() + self.hass.states.set('test.entity', 'test state 2', {'test_attribute': '0.9'}) self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) + self.assertEqual( + 'numeric_state - test.entity - 10 - None - test state 1 - 12.0 - ' + 'test state 2 - 9.0', + self.calls[0].data['some']) def test_not_fires_on_attr_change_with_attr_not_below_multiple_attr(self): """"Test if not fired changed attributes.""" diff --git a/tests/components/automation/test_state.py b/tests/components/automation/test_state.py index 2f688249834..4a6971124b6 100644 --- a/tests/components/automation/test_state.py +++ b/tests/components/automation/test_state.py @@ -31,6 +31,9 @@ class TestAutomationState(unittest.TestCase): def test_if_fires_on_entity_change(self): """Test for firing on entity change.""" + self.hass.states.set('test.entity', 'hello') + self.hass.pool.block_till_done() + assert _setup_component(self.hass, automation.DOMAIN, { automation.DOMAIN: { 'trigger': { @@ -38,7 +41,13 @@ class TestAutomationState(unittest.TestCase): 'entity_id': 'test.entity', }, 'action': { - 'service': 'test.automation' + 'service': 'test.automation', + 'data_template': { + 'some': '{{ trigger.%s }}' % '}} - {{ trigger.'.join(( + 'platform', 'entity_id', + 'from_state.state', 'to_state.state', + 'for')) + }, } } }) @@ -46,6 +55,9 @@ class TestAutomationState(unittest.TestCase): self.hass.states.set('test.entity', 'world') self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) + self.assertEqual( + 'state - test.entity - hello - world - None', + self.calls[0].data['some']) def test_if_fires_on_entity_change_with_from_filter(self): """Test for firing on entity change with filter.""" diff --git a/tests/components/automation/test_sun.py b/tests/components/automation/test_sun.py index 738c171ce6c..1975dc8da44 100644 --- a/tests/components/automation/test_sun.py +++ b/tests/components/automation/test_sun.py @@ -105,6 +105,11 @@ class TestAutomationSun(unittest.TestCase): }, 'action': { 'service': 'test.automation', + 'data_template': { + 'some': + '{{ trigger.%s }}' % '}} - {{ trigger.'.join(( + 'platform', 'event', 'offset')) + }, } } }) @@ -112,6 +117,7 @@ class TestAutomationSun(unittest.TestCase): fire_time_changed(self.hass, trigger_time) self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) + self.assertEqual('sun - sunset - 0:30:00', self.calls[0].data['some']) def test_sunrise_trigger_with_offset(self): """Test the runrise trigger with offset.""" diff --git a/tests/components/automation/test_template.py b/tests/components/automation/test_template.py index bb46c7a262a..a643b731492 100644 --- a/tests/components/automation/test_template.py +++ b/tests/components/automation/test_template.py @@ -1,4 +1,4 @@ -"""The tests fr the Template automation.""" +"""The tests for the Template automation.""" import unittest from homeassistant.bootstrap import _setup_component @@ -226,7 +226,13 @@ class TestAutomationTemplate(unittest.TestCase): {%- endif -%}''', }, 'action': { - 'service': 'test.automation' + 'service': 'test.automation', + 'data_template': { + 'some': + '{{ trigger.%s }}' % '}} - {{ trigger.'.join(( + 'platform', 'entity_id', 'from_state.state', + 'to_state.state')) + }, } } }) @@ -234,6 +240,9 @@ class TestAutomationTemplate(unittest.TestCase): self.hass.states.set('test.entity', 'world') self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) + self.assertEqual( + 'template - test.entity - hello - world', + self.calls[0].data['some']) def test_if_fires_on_no_change_with_template_advanced(self): """Test for firing on no change with template advanced.""" diff --git a/tests/components/automation/test_time.py b/tests/components/automation/test_time.py index 36f22a00148..0b19e9389e2 100644 --- a/tests/components/automation/test_time.py +++ b/tests/components/automation/test_time.py @@ -176,7 +176,11 @@ class TestAutomationTime(unittest.TestCase): 'after': '5:00:00', }, 'action': { - 'service': 'test.automation' + 'service': 'test.automation', + 'data_template': { + 'some': '{{ trigger.platform }} - ' + '{{ trigger.now.hour }}' + }, } } }) @@ -186,6 +190,7 @@ class TestAutomationTime(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) + self.assertEqual('time - 5', self.calls[0].data['some']) def test_if_not_working_if_no_values_in_conf_provided(self): """Test for failure if no configuration.""" diff --git a/tests/components/automation/test_zone.py b/tests/components/automation/test_zone.py index 87a22243760..24980b466bf 100644 --- a/tests/components/automation/test_zone.py +++ b/tests/components/automation/test_zone.py @@ -52,6 +52,13 @@ class TestAutomationZone(unittest.TestCase): }, 'action': { 'service': 'test.automation', + 'data_template': { + 'some': '{{ trigger.%s }}' % '}} - {{ trigger.'.join(( + 'platform', 'entity_id', + 'from_state.state', 'to_state.state', + 'zone.name')) + }, + } } }) @@ -63,6 +70,9 @@ class TestAutomationZone(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) + self.assertEqual( + 'zone - test.entity - hello - hello - test', + self.calls[0].data['some']) def test_if_not_fires_for_enter_on_zone_leave(self): """Test for not firing on zone leave.""" diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 6d3a9cbb6a0..5d9f8d28e20 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -7,6 +7,7 @@ from datetime import datetime, timedelta from astral import Astral import homeassistant.core as ha +from homeassistant.const import MATCH_ALL from homeassistant.helpers.event import ( track_point_in_utc_time, track_point_in_time, @@ -93,6 +94,7 @@ class TestEventHelpers(unittest.TestCase): # 2 lists to track how often our callbacks get called specific_runs = [] wildcard_runs = [] + wildercard_runs = [] track_state_change( self.hass, 'light.Bowl', lambda a, b, c: specific_runs.append(1), @@ -100,14 +102,18 @@ class TestEventHelpers(unittest.TestCase): track_state_change( self.hass, 'light.Bowl', - lambda _, old_s, new_s: wildcard_runs.append((old_s, new_s)), - ha.MATCH_ALL, ha.MATCH_ALL) + lambda _, old_s, new_s: wildcard_runs.append((old_s, new_s))) + + track_state_change( + self.hass, MATCH_ALL, + lambda _, old_s, new_s: wildercard_runs.append((old_s, new_s))) # Adding state to state machine self.hass.states.set("light.Bowl", "on") self.hass.pool.block_till_done() self.assertEqual(0, len(specific_runs)) self.assertEqual(1, len(wildcard_runs)) + self.assertEqual(1, len(wildercard_runs)) self.assertIsNone(wildcard_runs[-1][0]) self.assertIsNotNone(wildcard_runs[-1][1]) @@ -116,31 +122,45 @@ class TestEventHelpers(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(0, len(specific_runs)) self.assertEqual(1, len(wildcard_runs)) + self.assertEqual(1, len(wildercard_runs)) # State change off -> on self.hass.states.set('light.Bowl', 'off') self.hass.pool.block_till_done() self.assertEqual(1, len(specific_runs)) self.assertEqual(2, len(wildcard_runs)) + self.assertEqual(2, len(wildercard_runs)) # State change off -> off self.hass.states.set('light.Bowl', 'off', {"some_attr": 1}) self.hass.pool.block_till_done() self.assertEqual(1, len(specific_runs)) self.assertEqual(3, len(wildcard_runs)) + self.assertEqual(3, len(wildercard_runs)) # State change off -> on self.hass.states.set('light.Bowl', 'on') self.hass.pool.block_till_done() self.assertEqual(1, len(specific_runs)) self.assertEqual(4, len(wildcard_runs)) + self.assertEqual(4, len(wildercard_runs)) self.hass.states.remove('light.bowl') self.hass.pool.block_till_done() self.assertEqual(1, len(specific_runs)) self.assertEqual(5, len(wildcard_runs)) + self.assertEqual(5, len(wildercard_runs)) self.assertIsNotNone(wildcard_runs[-1][0]) self.assertIsNone(wildcard_runs[-1][1]) + self.assertIsNotNone(wildercard_runs[-1][0]) + self.assertIsNone(wildercard_runs[-1][1]) + + # Set state for different entity id + self.hass.states.set('switch.kitchen', 'on') + self.hass.pool.block_till_done() + self.assertEqual(1, len(specific_runs)) + self.assertEqual(5, len(wildcard_runs)) + self.assertEqual(6, len(wildercard_runs)) def test_track_sunrise(self): """Test track the sunrise.""" diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 59ba1781ab2..c863a46ad3b 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -2,7 +2,8 @@ import unittest from unittest.mock import patch -import homeassistant.components # noqa - to prevent circular import +# To prevent circular import when running just this file +import homeassistant.components # noqa from homeassistant import core as ha, loader from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID from homeassistant.helpers import service From f76d545a084e48f0cdaf59cd5937f9ae2eee691d Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 21 Apr 2016 15:52:20 -0700 Subject: [PATCH 04/11] Add script logic into helper. --- homeassistant/components/automation/state.py | 6 +- homeassistant/components/automation/sun.py | 6 +- homeassistant/components/script.py | 175 ++----------------- homeassistant/const.py | 1 + homeassistant/helpers/config_validation.py | 81 ++++++--- homeassistant/helpers/script.py | 125 +++++++++++++ tests/components/test_script.py | 47 ----- tests/helpers/test_config_validation.py | 11 +- tests/helpers/test_service.py | 5 +- 9 files changed, 219 insertions(+), 238 deletions(-) create mode 100644 homeassistant/helpers/script.py diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py index 802debbe63e..3183dab0803 100644 --- a/homeassistant/components/automation/state.py +++ b/homeassistant/components/automation/state.py @@ -92,7 +92,7 @@ def trigger(hass, config, action): def state_for_listener(now): """Fire on state changes after a delay and calls action.""" hass.bus.remove_listener( - EVENT_STATE_CHANGED, attached_state_for_cancel_listener) + EVENT_STATE_CHANGED, attached_state_for_cancel) call_action() def state_for_cancel_listener(entity, inner_from_s, inner_to_s): @@ -102,12 +102,12 @@ def trigger(hass, config, action): hass.bus.remove_listener(EVENT_TIME_CHANGED, attached_state_for_listener) hass.bus.remove_listener(EVENT_STATE_CHANGED, - attached_state_for_cancel_listener) + attached_state_for_cancel) attached_state_for_listener = track_point_in_time( hass, state_for_listener, dt_util.utcnow() + time_delta) - attached_state_for_cancel_listener = track_state_change( + attached_state_for_cancel = track_state_change( hass, entity_id, state_for_cancel_listener) track_state_change( diff --git a/homeassistant/components/automation/sun.py b/homeassistant/components/automation/sun.py index c9db88a83c2..7de43d7f5e3 100644 --- a/homeassistant/components/automation/sun.py +++ b/homeassistant/components/automation/sun.py @@ -35,7 +35,7 @@ _SUN_EVENT = vol.All(vol.Lower, vol.Any(EVENT_SUNRISE, EVENT_SUNSET)) TRIGGER_SCHEMA = vol.Schema({ vol.Required(CONF_PLATFORM): 'sun', vol.Required(CONF_EVENT): _SUN_EVENT, - vol.Required(CONF_OFFSET, default=timedelta(0)): cv.time_offset, + vol.Required(CONF_OFFSET, default=timedelta(0)): cv.time_period, }) IF_ACTION_SCHEMA = vol.All( @@ -43,8 +43,8 @@ IF_ACTION_SCHEMA = vol.All( vol.Required(CONF_PLATFORM): 'sun', CONF_BEFORE: _SUN_EVENT, CONF_AFTER: _SUN_EVENT, - vol.Required(CONF_BEFORE_OFFSET, default=timedelta(0)): cv.time_offset, - vol.Required(CONF_AFTER_OFFSET, default=timedelta(0)): cv.time_offset, + vol.Required(CONF_BEFORE_OFFSET, default=timedelta(0)): cv.time_period, + vol.Required(CONF_AFTER_OFFSET, default=timedelta(0)): cv.time_period, }), cv.has_at_least_one_key(CONF_BEFORE, CONF_AFTER), ) diff --git a/homeassistant/components/script.py b/homeassistant/components/script.py index c19e614f19d..3557179c6eb 100644 --- a/homeassistant/components/script.py +++ b/homeassistant/components/script.py @@ -8,101 +8,33 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/components/script/ """ import logging -import threading -from datetime import timedelta -from itertools import islice import voluptuous as vol -import homeassistant.util.dt as date_util from homeassistant.const import ( - ATTR_ENTITY_ID, EVENT_TIME_CHANGED, SERVICE_TURN_OFF, SERVICE_TURN_ON, - SERVICE_TOGGLE, STATE_ON) + ATTR_ENTITY_ID, SERVICE_TURN_OFF, SERVICE_TURN_ON, + SERVICE_TOGGLE, STATE_ON, CONF_ALIAS) from homeassistant.helpers.entity import ToggleEntity, split_entity_id from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.event import track_point_in_utc_time -from homeassistant.helpers.service import (call_from_config, - validate_service_call) import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.script import Script + DOMAIN = "script" ENTITY_ID_FORMAT = DOMAIN + '.{}' DEPENDENCIES = ["group"] -STATE_NOT_RUNNING = 'Not Running' - -CONF_ALIAS = "alias" -CONF_SERVICE = "service" -CONF_SERVICE_DATA = "data" CONF_SEQUENCE = "sequence" -CONF_EVENT = "event" -CONF_EVENT_DATA = "event_data" -CONF_DELAY = "delay" ATTR_LAST_ACTION = 'last_action' ATTR_CAN_CANCEL = 'can_cancel' _LOGGER = logging.getLogger(__name__) -_ALIAS_VALIDATOR = vol.Schema(cv.string) - - -def _alias_stripper(validator): - """Strip alias from object for validation.""" - def validate(value): - """Validate without alias value.""" - value = value.copy() - alias = value.pop(CONF_ALIAS, None) - - if alias is not None: - alias = _ALIAS_VALIDATOR(alias) - - value = validator(value) - - if alias is not None: - value[CONF_ALIAS] = alias - - return value - - return validate - - -_TIMESPEC = vol.Schema({ - 'days': cv.positive_int, - 'hours': cv.positive_int, - 'minutes': cv.positive_int, - 'seconds': cv.positive_int, - 'milliseconds': cv.positive_int, -}) -_TIMESPEC_REQ = cv.has_at_least_one_key( - 'days', 'hours', 'minutes', 'seconds', 'milliseconds', -) - -_DELAY_SCHEMA = vol.Any( - vol.Schema({ - vol.Required(CONF_DELAY): vol.All(_TIMESPEC.extend({ - vol.Optional(CONF_ALIAS): cv.string - }), _TIMESPEC_REQ) - }), - # Alternative format in case people forgot to indent after 'delay:' - vol.All(_TIMESPEC.extend({ - vol.Required(CONF_DELAY): None, - vol.Optional(CONF_ALIAS): cv.string, - }), _TIMESPEC_REQ) -) - -_EVENT_SCHEMA = cv.EVENT_SCHEMA.extend({ - CONF_ALIAS: cv.string, -}) _SCRIPT_ENTRY_SCHEMA = vol.Schema({ CONF_ALIAS: cv.string, - vol.Required(CONF_SEQUENCE): vol.All(vol.Length(min=1), [vol.Any( - _EVENT_SCHEMA, - _DELAY_SCHEMA, - # Can't extend SERVICE_SCHEMA because it is an vol.All - _alias_stripper(cv.SERVICE_SCHEMA), - )]), + vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA, }) CONFIG_SCHEMA = vol.Schema({ @@ -152,7 +84,7 @@ def setup(hass, config): for object_id, cfg in config[DOMAIN].items(): alias = cfg.get(CONF_ALIAS, object_id) - script = Script(object_id, alias, cfg[CONF_SEQUENCE]) + script = ScriptEntity(hass, object_id, alias, cfg[CONF_SEQUENCE]) component.add_entities((script,)) hass.services.register(DOMAIN, object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA) @@ -183,21 +115,14 @@ def setup(hass, config): return True -class Script(ToggleEntity): - """Representation of a script.""" +class ScriptEntity(ToggleEntity): + """Representation of a script entity.""" # pylint: disable=too-many-instance-attributes - def __init__(self, object_id, name, sequence): + def __init__(self, hass, object_id, name, sequence): """Initialize the script.""" self.entity_id = ENTITY_ID_FORMAT.format(object_id) - self._name = name - self.sequence = sequence - self._lock = threading.Lock() - self._cur = -1 - self._last_action = None - self._listener = None - self._can_cancel = any(CONF_DELAY in action for action - in self.sequence) + self.script = Script(hass, sequence, name, self.update_ha_state) @property def should_poll(self): @@ -207,91 +132,27 @@ class Script(ToggleEntity): @property def name(self): """Return the name of the entity.""" - return self._name + return self.script.name @property def state_attributes(self): """Return the state attributes.""" attrs = {} - if self._can_cancel: - attrs[ATTR_CAN_CANCEL] = self._can_cancel - if self._last_action: - attrs[ATTR_LAST_ACTION] = self._last_action + if self.script.can_cancel: + attrs[ATTR_CAN_CANCEL] = self.script.can_cancel + if self.script.last_action: + attrs[ATTR_LAST_ACTION] = self.script.last_action return attrs @property def is_on(self): """Return true if script is on.""" - return self._cur != -1 + return self.script.is_running def turn_on(self, **kwargs): """Turn the entity on.""" - _LOGGER.info("Executing script %s", self._name) - with self._lock: - if self._cur == -1: - self._cur = 0 - - # Unregister callback if we were in a delay but turn on is called - # again. In that case we just continue execution. - self._remove_listener() - - for cur, action in islice(enumerate(self.sequence), self._cur, - None): - - if validate_service_call(action) is None: - self._call_service(action) - - elif CONF_EVENT in action: - self._fire_event(action) - - elif CONF_DELAY in action: - # Call ourselves in the future to continue work - def script_delay(now): - """Called after delay is done.""" - self._listener = None - self.turn_on() - - timespec = action[CONF_DELAY] or action.copy() - timespec.pop(CONF_DELAY, None) - delay = timedelta(**timespec) - self._listener = track_point_in_utc_time( - self.hass, script_delay, date_util.utcnow() + delay) - self._cur = cur + 1 - self.update_ha_state() - return - - self._cur = -1 - self._last_action = None - self.update_ha_state() + self.script.run() def turn_off(self, **kwargs): """Turn script off.""" - _LOGGER.info("Cancelled script %s", self._name) - with self._lock: - if self._cur == -1: - return - - self._cur = -1 - self.update_ha_state() - self._remove_listener() - - def _call_service(self, action): - """Call the service specified in the action.""" - self._last_action = action.get(CONF_ALIAS, 'call service') - _LOGGER.info("Executing script %s step %s", self._name, - self._last_action) - call_from_config(self.hass, action, True) - - def _fire_event(self, action): - """Fire an event.""" - self._last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) - _LOGGER.info("Executing script %s step %s", self._name, - self._last_action) - self.hass.bus.fire(action[CONF_EVENT], action.get(CONF_EVENT_DATA)) - - def _remove_listener(self): - """Remove point in time listener, if any.""" - if self._listener: - self.hass.bus.remove_listener(EVENT_TIME_CHANGED, - self._listener) - self._listener = None + self.script.stop() diff --git a/homeassistant/const.py b/homeassistant/const.py index 77e540cd76f..b2971ab59f6 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -13,6 +13,7 @@ MATCH_ALL = '*' DEVICE_DEFAULT_NAME = "Unnamed Device" # #### CONFIG #### +CONF_ALIAS = "alias" CONF_ICON = "icon" CONF_LATITUDE = "latitude" CONF_LONGITUDE = "longitude" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 51684e5f1cd..71e103f7dd3 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -6,7 +6,8 @@ import voluptuous as vol from homeassistant.loader import get_platform from homeassistant.const import ( - CONF_PLATFORM, CONF_SCAN_INTERVAL, TEMP_CELSIUS, TEMP_FAHRENHEIT) + CONF_PLATFORM, CONF_SCAN_INTERVAL, TEMP_CELSIUS, TEMP_FAHRENHEIT, + CONF_ALIAS) from homeassistant.helpers.entity import valid_entity_id import homeassistant.util.dt as dt_util from homeassistant.util import slugify @@ -23,6 +24,23 @@ longitude = vol.All(vol.Coerce(float), vol.Range(min=-180, max=180), msg='invalid longitude') +# Adapted from: +# https://github.com/alecthomas/voluptuous/issues/115#issuecomment-144464666 +def has_at_least_one_key(*keys): + """Validator that at least one key exists.""" + def validate(obj): + """Test keys exist in dict.""" + if not isinstance(obj, dict): + raise vol.Invalid('expected dictionary') + + for k in obj.keys(): + if k in keys: + return obj + raise vol.Invalid('must contain one of {}.'.format(', '.join(keys))) + + return validate + + def boolean(value): """Validate and coerce a boolean value.""" if isinstance(value, str): @@ -72,10 +90,24 @@ def icon(value): raise vol.Invalid('Icons should start with prefix "mdi:"') -def time_offset(value): +time_period_dict = vol.All( + dict, vol.Schema({ + 'days': vol.Coerce(int), + 'hours': vol.Coerce(int), + 'minutes': vol.Coerce(int), + 'seconds': vol.Coerce(int), + 'milliseconds': vol.Coerce(int), + }), + has_at_least_one_key('days', 'hours', 'minutes', + 'seconds', 'milliseconds'), + lambda value: timedelta(**value)) + + +def time_period_str(value): """Validate and transform time offset.""" if not isinstance(value, str): - raise vol.Invalid('offset should be a string') + raise vol.Invalid( + 'offset {} should be format HH:MM or HH:MM:SS'.format(value)) negative_offset = False if value.startswith('-'): @@ -107,6 +139,9 @@ def time_offset(value): return offset +time_period = vol.Any(time_period_str, timedelta, time_period_dict) + + def match_all(value): """Validator that matches all values.""" return value @@ -125,6 +160,13 @@ def platform_validator(domain): return validator +def positive_timedelta(value): + """Validate timedelta is positive.""" + if value < timedelta(0): + raise vol.Invalid('Time period should be positive') + return value + + def service(value): """Validate service.""" # Services use same format as entities so we can use same helper. @@ -200,23 +242,6 @@ def key_dependency(key, dependency): return validator -# Adapted from: -# https://github.com/alecthomas/voluptuous/issues/115#issuecomment-144464666 -def has_at_least_one_key(*keys): - """Validator that at least one key exists.""" - def validate(obj): - """Test keys exist in dict.""" - if not isinstance(obj, dict): - raise vol.Invalid('expected dictionary') - - for k in obj.keys(): - if k in keys: - return obj - raise vol.Invalid('must contain one of {}.'.format(', '.join(keys))) - - return validate - - # Schemas PLATFORM_SCHEMA = vol.Schema({ @@ -225,14 +250,28 @@ PLATFORM_SCHEMA = vol.Schema({ }, extra=vol.ALLOW_EXTRA) EVENT_SCHEMA = vol.Schema({ + vol.Optional(CONF_ALIAS): string, vol.Required('event'): string, - 'event_data': dict + vol.Optional('event_data'): dict, }) SERVICE_SCHEMA = vol.All(vol.Schema({ + vol.Optional(CONF_ALIAS): string, vol.Exclusive('service', 'service name'): service, vol.Exclusive('service_template', 'service name'): template, vol.Optional('data'): dict, vol.Optional('data_template'): {match_all: template}, vol.Optional('entity_id'): entity_ids, }), has_at_least_one_key('service', 'service_template')) + +# ----- SCRIPT + +_DELAY_SCHEMA = vol.Schema({ + vol.Optional(CONF_ALIAS): string, + vol.Required("delay"): vol.All(time_period, positive_timedelta) +}) + +SCRIPT_SCHEMA = vol.All( + ensure_list, + [vol.Any(SERVICE_SCHEMA, _DELAY_SCHEMA, EVENT_SCHEMA)], +) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py new file mode 100644 index 00000000000..e4cf2f6756d --- /dev/null +++ b/homeassistant/helpers/script.py @@ -0,0 +1,125 @@ +"""Helpers to execute scripts.""" +import logging +import threading +from itertools import islice + +import homeassistant.util.dt as date_util +from homeassistant.const import EVENT_TIME_CHANGED +from homeassistant.helpers.event import track_point_in_utc_time +from homeassistant.helpers import service +import homeassistant.helpers.config_validation as cv + +_LOGGER = logging.getLogger(__name__) + +CONF_ALIAS = "alias" +CONF_SERVICE = "service" +CONF_SERVICE_DATA = "data" +CONF_SEQUENCE = "sequence" +CONF_EVENT = "event" +CONF_EVENT_DATA = "event_data" +CONF_DELAY = "delay" + + +def call_from_config(hass, config): + """Call a script based on a config entry.""" + Script(hass, config).run() + + +class Script(): + """Representation of a script.""" + + # pylint: disable=too-many-instance-attributes + def __init__(self, hass, sequence, name=None, change_listener=None): + """Initialize the script.""" + self.hass = hass + self.sequence = cv.SCRIPT_SCHEMA(sequence) + self.name = name + self._change_listener = change_listener + self._cur = -1 + self.last_action = None + self.can_cancel = any(CONF_DELAY in action for action + in self.sequence) + self._lock = threading.Lock() + self._delay_listener = None + + @property + def is_running(self): + """Return true if script is on.""" + return self._cur != -1 + + def run(self): + """Run script.""" + with self._lock: + if self._cur == -1: + self._log('Running script') + self._cur = 0 + + # Unregister callback if we were in a delay but turn on is called + # again. In that case we just continue execution. + self._remove_listener() + + for cur, action in islice(enumerate(self.sequence), self._cur, + None): + + if CONF_DELAY in action: + # Call ourselves in the future to continue work + def script_delay(now): + """Called after delay is done.""" + self._delay_listener = None + self.run() + + self._delay_listener = track_point_in_utc_time( + self.hass, script_delay, + date_util.utcnow() + action[CONF_DELAY]) + self._cur = cur + 1 + if self._change_listener: + self._change_listener() + return + + elif service.validate_service_call(action) is None: + self._call_service(action) + + elif CONF_EVENT in action: + self._fire_event(action) + + self._cur = -1 + self.last_action = None + if self._change_listener: + self._change_listener() + + def stop(self): + """Stop running script.""" + with self._lock: + if self._cur == -1: + return + + self._cur = -1 + self._remove_listener() + if self._change_listener: + self._change_listener() + + def _call_service(self, action): + """Call the service specified in the action.""" + self.last_action = action.get(CONF_ALIAS, 'call service') + self._log("Executing step %s", self.last_action) + service.call_from_config(self.hass, action, True) + + def _fire_event(self, action): + """Fire an event.""" + self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) + self._log("Executing step %s", self.last_action) + self.hass.bus.fire(action[CONF_EVENT], action.get(CONF_EVENT_DATA)) + + def _remove_listener(self): + """Remove point in time listener, if any.""" + if self._delay_listener: + self.hass.bus.remove_listener(EVENT_TIME_CHANGED, + self._delay_listener) + self._delay_listener = None + + def _log(self, msg, *substitutes): + """Logger helper.""" + if self.name is not None: + msg = "Script {}: {}".format(self.name, msg, *substitutes) + + _LOGGER.info(msg) diff --git a/tests/components/test_script.py b/tests/components/test_script.py index 4f912dc77a0..f8b99533c18 100644 --- a/tests/components/test_script.py +++ b/tests/components/test_script.py @@ -34,13 +34,6 @@ class TestScript(unittest.TestCase): 'sequence': [{'event': 'bla'}] } }, - { - 'test': { - 'sequence': { - 'event': 'test_event' - } - } - }, { 'test': { 'sequence': { @@ -49,7 +42,6 @@ class TestScript(unittest.TestCase): } } }, - ): assert not _setup_component(self.hass, 'script', { 'script': value @@ -206,45 +198,6 @@ class TestScript(unittest.TestCase): self.assertEqual(2, len(calls)) - def test_alt_delay(self): - """Test alternative delay config format.""" - event = 'test_event' - calls = [] - - def record_event(event): - """Add recorded event to set.""" - calls.append(event) - - self.hass.bus.listen(event, record_event) - - assert _setup_component(self.hass, 'script', { - 'script': { - 'test': { - 'sequence': [{ - 'event': event, - }, { - 'delay': None, - 'seconds': 5 - }, { - 'event': event, - }] - } - } - }) - - script.turn_on(self.hass, ENTITY_ID) - self.hass.pool.block_till_done() - - self.assertTrue(script.is_on(self.hass, ENTITY_ID)) - self.assertEqual(1, len(calls)) - - future = dt_util.utcnow() + timedelta(seconds=5) - fire_time_changed(self.hass, future) - self.hass.pool.block_till_done() - - self.assertFalse(script.is_on(self.hass, ENTITY_ID)) - self.assertEqual(2, len(calls)) - def test_cancel_while_delay(self): """Test the cancelling while the delay is present.""" event = 'test_event' diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 3f4789eca4f..b73dc6d6f94 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -145,18 +145,19 @@ def test_icon(): schema('mdi:work') -def test_time_offset(): - """Test time_offset validation.""" - schema = vol.Schema(cv.time_offset) +def test_time_period(): + """Test time_period validation.""" + schema = vol.Schema(cv.time_period) for value in ( - None, '', 1234, 'hello:world', '12:', '12:34:56:78' + None, '', 1234, 'hello:world', '12:', '12:34:56:78', + {}, {'wrong_key': -10} ): with pytest.raises(vol.MultipleInvalid): schema(value) for value in ( - '8:20', '23:59', '-8:20', '-23:59:59', '-48:00' + '8:20', '23:59', '-8:20', '-23:59:59', '-48:00', {'minutes': 5} ): schema(value) diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index c863a46ad3b..11ace1ab5d8 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -37,7 +37,7 @@ class TestServiceHelpers(unittest.TestCase): self.assertEqual(1, len(runs)) def test_template_service_call(self): - """ Test service call with tempating. """ + """Test service call with tempating.""" config = { 'service_template': '{{ \'test_domain.test_service\' }}', 'entity_id': 'hello.world', @@ -56,6 +56,7 @@ class TestServiceHelpers(unittest.TestCase): self.assertEqual('goodbye', runs[0].data['hello']) def test_passing_variables_to_templates(self): + """Test passing variables to templates.""" config = { 'service_template': '{{ var_service }}', 'entity_id': 'hello.world', @@ -141,7 +142,7 @@ class TestServiceHelpers(unittest.TestCase): service.extract_entity_ids(self.hass, call)) def test_validate_service_call(self): - """Test is_valid_service_call method""" + """Test is_valid_service_call method.""" self.assertNotEqual( service.validate_service_call( {}), From 09a771a026d12e5c6d32878c17a809f38ec2a41c Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 21 Apr 2016 21:29:20 -0400 Subject: [PATCH 05/11] Move script component tests to script helper tests --- tests/components/test_script.py | 199 +------------------------------- tests/helpers/test_script.py | 175 ++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+), 197 deletions(-) create mode 100644 tests/helpers/test_script.py diff --git a/tests/components/test_script.py b/tests/components/test_script.py index f8b99533c18..90165ada294 100644 --- a/tests/components/test_script.py +++ b/tests/components/test_script.py @@ -1,19 +1,17 @@ """The tests for the Script component.""" # pylint: disable=too-many-public-methods,protected-access -from datetime import timedelta import unittest from homeassistant.bootstrap import _setup_component from homeassistant.components import script -import homeassistant.util.dt as dt_util -from tests.common import fire_time_changed, get_test_home_assistant +from tests.common import get_test_home_assistant ENTITY_ID = 'script.test' -class TestScript(unittest.TestCase): +class TestScriptComponent(unittest.TestCase): """Test the Script component.""" def setUp(self): # pylint: disable=invalid-name @@ -49,199 +47,6 @@ class TestScript(unittest.TestCase): self.assertEqual(0, len(self.hass.states.entity_ids('script'))) - def test_firing_event(self): - """Test the firing of events.""" - event = 'test_event' - calls = [] - - def record_event(event): - """Add recorded event to set.""" - calls.append(event) - - self.hass.bus.listen(event, record_event) - - assert _setup_component(self.hass, 'script', { - 'script': { - 'test': { - 'alias': 'Test Script', - 'sequence': [{ - 'event': event, - 'event_data': { - 'hello': 'world' - } - }] - } - } - }) - - script.turn_on(self.hass, ENTITY_ID) - self.hass.pool.block_till_done() - - self.assertEqual(1, len(calls)) - self.assertEqual('world', calls[0].data.get('hello')) - self.assertIsNone( - self.hass.states.get(ENTITY_ID).attributes.get('can_cancel')) - - def test_calling_service(self): - """Test the calling of a service.""" - calls = [] - - def record_call(service): - """Add recorded event to set.""" - calls.append(service) - - self.hass.services.register('test', 'script', record_call) - - assert _setup_component(self.hass, 'script', { - 'script': { - 'test': { - 'sequence': [{ - 'service': 'test.script', - 'data': { - 'hello': 'world' - } - }] - } - } - }) - - script.turn_on(self.hass, ENTITY_ID) - self.hass.pool.block_till_done() - - self.assertEqual(1, len(calls)) - self.assertEqual('world', calls[0].data.get('hello')) - - def test_calling_service_template(self): - """Test the calling of a service.""" - calls = [] - - def record_call(service): - """Add recorded event to set.""" - calls.append(service) - - self.hass.services.register('test', 'script', record_call) - - assert _setup_component(self.hass, 'script', { - 'script': { - 'test': { - 'sequence': [{ - 'service_template': """ - {% if True %} - test.script - {% else %} - test.not_script - {% endif %}""", - 'data_template': { - 'hello': """ - {% if True %} - world - {% else %} - Not world - {% endif %} - """ - } - }] - } - } - }) - - script.turn_on(self.hass, ENTITY_ID) - self.hass.pool.block_till_done() - - self.assertEqual(1, len(calls)) - self.assertEqual('world', calls[0].data.get('hello')) - - def test_delay(self): - """Test the delay.""" - event = 'test_event' - calls = [] - - def record_event(event): - """Add recorded event to set.""" - calls.append(event) - - self.hass.bus.listen(event, record_event) - - assert _setup_component(self.hass, 'script', { - 'script': { - 'test': { - 'sequence': [{ - 'event': event - }, { - 'delay': { - 'seconds': 5 - } - }, { - 'event': event, - }] - } - } - }) - - script.turn_on(self.hass, ENTITY_ID) - self.hass.pool.block_till_done() - - self.assertTrue(script.is_on(self.hass, ENTITY_ID)) - self.assertTrue( - self.hass.states.get(ENTITY_ID).attributes.get('can_cancel')) - - self.assertEqual( - event, - self.hass.states.get(ENTITY_ID).attributes.get('last_action')) - self.assertEqual(1, len(calls)) - - future = dt_util.utcnow() + timedelta(seconds=5) - fire_time_changed(self.hass, future) - self.hass.pool.block_till_done() - - self.assertFalse(script.is_on(self.hass, ENTITY_ID)) - - self.assertEqual(2, len(calls)) - - def test_cancel_while_delay(self): - """Test the cancelling while the delay is present.""" - event = 'test_event' - calls = [] - - def record_event(event): - """Add recorded event to set.""" - calls.append(event) - - self.hass.bus.listen(event, record_event) - - assert _setup_component(self.hass, 'script', { - 'script': { - 'test': { - 'sequence': [{ - 'delay': { - 'seconds': 5 - } - }, { - 'event': event, - }] - } - } - }) - - script.turn_on(self.hass, ENTITY_ID) - self.hass.pool.block_till_done() - - self.assertTrue(script.is_on(self.hass, ENTITY_ID)) - - self.assertEqual(0, len(calls)) - - script.turn_off(self.hass, ENTITY_ID) - self.hass.pool.block_till_done() - self.assertFalse(script.is_on(self.hass, ENTITY_ID)) - - future = dt_util.utcnow() + timedelta(seconds=5) - fire_time_changed(self.hass, future) - self.hass.pool.block_till_done() - - self.assertFalse(script.is_on(self.hass, ENTITY_ID)) - - self.assertEqual(0, len(calls)) - def test_turn_on_service(self): """Verify that the turn_on service.""" event = 'test_event' diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py new file mode 100644 index 00000000000..90e438047bb --- /dev/null +++ b/tests/helpers/test_script.py @@ -0,0 +1,175 @@ +"""The tests for the Script component.""" +# pylint: disable=too-many-public-methods,protected-access +from datetime import timedelta +import unittest + +from homeassistant.bootstrap import _setup_component +import homeassistant.util.dt as dt_util +from homeassistant.helpers import script + +from tests.common import fire_time_changed, get_test_home_assistant + + +ENTITY_ID = 'script.test' + + +class TestScriptHelper(unittest.TestCase): + """Test the Script component.""" + + def setUp(self): # pylint: disable=invalid-name + """Setup things to be run when tests are started.""" + self.hass = get_test_home_assistant() + + def tearDown(self): # pylint: disable=invalid-name + """Stop down everything that was started.""" + self.hass.stop() + + def test_firing_event(self): + """Test the firing of events.""" + event = 'test_event' + calls = [] + + def record_event(event): + """Add recorded event to set.""" + calls.append(event) + + self.hass.bus.listen(event, record_event) + + script_obj = script.Script(self.hass, { + 'event': event, + 'event_data': { + 'hello': 'world' + } + }) + + script_obj.run() + + self.hass.pool.block_till_done() + + assert len(calls) == 1 + assert calls[0].data.get('hello') == 'world' + assert not script_obj.can_cancel + + def test_calling_service(self): + """Test the calling of a service.""" + calls = [] + + def record_call(service): + """Add recorded event to set.""" + calls.append(service) + + self.hass.services.register('test', 'script', record_call) + + script_obj = script.Script(self.hass, { + 'service': 'test.script', + 'data': { + 'hello': 'world' + } + }) + + script_obj.run() + self.hass.pool.block_till_done() + + assert len(calls) == 1 + assert calls[0].data.get('hello') == 'world' + + def test_calling_service_template(self): + """Test the calling of a service.""" + calls = [] + + def record_call(service): + """Add recorded event to set.""" + calls.append(service) + + self.hass.services.register('test', 'script', record_call) + + script_obj = script.Script(self.hass, { + 'service_template': """ + {% if True %} + test.script + {% else %} + test.not_script + {% endif %}""", + 'data_template': { + 'hello': """ + {% if True %} + world + {% else %} + Not world + {% endif %} + """ + } + }) + + script_obj.run() + + self.hass.pool.block_till_done() + + assert len(calls) == 1 + assert calls[0].data.get('hello') == 'world' + + def test_delay(self): + """Test the delay.""" + event = 'test_event' + events = [] + + def record_event(event): + """Add recorded event to set.""" + events.append(event) + + self.hass.bus.listen(event, record_event) + + script_obj = script.Script(self.hass, [ + {'event': event}, + {'delay': {'seconds': 5}}, + {'event': event}]) + + script_obj.run() + + self.hass.pool.block_till_done() + + assert script_obj.is_running + assert script_obj.can_cancel + assert script_obj.last_action == event + assert len(events) == 1 + + future = dt_util.utcnow() + timedelta(seconds=5) + fire_time_changed(self.hass, future) + self.hass.pool.block_till_done() + + assert not script_obj.is_running + assert len(events) == 2 + + def test_cancel_while_delay(self): + """Test the cancelling while the delay is present.""" + event = 'test_event' + events = [] + + def record_event(event): + """Add recorded event to set.""" + events.append(event) + + self.hass.bus.listen(event, record_event) + + script_obj = script.Script(self.hass, [ + {'delay': {'seconds': 5}}, + {'event': event}]) + + script_obj.run() + + self.hass.pool.block_till_done() + + assert script_obj.is_running + assert len(events) == 0 + + script_obj.stop() + + assert not script_obj.is_running + + # Make sure the script is really stopped. + future = dt_util.utcnow() + timedelta(seconds=5) + fire_time_changed(self.hass, future) + self.hass.pool.block_till_done() + + assert not script_obj.is_running + assert len(events) == 0 From 26863284b6245e9552ac9a7839cd279cc04ad393 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 21 Apr 2016 21:42:20 -0400 Subject: [PATCH 06/11] Script helper: support variables --- homeassistant/helpers/script.py | 10 ++++---- tests/helpers/test_script.py | 44 +++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index e4cf2f6756d..0025ddd61df 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -47,7 +47,7 @@ class Script(): """Return true if script is on.""" return self._cur != -1 - def run(self): + def run(self, variables=None): """Run script.""" with self._lock: if self._cur == -1: @@ -66,7 +66,7 @@ class Script(): def script_delay(now): """Called after delay is done.""" self._delay_listener = None - self.run() + self.run(variables) self._delay_listener = track_point_in_utc_time( self.hass, script_delay, @@ -77,7 +77,7 @@ class Script(): return elif service.validate_service_call(action) is None: - self._call_service(action) + self._call_service(action, variables) elif CONF_EVENT in action: self._fire_event(action) @@ -98,11 +98,11 @@ class Script(): if self._change_listener: self._change_listener() - def _call_service(self, action): + def _call_service(self, action, variables): """Call the service specified in the action.""" self.last_action = action.get(CONF_ALIAS, 'call service') self._log("Executing step %s", self.last_action) - service.call_from_config(self.hass, action, True) + service.call_from_config(self.hass, action, True, variables) def _fire_event(self, action): """Fire an event.""" diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 90e438047bb..492b62906df 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -173,3 +173,47 @@ class TestScriptHelper(unittest.TestCase): assert not script_obj.is_running assert len(events) == 0 + + def test_passing_variables_to_script(self): + """Test if we can pass variables to script.""" + calls = [] + + def record_call(service): + """Add recorded event to set.""" + calls.append(service) + + self.hass.services.register('test', 'script', record_call) + + script_obj = script.Script(self.hass, [ + { + 'service': 'test.script', + 'data_template': { + 'hello': '{{ greeting }}', + }, + }, + {'delay': {'seconds': 5}}, + { + 'service': 'test.script', + 'data_template': { + 'hello': '{{ greeting2 }}', + }, + }]) + + script_obj.run({ + 'greeting': 'world', + 'greeting2': 'universe', + }) + + self.hass.pool.block_till_done() + + assert script_obj.is_running + assert len(calls) == 1 + assert calls[-1].data['hello'] == 'world' + + future = dt_util.utcnow() + timedelta(seconds=5) + fire_time_changed(self.hass, future) + self.hass.pool.block_till_done() + + assert not script_obj.is_running + assert len(calls) == 2 + assert calls[-1].data['hello'] == 'universe' From b8e4db9161ca8d058e0d91b759c830f151374235 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 21 Apr 2016 22:21:11 -0400 Subject: [PATCH 07/11] Script entities to allow passing in variables --- homeassistant/components/script.py | 17 ++++----- tests/components/test_script.py | 57 +++++++++++++++++++++++++----- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/homeassistant/components/script.py b/homeassistant/components/script.py index 3557179c6eb..5f1e63f5d00 100644 --- a/homeassistant/components/script.py +++ b/homeassistant/components/script.py @@ -26,12 +26,12 @@ DEPENDENCIES = ["group"] CONF_SEQUENCE = "sequence" +ATTR_VARIABLES = 'variables' ATTR_LAST_ACTION = 'last_action' ATTR_CAN_CANCEL = 'can_cancel' _LOGGER = logging.getLogger(__name__) - _SCRIPT_ENTRY_SCHEMA = vol.Schema({ CONF_ALIAS: cv.string, vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA, @@ -41,9 +41,10 @@ CONFIG_SCHEMA = vol.Schema({ vol.Required(DOMAIN): {cv.slug: _SCRIPT_ENTRY_SCHEMA} }, extra=vol.ALLOW_EXTRA) -SCRIPT_SERVICE_SCHEMA = vol.Schema({}) +SCRIPT_SERVICE_SCHEMA = vol.Schema(dict) SCRIPT_TURN_ONOFF_SCHEMA = vol.Schema({ vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, + vol.Optional(ATTR_VARIABLES): dict, }) @@ -52,11 +53,11 @@ def is_on(hass, entity_id): return hass.states.is_state(entity_id, STATE_ON) -def turn_on(hass, entity_id): +def turn_on(hass, entity_id, variables=None): """Turn script on.""" _, object_id = split_entity_id(entity_id) - hass.services.call(DOMAIN, object_id) + hass.services.call(DOMAIN, object_id, variables) def turn_off(hass, entity_id): @@ -80,7 +81,7 @@ def setup(hass, config): if script.is_on: _LOGGER.warning("Script %s already running.", entity_id) return - script.turn_on() + script.turn_on(variables=service.data) for object_id, cfg in config[DOMAIN].items(): alias = cfg.get(CONF_ALIAS, object_id) @@ -92,9 +93,9 @@ def setup(hass, config): def turn_on_service(service): """Call a service to turn script on.""" # We could turn on script directly here, but we only want to offer - # one way to do it. Otherwise no easy way to call invocations. + # one way to do it. Otherwise no easy way to detect invocations. for script in component.extract_from_service(service): - turn_on(hass, script.entity_id) + turn_on(hass, script.entity_id, service.data.get(ATTR_VARIABLES)) def turn_off_service(service): """Cancel a script.""" @@ -151,7 +152,7 @@ class ScriptEntity(ToggleEntity): def turn_on(self, **kwargs): """Turn the entity on.""" - self.script.run() + self.script.run(kwargs.get(ATTR_VARIABLES)) def turn_off(self, **kwargs): """Turn script off.""" diff --git a/tests/components/test_script.py b/tests/components/test_script.py index 90165ada294..30cf69d7922 100644 --- a/tests/components/test_script.py +++ b/tests/components/test_script.py @@ -50,11 +50,11 @@ class TestScriptComponent(unittest.TestCase): def test_turn_on_service(self): """Verify that the turn_on service.""" event = 'test_event' - calls = [] + events = [] def record_event(event): """Add recorded event to set.""" - calls.append(event) + events.append(event) self.hass.bus.listen(event, record_event) @@ -75,21 +75,21 @@ class TestScriptComponent(unittest.TestCase): script.turn_on(self.hass, ENTITY_ID) self.hass.pool.block_till_done() self.assertTrue(script.is_on(self.hass, ENTITY_ID)) - self.assertEqual(0, len(calls)) + self.assertEqual(0, len(events)) # Calling turn_on a second time should not advance the script script.turn_on(self.hass, ENTITY_ID) self.hass.pool.block_till_done() - self.assertEqual(0, len(calls)) + self.assertEqual(0, len(events)) def test_toggle_service(self): """Test the toggling of a service.""" event = 'test_event' - calls = [] + events = [] def record_event(event): """Add recorded event to set.""" - calls.append(event) + events.append(event) self.hass.bus.listen(event, record_event) @@ -110,9 +110,50 @@ class TestScriptComponent(unittest.TestCase): script.toggle(self.hass, ENTITY_ID) self.hass.pool.block_till_done() self.assertTrue(script.is_on(self.hass, ENTITY_ID)) - self.assertEqual(0, len(calls)) + self.assertEqual(0, len(events)) script.toggle(self.hass, ENTITY_ID) self.hass.pool.block_till_done() self.assertFalse(script.is_on(self.hass, ENTITY_ID)) - self.assertEqual(0, len(calls)) + self.assertEqual(0, len(events)) + + def test_passing_variables(self): + """Test different ways of passing in variables.""" + calls = [] + + def record_call(service): + """Add recorded event to set.""" + calls.append(service) + + self.hass.services.register('test', 'script', record_call) + + assert _setup_component(self.hass, 'script', { + 'script': { + 'test': { + 'sequence': { + 'service': 'test.script', + 'data_template': { + 'hello': '{{ greeting }}', + }, + }, + }, + }, + }) + + script.turn_on(self.hass, ENTITY_ID, { + 'greeting': 'world' + }) + + self.hass.pool.block_till_done() + + assert len(calls) == 1 + assert calls[-1].data['hello'] == 'world' + + self.hass.services.call('script', 'test', { + 'greeting': 'universe', + }) + + self.hass.pool.block_till_done() + + assert len(calls) == 2 + assert calls[-1].data['hello'] == 'universe' From 612a017bc6138b4b0b292111fb2e538889cd85a9 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 21 Apr 2016 22:36:14 -0400 Subject: [PATCH 08/11] Automation: Allow embedding script definition --- .../components/automation/__init__.py | 9 ++++--- tests/components/automation/test_init.py | 26 +++++++++++++++++++ tests/helpers/test_script.py | 1 - 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 8cbaf35a5c4..3ba5596fb4d 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -11,8 +11,7 @@ import voluptuous as vol from homeassistant.bootstrap import prepare_setup_platform from homeassistant.const import CONF_PLATFORM from homeassistant.components import logbook -from homeassistant.helpers import extract_domain_configs -from homeassistant.helpers.service import call_from_config +from homeassistant.helpers import extract_domain_configs, script from homeassistant.loader import get_platform import homeassistant.helpers.config_validation as cv @@ -88,7 +87,7 @@ PLATFORM_SCHEMA = vol.Schema({ vol.Required(CONF_CONDITION_TYPE, default=DEFAULT_CONDITION_TYPE): vol.All(vol.Lower, vol.Any(CONDITION_TYPE_AND, CONDITION_TYPE_OR)), CONF_CONDITION: _CONDITION_SCHEMA, - vol.Required(CONF_ACTION): cv.SERVICE_SCHEMA, + vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA, }) @@ -122,11 +121,13 @@ def _setup_automation(hass, config_block, name, config): def _get_action(hass, config, name): """Return an action based on a configuration.""" + script_obj = script.Script(hass, config, name) + def action(variables=None): """Action to be executed.""" _LOGGER.info('Executing %s', name) logbook.log_entry(hass, name, 'has been triggered', DOMAIN) - call_from_config(hass, config, variables=variables) + script_obj.run(variables) return action diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index f6d33c18071..8e06f524d0e 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -316,3 +316,29 @@ class TestAutomation(unittest.TestCase): self.hass.bus.fire('test_event_2') self.hass.pool.block_till_done() self.assertEqual(2, len(self.calls)) + + def test_automation_calling_two_actions(self): + """Test if we can call two actions from automation definition.""" + self.assertTrue(_setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event', + }, + + 'action': [{ + 'service': 'test.automation', + 'data': {'position': 0}, + }, { + 'service': 'test.automation', + 'data': {'position': 1}, + }], + } + })) + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + + assert len(self.calls) == 2 + assert self.calls[0].data['position'] == 0 + assert self.calls[1].data['position'] == 1 diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 492b62906df..47af833223f 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -3,7 +3,6 @@ from datetime import timedelta import unittest -from homeassistant.bootstrap import _setup_component import homeassistant.util.dt as dt_util from homeassistant.helpers import script From 4a5411a957629b885c856b6f2e18981449a208f9 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 22 Apr 2016 05:30:30 -0400 Subject: [PATCH 09/11] Allow calling scripts from Alexa --- homeassistant/components/alexa.py | 6 +++--- homeassistant/helpers/script.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/homeassistant/components/alexa.py b/homeassistant/components/alexa.py index bb9e1816a68..d6178b4744c 100644 --- a/homeassistant/components/alexa.py +++ b/homeassistant/components/alexa.py @@ -8,8 +8,7 @@ import enum import logging from homeassistant.const import HTTP_OK, HTTP_UNPROCESSABLE_ENTITY -from homeassistant.helpers.service import call_from_config -from homeassistant.helpers import template +from homeassistant.helpers import template, script DOMAIN = 'alexa' DEPENDENCIES = ['http'] @@ -91,7 +90,8 @@ def _handle_alexa(handler, path_match, data): card['content']) if action is not None: - call_from_config(handler.server.hass, action, True, response.variables) + script.call_from_config(handler.server.hass, action, + response.variables) handler.write_json(response.as_dict()) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 0025ddd61df..6c938fd4032 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -20,9 +20,9 @@ CONF_EVENT_DATA = "event_data" CONF_DELAY = "delay" -def call_from_config(hass, config): +def call_from_config(hass, config, variables=None): """Call a script based on a config entry.""" - Script(hass, config).run() + Script(hass, config).run(variables) class Script(): From 533799656e19ae0f6c94c63ad91cb107daf607a4 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 23 Apr 2016 07:10:57 +0200 Subject: [PATCH 10/11] Cache script object for Alexa --- homeassistant/components/alexa.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/alexa.py b/homeassistant/components/alexa.py index d6178b4744c..080d7bd1097 100644 --- a/homeassistant/components/alexa.py +++ b/homeassistant/components/alexa.py @@ -26,7 +26,14 @@ CONF_ACTION = 'action' def setup(hass, config): """Activate Alexa component.""" - _CONFIG.update(config[DOMAIN].get(CONF_INTENTS, {})) + intents = config[DOMAIN].get(CONF_INTENTS, {}) + + for name, intent in intents.items(): + if CONF_ACTION in intent: + intent[CONF_ACTION] = script.Script(hass, intent[CONF_ACTION], + "Alexa intent {}".format(name)) + + _CONFIG.update(intents) hass.http.register_path('POST', API_ENDPOINT, _handle_alexa, True) @@ -90,8 +97,7 @@ def _handle_alexa(handler, path_match, data): card['content']) if action is not None: - script.call_from_config(handler.server.hass, action, - response.variables) + action.run(response.variables) handler.write_json(response.as_dict()) From 14bd630c1d87f3fb444d8298eb8c3373d589eb9d Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 23 Apr 2016 07:11:21 +0200 Subject: [PATCH 11/11] Service/Script cleanup --- homeassistant/helpers/config_validation.py | 6 ++--- homeassistant/helpers/script.py | 9 ++++---- homeassistant/helpers/service.py | 27 +++++++--------------- tests/helpers/test_service.py | 18 --------------- 4 files changed, 15 insertions(+), 45 deletions(-) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 71e103f7dd3..3a8e6179c86 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -264,14 +264,12 @@ SERVICE_SCHEMA = vol.All(vol.Schema({ vol.Optional('entity_id'): entity_ids, }), has_at_least_one_key('service', 'service_template')) -# ----- SCRIPT - -_DELAY_SCHEMA = vol.Schema({ +_SCRIPT_DELAY_SCHEMA = vol.Schema({ vol.Optional(CONF_ALIAS): string, vol.Required("delay"): vol.All(time_period, positive_timedelta) }) SCRIPT_SCHEMA = vol.All( ensure_list, - [vol.Any(SERVICE_SCHEMA, _DELAY_SCHEMA, EVENT_SCHEMA)], + [vol.Any(SERVICE_SCHEMA, _SCRIPT_DELAY_SCHEMA, EVENT_SCHEMA)], ) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 6c938fd4032..05e49c6e9ce 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -76,12 +76,12 @@ class Script(): self._change_listener() return - elif service.validate_service_call(action) is None: - self._call_service(action, variables) - elif CONF_EVENT in action: self._fire_event(action) + else: + self._call_service(action, variables) + self._cur = -1 self.last_action = None if self._change_listener: @@ -102,7 +102,8 @@ class Script(): """Call the service specified in the action.""" self.last_action = action.get(CONF_ALIAS, 'call service') self._log("Executing step %s", self.last_action) - service.call_from_config(self.hass, action, True, variables) + service.call_from_config(self.hass, action, True, variables, + validate_config=False) def _fire_event(self, action): """Fire an event.""" diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 8b89d856c50..95dce9516de 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -32,13 +32,15 @@ def service(domain, service_name): return register_service_decorator -def call_from_config(hass, config, blocking=False, variables=None): +def call_from_config(hass, config, blocking=False, variables=None, + validate_config=True): """Call a service based on a config hash.""" - try: - config = cv.SERVICE_SCHEMA(config) - except vol.Invalid as ex: - _LOGGER.error("Invalid config for calling service: %s", ex) - return + if validate_config: + try: + config = cv.SERVICE_SCHEMA(config) + except vol.Invalid as ex: + _LOGGER.error("Invalid config for calling service: %s", ex) + return if CONF_SERVICE in config: domain_service = config[CONF_SERVICE] @@ -85,16 +87,3 @@ def extract_entity_ids(hass, service_call): return group.expand_entity_ids(hass, [service_ent_id]) return [ent_id for ent_id in group.expand_entity_ids(hass, service_ent_id)] - - -def validate_service_call(config): - """Validate service call configuration. - - Helper method to validate that a configuration is a valid service call. - Returns None if validation succeeds, else an error description - """ - try: - cv.SERVICE_SCHEMA(config) - return None - except vol.Invalid as ex: - return str(ex) diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 11ace1ab5d8..5372b6a77d4 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -140,21 +140,3 @@ class TestServiceHelpers(unittest.TestCase): self.assertEqual(['light.ceiling', 'light.kitchen'], service.extract_entity_ids(self.hass, call)) - - def test_validate_service_call(self): - """Test is_valid_service_call method.""" - self.assertNotEqual( - service.validate_service_call( - {}), - None - ) - self.assertEqual( - service.validate_service_call( - {'service': 'test_domain.test_service'}), - None - ) - self.assertEqual( - service.validate_service_call( - {'service_template': 'test_domain.{{ \'test_service\' }}'}), - None - )