mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Add support for remove services / Reload script support (#6441)
* Add support for remove services / Reload script support * Reload support for scripts * Add more unittest for services * Add unittest for script reload * Address paulus comments
This commit is contained in:
parent
e7f442d66b
commit
c937a7bcb0
@ -15,7 +15,7 @@ from homeassistant.setup import async_prepare_setup_platform
|
||||
from homeassistant import config as conf_util
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF,
|
||||
SERVICE_TOGGLE)
|
||||
SERVICE_TOGGLE, SERVICE_RELOAD)
|
||||
from homeassistant.components import logbook
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import extract_domain_configs, script, condition
|
||||
@ -51,7 +51,6 @@ DEFAULT_INITIAL_STATE = True
|
||||
ATTR_LAST_TRIGGERED = 'last_triggered'
|
||||
ATTR_VARIABLES = 'variables'
|
||||
SERVICE_TRIGGER = 'trigger'
|
||||
SERVICE_RELOAD = 'reload'
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -14,7 +14,7 @@ from homeassistant import config as conf_util, core as ha
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, CONF_ICON, CONF_NAME, STATE_CLOSED, STATE_HOME,
|
||||
STATE_NOT_HOME, STATE_OFF, STATE_ON, STATE_OPEN, STATE_LOCKED,
|
||||
STATE_UNLOCKED, STATE_UNKNOWN, ATTR_ASSUMED_STATE)
|
||||
STATE_UNLOCKED, STATE_UNKNOWN, ATTR_ASSUMED_STATE, SERVICE_RELOAD)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.entity import Entity, async_generate_entity_id
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
@ -42,7 +42,6 @@ SET_VISIBILITY_SERVICE_SCHEMA = vol.Schema({
|
||||
vol.Required(ATTR_VISIBLE): cv.boolean
|
||||
})
|
||||
|
||||
SERVICE_RELOAD = 'reload'
|
||||
RELOAD_SERVICE_SCHEMA = vol.Schema({})
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -395,17 +394,16 @@ class Group(Entity):
|
||||
self._state = STATE_UNKNOWN
|
||||
self._async_update_group_state()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_remove(self):
|
||||
"""Remove group from HASS.
|
||||
|
||||
This method must be run in the event loop.
|
||||
This method must be run in the event loop and returns a coroutine.
|
||||
"""
|
||||
if self._async_unsub_state_changed:
|
||||
self._async_unsub_state_changed()
|
||||
self._async_unsub_state_changed = None
|
||||
|
||||
yield from super().async_remove()
|
||||
return super().async_remove()
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_state_changed_listener(self, entity_id, old_state, new_state):
|
||||
|
@ -14,7 +14,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, SERVICE_TURN_OFF, SERVICE_TURN_ON,
|
||||
SERVICE_TOGGLE, STATE_ON, CONF_ALIAS)
|
||||
SERVICE_TOGGLE, SERVICE_RELOAD, STATE_ON, CONF_ALIAS)
|
||||
from homeassistant.core import split_entity_id
|
||||
from homeassistant.helpers.entity import ToggleEntity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
@ -49,6 +49,7 @@ SCRIPT_TURN_ONOFF_SCHEMA = vol.Schema({
|
||||
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
|
||||
vol.Optional(ATTR_VARIABLES): dict,
|
||||
})
|
||||
RELOAD_SERVICE_SCHEMA = vol.Schema({})
|
||||
|
||||
|
||||
def is_on(hass, entity_id):
|
||||
@ -56,6 +57,11 @@ def is_on(hass, entity_id):
|
||||
return hass.states.is_state(entity_id, STATE_ON)
|
||||
|
||||
|
||||
def reload(hass):
|
||||
"""Reload script component."""
|
||||
hass.services.call(DOMAIN, SERVICE_RELOAD)
|
||||
|
||||
|
||||
def turn_on(hass, entity_id, variables=None):
|
||||
"""Turn script on."""
|
||||
_, object_id = split_entity_id(entity_id)
|
||||
@ -76,29 +82,19 @@ def toggle(hass, entity_id):
|
||||
@asyncio.coroutine
|
||||
def async_setup(hass, config):
|
||||
"""Load the scripts from the configuration."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass,
|
||||
group_name=GROUP_NAME_ALL_SCRIPTS)
|
||||
component = EntityComponent(
|
||||
_LOGGER, DOMAIN, hass, group_name=GROUP_NAME_ALL_SCRIPTS)
|
||||
|
||||
yield from _async_process_config(hass, config, component)
|
||||
|
||||
@asyncio.coroutine
|
||||
def service_handler(service):
|
||||
"""Execute a service call to script.<script name>."""
|
||||
entity_id = ENTITY_ID_FORMAT.format(service.service)
|
||||
script = component.entities.get(entity_id)
|
||||
if script.is_on:
|
||||
_LOGGER.warning("Script %s already running.", entity_id)
|
||||
def reload_service(service):
|
||||
"""Call a service to reload scripts."""
|
||||
conf = yield from component.async_prepare_reload()
|
||||
if conf is None:
|
||||
return
|
||||
yield from script.async_turn_on(variables=service.data)
|
||||
|
||||
scripts = []
|
||||
|
||||
for object_id, cfg in config[DOMAIN].items():
|
||||
alias = cfg.get(CONF_ALIAS, object_id)
|
||||
script = ScriptEntity(hass, object_id, alias, cfg[CONF_SEQUENCE])
|
||||
scripts.append(script)
|
||||
hass.services.async_register(DOMAIN, object_id, service_handler,
|
||||
schema=SCRIPT_SERVICE_SCHEMA)
|
||||
|
||||
yield from component.async_add_entities(scripts)
|
||||
yield from _async_process_config(hass, conf, component)
|
||||
|
||||
@asyncio.coroutine
|
||||
def turn_on_service(service):
|
||||
@ -123,6 +119,8 @@ def async_setup(hass, config):
|
||||
for script in component.async_extract_from_service(service):
|
||||
yield from script.async_toggle()
|
||||
|
||||
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service,
|
||||
schema=RELOAD_SERVICE_SCHEMA)
|
||||
hass.services.async_register(DOMAIN, SERVICE_TURN_ON, turn_on_service,
|
||||
schema=SCRIPT_TURN_ONOFF_SCHEMA)
|
||||
hass.services.async_register(DOMAIN, SERVICE_TURN_OFF, turn_off_service,
|
||||
@ -133,6 +131,31 @@ def async_setup(hass, config):
|
||||
return True
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_process_config(hass, config, component):
|
||||
"""Process group configuration."""
|
||||
@asyncio.coroutine
|
||||
def service_handler(service):
|
||||
"""Execute a service call to script.<script name>."""
|
||||
entity_id = ENTITY_ID_FORMAT.format(service.service)
|
||||
script = component.entities.get(entity_id)
|
||||
if script.is_on:
|
||||
_LOGGER.warning("Script %s already running.", entity_id)
|
||||
return
|
||||
yield from script.async_turn_on(variables=service.data)
|
||||
|
||||
scripts = []
|
||||
|
||||
for object_id, cfg in config[DOMAIN].items():
|
||||
alias = cfg.get(CONF_ALIAS, object_id)
|
||||
script = ScriptEntity(hass, object_id, alias, cfg[CONF_SEQUENCE])
|
||||
scripts.append(script)
|
||||
hass.services.async_register(
|
||||
DOMAIN, object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA)
|
||||
|
||||
yield from component.async_add_entities(scripts)
|
||||
|
||||
|
||||
class ScriptEntity(ToggleEntity):
|
||||
"""Representation of a script entity."""
|
||||
|
||||
@ -177,3 +200,16 @@ class ScriptEntity(ToggleEntity):
|
||||
def async_turn_off(self, **kwargs):
|
||||
"""Turn script off."""
|
||||
self.script.async_stop()
|
||||
|
||||
def async_remove(self):
|
||||
"""Remove script from HASS.
|
||||
|
||||
This method must be run in the event loop and returns a coroutine.
|
||||
"""
|
||||
if self.script.is_running:
|
||||
self.script.async_stop()
|
||||
|
||||
# remove service
|
||||
self.hass.services.async_remove(DOMAIN, self.object_id)
|
||||
|
||||
return super().async_remove()
|
||||
|
@ -166,6 +166,7 @@ EVENT_SERVICE_EXECUTED = 'service_executed'
|
||||
EVENT_PLATFORM_DISCOVERED = 'platform_discovered'
|
||||
EVENT_COMPONENT_LOADED = 'component_loaded'
|
||||
EVENT_SERVICE_REGISTERED = 'service_registered'
|
||||
EVENT_SERVICE_REMOVED = 'service_removed'
|
||||
|
||||
# #### STATES ####
|
||||
STATE_ON = 'on'
|
||||
@ -305,6 +306,7 @@ SERVICE_HOMEASSISTANT_RESTART = 'restart'
|
||||
SERVICE_TURN_ON = 'turn_on'
|
||||
SERVICE_TURN_OFF = 'turn_off'
|
||||
SERVICE_TOGGLE = 'toggle'
|
||||
SERVICE_RELOAD = 'reload'
|
||||
|
||||
SERVICE_VOLUME_UP = 'volume_up'
|
||||
SERVICE_VOLUME_DOWN = 'volume_down'
|
||||
|
@ -26,7 +26,8 @@ from homeassistant.const import (
|
||||
ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
|
||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
||||
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
|
||||
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE, __version__)
|
||||
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE,
|
||||
EVENT_SERVICE_REMOVED, __version__)
|
||||
from homeassistant.exceptions import (
|
||||
HomeAssistantError, InvalidEntityFormatError, ShuttingDown)
|
||||
from homeassistant.util.async import (
|
||||
@ -864,6 +865,32 @@ class ServiceRegistry(object):
|
||||
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
|
||||
)
|
||||
|
||||
def remove(self, domain, service):
|
||||
"""Remove a registered service from service handler."""
|
||||
run_callback_threadsafe(
|
||||
self._hass.loop, self.async_remove, domain, service).result()
|
||||
|
||||
@callback
|
||||
def async_remove(self, domain, service):
|
||||
"""Remove a registered service from service handler.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
domain = domain.lower()
|
||||
service = service.lower()
|
||||
|
||||
if service not in self._services.get(domain, {}):
|
||||
_LOGGER.warning(
|
||||
"Unable to remove unknown service %s/%s.", domain, service)
|
||||
return
|
||||
|
||||
self._services[domain].pop(service)
|
||||
|
||||
self._hass.bus.async_fire(
|
||||
EVENT_SERVICE_REMOVED,
|
||||
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
|
||||
)
|
||||
|
||||
def call(self, domain, service, service_data=None, blocking=False):
|
||||
"""
|
||||
Call a service.
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""The tests for the Script component."""
|
||||
# pylint: disable=protected-access
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.setup import setup_component
|
||||
@ -172,3 +173,38 @@ class TestScriptComponent(unittest.TestCase):
|
||||
|
||||
assert len(calls) == 2
|
||||
assert calls[-1].data['hello'] == 'universe'
|
||||
|
||||
def test_reload_service(self):
|
||||
"""Verify that the turn_on service."""
|
||||
assert setup_component(self.hass, 'script', {
|
||||
'script': {
|
||||
'test': {
|
||||
'sequence': [{
|
||||
'delay': {
|
||||
'seconds': 5
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
assert self.hass.states.get(ENTITY_ID) is not None
|
||||
assert self.hass.services.has_service(script.DOMAIN, 'test')
|
||||
|
||||
with patch('homeassistant.config.load_yaml_config_file', return_value={
|
||||
'script': {
|
||||
'test2': {
|
||||
'sequence': [{
|
||||
'delay': {
|
||||
'seconds': 5
|
||||
}
|
||||
}]
|
||||
}}}):
|
||||
script.reload(self.hass)
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert self.hass.states.get(ENTITY_ID) is None
|
||||
assert not self.hass.services.has_service(script.DOMAIN, 'test')
|
||||
|
||||
assert self.hass.states.get("script.test2") is not None
|
||||
assert self.hass.services.has_service(script.DOMAIN, 'test2')
|
||||
|
@ -16,7 +16,8 @@ from homeassistant.util.unit_system import (METRIC_SYSTEM)
|
||||
from homeassistant.const import (
|
||||
__version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM,
|
||||
ATTR_NOW, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
|
||||
EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_START)
|
||||
EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_START,
|
||||
EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED)
|
||||
|
||||
from tests.common import get_test_home_assistant
|
||||
|
||||
@ -619,6 +620,15 @@ class TestServiceRegistry(unittest.TestCase):
|
||||
|
||||
self.services.register("Test_Domain", "TEST_SERVICE", mock_service)
|
||||
|
||||
self.calls_register = []
|
||||
|
||||
@ha.callback
|
||||
def mock_event_register(event):
|
||||
"""Mock register event."""
|
||||
self.calls_register.append(event)
|
||||
|
||||
self.hass.bus.listen(EVENT_SERVICE_REGISTERED, mock_event_register)
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def tearDown(self):
|
||||
"""Stop down stuff we started."""
|
||||
@ -649,8 +659,13 @@ class TestServiceRegistry(unittest.TestCase):
|
||||
"""Service handler."""
|
||||
calls.append(call)
|
||||
|
||||
self.services.register("test_domain", "register_calls",
|
||||
service_handler)
|
||||
self.services.register(
|
||||
"test_domain", "register_calls", service_handler)
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(self.calls_register) == 1
|
||||
assert self.calls_register[-1].data['domain'] == 'test_domain'
|
||||
assert self.calls_register[-1].data['service'] == 'register_calls'
|
||||
|
||||
self.assertTrue(
|
||||
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
|
||||
@ -675,8 +690,14 @@ class TestServiceRegistry(unittest.TestCase):
|
||||
"""Service handler coroutine."""
|
||||
calls.append(call)
|
||||
|
||||
self.services.register('test_domain', 'register_calls',
|
||||
service_handler)
|
||||
self.services.register(
|
||||
'test_domain', 'register_calls', service_handler)
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(self.calls_register) == 1
|
||||
assert self.calls_register[-1].data['domain'] == 'test_domain'
|
||||
assert self.calls_register[-1].data['service'] == 'register_calls'
|
||||
|
||||
self.assertTrue(
|
||||
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
|
||||
self.hass.block_till_done()
|
||||
@ -691,13 +712,56 @@ class TestServiceRegistry(unittest.TestCase):
|
||||
"""Service handler coroutine."""
|
||||
calls.append(call)
|
||||
|
||||
self.services.register('test_domain', 'register_calls',
|
||||
service_handler)
|
||||
self.services.register(
|
||||
'test_domain', 'register_calls', service_handler)
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(self.calls_register) == 1
|
||||
assert self.calls_register[-1].data['domain'] == 'test_domain'
|
||||
assert self.calls_register[-1].data['service'] == 'register_calls'
|
||||
|
||||
self.assertTrue(
|
||||
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(calls))
|
||||
|
||||
def test_remove_service(self):
|
||||
"""Test remove service."""
|
||||
calls_remove = []
|
||||
|
||||
@ha.callback
|
||||
def mock_event_remove(event):
|
||||
"""Mock register event."""
|
||||
calls_remove.append(event)
|
||||
|
||||
self.hass.bus.listen(EVENT_SERVICE_REMOVED, mock_event_remove)
|
||||
|
||||
assert self.services.has_service('test_Domain', 'test_Service')
|
||||
|
||||
self.services.remove('test_Domain', 'test_Service')
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert not self.services.has_service('test_Domain', 'test_Service')
|
||||
assert len(calls_remove) == 1
|
||||
assert calls_remove[-1].data['domain'] == 'test_domain'
|
||||
assert calls_remove[-1].data['service'] == 'test_service'
|
||||
|
||||
def test_remove_service_that_not_exists(self):
|
||||
"""Test remove service that not exists."""
|
||||
calls_remove = []
|
||||
|
||||
@ha.callback
|
||||
def mock_event_remove(event):
|
||||
"""Mock register event."""
|
||||
calls_remove.append(event)
|
||||
|
||||
self.hass.bus.listen(EVENT_SERVICE_REMOVED, mock_event_remove)
|
||||
|
||||
assert not self.services.has_service('test_xxx', 'test_yyy')
|
||||
self.services.remove('test_xxx', 'test_yyy')
|
||||
self.hass.block_till_done()
|
||||
assert len(calls_remove) == 0
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
"""Test configuration methods."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user