Add support for remove services / Reload script support (#6441)

* Add support for remove services / Reload script support

* Reload support for scripts

* Add more unittest for services

* Add unittest for script reload

* Address paulus comments
This commit is contained in:
Pascal Vizeli 2017-03-08 07:51:34 +01:00 committed by Paulus Schoutsen
parent e7f442d66b
commit c937a7bcb0
7 changed files with 197 additions and 35 deletions

View File

@ -15,7 +15,7 @@ from homeassistant.setup import async_prepare_setup_platform
from homeassistant import config as conf_util
from homeassistant.const import (
ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF,
SERVICE_TOGGLE)
SERVICE_TOGGLE, SERVICE_RELOAD)
from homeassistant.components import logbook
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import extract_domain_configs, script, condition
@ -51,7 +51,6 @@ DEFAULT_INITIAL_STATE = True
ATTR_LAST_TRIGGERED = 'last_triggered'
ATTR_VARIABLES = 'variables'
SERVICE_TRIGGER = 'trigger'
SERVICE_RELOAD = 'reload'
_LOGGER = logging.getLogger(__name__)

View File

@ -14,7 +14,7 @@ from homeassistant import config as conf_util, core as ha
from homeassistant.const import (
ATTR_ENTITY_ID, CONF_ICON, CONF_NAME, STATE_CLOSED, STATE_HOME,
STATE_NOT_HOME, STATE_OFF, STATE_ON, STATE_OPEN, STATE_LOCKED,
STATE_UNLOCKED, STATE_UNKNOWN, ATTR_ASSUMED_STATE)
STATE_UNLOCKED, STATE_UNKNOWN, ATTR_ASSUMED_STATE, SERVICE_RELOAD)
from homeassistant.core import callback
from homeassistant.helpers.entity import Entity, async_generate_entity_id
from homeassistant.helpers.entity_component import EntityComponent
@ -42,7 +42,6 @@ SET_VISIBILITY_SERVICE_SCHEMA = vol.Schema({
vol.Required(ATTR_VISIBLE): cv.boolean
})
SERVICE_RELOAD = 'reload'
RELOAD_SERVICE_SCHEMA = vol.Schema({})
_LOGGER = logging.getLogger(__name__)
@ -395,17 +394,16 @@ class Group(Entity):
self._state = STATE_UNKNOWN
self._async_update_group_state()
@asyncio.coroutine
def async_remove(self):
"""Remove group from HASS.
This method must be run in the event loop.
This method must be run in the event loop and returns a coroutine.
"""
if self._async_unsub_state_changed:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
yield from super().async_remove()
return super().async_remove()
@asyncio.coroutine
def _async_state_changed_listener(self, entity_id, old_state, new_state):

View File

@ -14,7 +14,7 @@ import voluptuous as vol
from homeassistant.const import (
ATTR_ENTITY_ID, SERVICE_TURN_OFF, SERVICE_TURN_ON,
SERVICE_TOGGLE, STATE_ON, CONF_ALIAS)
SERVICE_TOGGLE, SERVICE_RELOAD, STATE_ON, CONF_ALIAS)
from homeassistant.core import split_entity_id
from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent
@ -49,6 +49,7 @@ SCRIPT_TURN_ONOFF_SCHEMA = vol.Schema({
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
vol.Optional(ATTR_VARIABLES): dict,
})
RELOAD_SERVICE_SCHEMA = vol.Schema({})
def is_on(hass, entity_id):
@ -56,6 +57,11 @@ def is_on(hass, entity_id):
return hass.states.is_state(entity_id, STATE_ON)
def reload(hass):
"""Reload script component."""
hass.services.call(DOMAIN, SERVICE_RELOAD)
def turn_on(hass, entity_id, variables=None):
"""Turn script on."""
_, object_id = split_entity_id(entity_id)
@ -76,29 +82,19 @@ def toggle(hass, entity_id):
@asyncio.coroutine
def async_setup(hass, config):
"""Load the scripts from the configuration."""
component = EntityComponent(_LOGGER, DOMAIN, hass,
group_name=GROUP_NAME_ALL_SCRIPTS)
component = EntityComponent(
_LOGGER, DOMAIN, hass, group_name=GROUP_NAME_ALL_SCRIPTS)
yield from _async_process_config(hass, config, component)
@asyncio.coroutine
def service_handler(service):
"""Execute a service call to script.<script name>."""
entity_id = ENTITY_ID_FORMAT.format(service.service)
script = component.entities.get(entity_id)
if script.is_on:
_LOGGER.warning("Script %s already running.", entity_id)
def reload_service(service):
"""Call a service to reload scripts."""
conf = yield from component.async_prepare_reload()
if conf is None:
return
yield from script.async_turn_on(variables=service.data)
scripts = []
for object_id, cfg in config[DOMAIN].items():
alias = cfg.get(CONF_ALIAS, object_id)
script = ScriptEntity(hass, object_id, alias, cfg[CONF_SEQUENCE])
scripts.append(script)
hass.services.async_register(DOMAIN, object_id, service_handler,
schema=SCRIPT_SERVICE_SCHEMA)
yield from component.async_add_entities(scripts)
yield from _async_process_config(hass, conf, component)
@asyncio.coroutine
def turn_on_service(service):
@ -123,6 +119,8 @@ def async_setup(hass, config):
for script in component.async_extract_from_service(service):
yield from script.async_toggle()
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service,
schema=RELOAD_SERVICE_SCHEMA)
hass.services.async_register(DOMAIN, SERVICE_TURN_ON, turn_on_service,
schema=SCRIPT_TURN_ONOFF_SCHEMA)
hass.services.async_register(DOMAIN, SERVICE_TURN_OFF, turn_off_service,
@ -133,6 +131,31 @@ def async_setup(hass, config):
return True
@asyncio.coroutine
def _async_process_config(hass, config, component):
"""Process group configuration."""
@asyncio.coroutine
def service_handler(service):
"""Execute a service call to script.<script name>."""
entity_id = ENTITY_ID_FORMAT.format(service.service)
script = component.entities.get(entity_id)
if script.is_on:
_LOGGER.warning("Script %s already running.", entity_id)
return
yield from script.async_turn_on(variables=service.data)
scripts = []
for object_id, cfg in config[DOMAIN].items():
alias = cfg.get(CONF_ALIAS, object_id)
script = ScriptEntity(hass, object_id, alias, cfg[CONF_SEQUENCE])
scripts.append(script)
hass.services.async_register(
DOMAIN, object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA)
yield from component.async_add_entities(scripts)
class ScriptEntity(ToggleEntity):
"""Representation of a script entity."""
@ -177,3 +200,16 @@ class ScriptEntity(ToggleEntity):
def async_turn_off(self, **kwargs):
"""Turn script off."""
self.script.async_stop()
def async_remove(self):
"""Remove script from HASS.
This method must be run in the event loop and returns a coroutine.
"""
if self.script.is_running:
self.script.async_stop()
# remove service
self.hass.services.async_remove(DOMAIN, self.object_id)
return super().async_remove()

View File

@ -166,6 +166,7 @@ EVENT_SERVICE_EXECUTED = 'service_executed'
EVENT_PLATFORM_DISCOVERED = 'platform_discovered'
EVENT_COMPONENT_LOADED = 'component_loaded'
EVENT_SERVICE_REGISTERED = 'service_registered'
EVENT_SERVICE_REMOVED = 'service_removed'
# #### STATES ####
STATE_ON = 'on'
@ -305,6 +306,7 @@ SERVICE_HOMEASSISTANT_RESTART = 'restart'
SERVICE_TURN_ON = 'turn_on'
SERVICE_TURN_OFF = 'turn_off'
SERVICE_TOGGLE = 'toggle'
SERVICE_RELOAD = 'reload'
SERVICE_VOLUME_UP = 'volume_up'
SERVICE_VOLUME_DOWN = 'volume_down'

View File

@ -26,7 +26,8 @@ from homeassistant.const import (
ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE, __version__)
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE,
EVENT_SERVICE_REMOVED, __version__)
from homeassistant.exceptions import (
HomeAssistantError, InvalidEntityFormatError, ShuttingDown)
from homeassistant.util.async import (
@ -864,6 +865,32 @@ class ServiceRegistry(object):
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
)
def remove(self, domain, service):
"""Remove a registered service from service handler."""
run_callback_threadsafe(
self._hass.loop, self.async_remove, domain, service).result()
@callback
def async_remove(self, domain, service):
"""Remove a registered service from service handler.
This method must be run in the event loop.
"""
domain = domain.lower()
service = service.lower()
if service not in self._services.get(domain, {}):
_LOGGER.warning(
"Unable to remove unknown service %s/%s.", domain, service)
return
self._services[domain].pop(service)
self._hass.bus.async_fire(
EVENT_SERVICE_REMOVED,
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
)
def call(self, domain, service, service_data=None, blocking=False):
"""
Call a service.

View File

@ -1,6 +1,7 @@
"""The tests for the Script component."""
# pylint: disable=protected-access
import unittest
from unittest.mock import patch
from homeassistant.core import callback
from homeassistant.setup import setup_component
@ -172,3 +173,38 @@ class TestScriptComponent(unittest.TestCase):
assert len(calls) == 2
assert calls[-1].data['hello'] == 'universe'
def test_reload_service(self):
"""Verify that the turn_on service."""
assert setup_component(self.hass, 'script', {
'script': {
'test': {
'sequence': [{
'delay': {
'seconds': 5
}
}]
}
}
})
assert self.hass.states.get(ENTITY_ID) is not None
assert self.hass.services.has_service(script.DOMAIN, 'test')
with patch('homeassistant.config.load_yaml_config_file', return_value={
'script': {
'test2': {
'sequence': [{
'delay': {
'seconds': 5
}
}]
}}}):
script.reload(self.hass)
self.hass.block_till_done()
assert self.hass.states.get(ENTITY_ID) is None
assert not self.hass.services.has_service(script.DOMAIN, 'test')
assert self.hass.states.get("script.test2") is not None
assert self.hass.services.has_service(script.DOMAIN, 'test2')

View File

@ -16,7 +16,8 @@ from homeassistant.util.unit_system import (METRIC_SYSTEM)
from homeassistant.const import (
__version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM,
ATTR_NOW, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_START)
EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_START,
EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED)
from tests.common import get_test_home_assistant
@ -619,6 +620,15 @@ class TestServiceRegistry(unittest.TestCase):
self.services.register("Test_Domain", "TEST_SERVICE", mock_service)
self.calls_register = []
@ha.callback
def mock_event_register(event):
"""Mock register event."""
self.calls_register.append(event)
self.hass.bus.listen(EVENT_SERVICE_REGISTERED, mock_event_register)
# pylint: disable=invalid-name
def tearDown(self):
"""Stop down stuff we started."""
@ -649,8 +659,13 @@ class TestServiceRegistry(unittest.TestCase):
"""Service handler."""
calls.append(call)
self.services.register("test_domain", "register_calls",
service_handler)
self.services.register(
"test_domain", "register_calls", service_handler)
self.hass.block_till_done()
assert len(self.calls_register) == 1
assert self.calls_register[-1].data['domain'] == 'test_domain'
assert self.calls_register[-1].data['service'] == 'register_calls'
self.assertTrue(
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
@ -675,8 +690,14 @@ class TestServiceRegistry(unittest.TestCase):
"""Service handler coroutine."""
calls.append(call)
self.services.register('test_domain', 'register_calls',
service_handler)
self.services.register(
'test_domain', 'register_calls', service_handler)
self.hass.block_till_done()
assert len(self.calls_register) == 1
assert self.calls_register[-1].data['domain'] == 'test_domain'
assert self.calls_register[-1].data['service'] == 'register_calls'
self.assertTrue(
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
self.hass.block_till_done()
@ -691,13 +712,56 @@ class TestServiceRegistry(unittest.TestCase):
"""Service handler coroutine."""
calls.append(call)
self.services.register('test_domain', 'register_calls',
service_handler)
self.services.register(
'test_domain', 'register_calls', service_handler)
self.hass.block_till_done()
assert len(self.calls_register) == 1
assert self.calls_register[-1].data['domain'] == 'test_domain'
assert self.calls_register[-1].data['service'] == 'register_calls'
self.assertTrue(
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
self.hass.block_till_done()
self.assertEqual(1, len(calls))
def test_remove_service(self):
"""Test remove service."""
calls_remove = []
@ha.callback
def mock_event_remove(event):
"""Mock register event."""
calls_remove.append(event)
self.hass.bus.listen(EVENT_SERVICE_REMOVED, mock_event_remove)
assert self.services.has_service('test_Domain', 'test_Service')
self.services.remove('test_Domain', 'test_Service')
self.hass.block_till_done()
assert not self.services.has_service('test_Domain', 'test_Service')
assert len(calls_remove) == 1
assert calls_remove[-1].data['domain'] == 'test_domain'
assert calls_remove[-1].data['service'] == 'test_service'
def test_remove_service_that_not_exists(self):
"""Test remove service that not exists."""
calls_remove = []
@ha.callback
def mock_event_remove(event):
"""Mock register event."""
calls_remove.append(event)
self.hass.bus.listen(EVENT_SERVICE_REMOVED, mock_event_remove)
assert not self.services.has_service('test_xxx', 'test_yyy')
self.services.remove('test_xxx', 'test_yyy')
self.hass.block_till_done()
assert len(calls_remove) == 0
class TestConfig(unittest.TestCase):
"""Test configuration methods."""