Script entities to allow passing in variables

This commit is contained in:
Paulus Schoutsen 2016-04-21 22:21:11 -04:00
parent 26863284b6
commit b8e4db9161
2 changed files with 58 additions and 16 deletions

View File

@ -26,12 +26,12 @@ DEPENDENCIES = ["group"]
CONF_SEQUENCE = "sequence" CONF_SEQUENCE = "sequence"
ATTR_VARIABLES = 'variables'
ATTR_LAST_ACTION = 'last_action' ATTR_LAST_ACTION = 'last_action'
ATTR_CAN_CANCEL = 'can_cancel' ATTR_CAN_CANCEL = 'can_cancel'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_SCRIPT_ENTRY_SCHEMA = vol.Schema({ _SCRIPT_ENTRY_SCHEMA = vol.Schema({
CONF_ALIAS: cv.string, CONF_ALIAS: cv.string,
vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA, vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA,
@ -41,9 +41,10 @@ CONFIG_SCHEMA = vol.Schema({
vol.Required(DOMAIN): {cv.slug: _SCRIPT_ENTRY_SCHEMA} vol.Required(DOMAIN): {cv.slug: _SCRIPT_ENTRY_SCHEMA}
}, extra=vol.ALLOW_EXTRA) }, extra=vol.ALLOW_EXTRA)
SCRIPT_SERVICE_SCHEMA = vol.Schema({}) SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)
SCRIPT_TURN_ONOFF_SCHEMA = vol.Schema({ SCRIPT_TURN_ONOFF_SCHEMA = vol.Schema({
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
vol.Optional(ATTR_VARIABLES): dict,
}) })
@ -52,11 +53,11 @@ def is_on(hass, entity_id):
return hass.states.is_state(entity_id, STATE_ON) return hass.states.is_state(entity_id, STATE_ON)
def turn_on(hass, entity_id): def turn_on(hass, entity_id, variables=None):
"""Turn script on.""" """Turn script on."""
_, object_id = split_entity_id(entity_id) _, object_id = split_entity_id(entity_id)
hass.services.call(DOMAIN, object_id) hass.services.call(DOMAIN, object_id, variables)
def turn_off(hass, entity_id): def turn_off(hass, entity_id):
@ -80,7 +81,7 @@ def setup(hass, config):
if script.is_on: if script.is_on:
_LOGGER.warning("Script %s already running.", entity_id) _LOGGER.warning("Script %s already running.", entity_id)
return return
script.turn_on() script.turn_on(variables=service.data)
for object_id, cfg in config[DOMAIN].items(): for object_id, cfg in config[DOMAIN].items():
alias = cfg.get(CONF_ALIAS, object_id) alias = cfg.get(CONF_ALIAS, object_id)
@ -92,9 +93,9 @@ def setup(hass, config):
def turn_on_service(service): def turn_on_service(service):
"""Call a service to turn script on.""" """Call a service to turn script on."""
# We could turn on script directly here, but we only want to offer # We could turn on script directly here, but we only want to offer
# one way to do it. Otherwise no easy way to call invocations. # one way to do it. Otherwise no easy way to detect invocations.
for script in component.extract_from_service(service): for script in component.extract_from_service(service):
turn_on(hass, script.entity_id) turn_on(hass, script.entity_id, service.data.get(ATTR_VARIABLES))
def turn_off_service(service): def turn_off_service(service):
"""Cancel a script.""" """Cancel a script."""
@ -151,7 +152,7 @@ class ScriptEntity(ToggleEntity):
def turn_on(self, **kwargs): def turn_on(self, **kwargs):
"""Turn the entity on.""" """Turn the entity on."""
self.script.run() self.script.run(kwargs.get(ATTR_VARIABLES))
def turn_off(self, **kwargs): def turn_off(self, **kwargs):
"""Turn script off.""" """Turn script off."""

View File

@ -50,11 +50,11 @@ class TestScriptComponent(unittest.TestCase):
def test_turn_on_service(self): def test_turn_on_service(self):
"""Verify that the turn_on service.""" """Verify that the turn_on service."""
event = 'test_event' event = 'test_event'
calls = [] events = []
def record_event(event): def record_event(event):
"""Add recorded event to set.""" """Add recorded event to set."""
calls.append(event) events.append(event)
self.hass.bus.listen(event, record_event) self.hass.bus.listen(event, record_event)
@ -75,21 +75,21 @@ class TestScriptComponent(unittest.TestCase):
script.turn_on(self.hass, ENTITY_ID) script.turn_on(self.hass, ENTITY_ID)
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
self.assertTrue(script.is_on(self.hass, ENTITY_ID)) self.assertTrue(script.is_on(self.hass, ENTITY_ID))
self.assertEqual(0, len(calls)) self.assertEqual(0, len(events))
# Calling turn_on a second time should not advance the script # Calling turn_on a second time should not advance the script
script.turn_on(self.hass, ENTITY_ID) script.turn_on(self.hass, ENTITY_ID)
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
self.assertEqual(0, len(calls)) self.assertEqual(0, len(events))
def test_toggle_service(self): def test_toggle_service(self):
"""Test the toggling of a service.""" """Test the toggling of a service."""
event = 'test_event' event = 'test_event'
calls = [] events = []
def record_event(event): def record_event(event):
"""Add recorded event to set.""" """Add recorded event to set."""
calls.append(event) events.append(event)
self.hass.bus.listen(event, record_event) self.hass.bus.listen(event, record_event)
@ -110,9 +110,50 @@ class TestScriptComponent(unittest.TestCase):
script.toggle(self.hass, ENTITY_ID) script.toggle(self.hass, ENTITY_ID)
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
self.assertTrue(script.is_on(self.hass, ENTITY_ID)) self.assertTrue(script.is_on(self.hass, ENTITY_ID))
self.assertEqual(0, len(calls)) self.assertEqual(0, len(events))
script.toggle(self.hass, ENTITY_ID) script.toggle(self.hass, ENTITY_ID)
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
self.assertFalse(script.is_on(self.hass, ENTITY_ID)) self.assertFalse(script.is_on(self.hass, ENTITY_ID))
self.assertEqual(0, len(calls)) self.assertEqual(0, len(events))
def test_passing_variables(self):
"""Test different ways of passing in variables."""
calls = []
def record_call(service):
"""Add recorded event to set."""
calls.append(service)
self.hass.services.register('test', 'script', record_call)
assert _setup_component(self.hass, 'script', {
'script': {
'test': {
'sequence': {
'service': 'test.script',
'data_template': {
'hello': '{{ greeting }}',
},
},
},
},
})
script.turn_on(self.hass, ENTITY_ID, {
'greeting': 'world'
})
self.hass.pool.block_till_done()
assert len(calls) == 1
assert calls[-1].data['hello'] == 'world'
self.hass.services.call('script', 'test', {
'greeting': 'universe',
})
self.hass.pool.block_till_done()
assert len(calls) == 2
assert calls[-1].data['hello'] == 'universe'