From e9813b219e240fcbd67ff6b6e101333ebfe5a35a Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 4 Sep 2016 17:15:52 +0200 Subject: [PATCH] Allow reloading automation without restarting HA (#3002) --- homeassistant/bootstrap.py | 133 ++++++++++-------- .../components/automation/__init__.py | 103 ++++++++++---- .../components/automation/services.yaml | 34 +++++ homeassistant/helpers/entity.py | 4 + homeassistant/helpers/entity_component.py | 60 +++++--- tests/components/automation/test_init.py | 130 +++++++++++++++++ tests/test_bootstrap.py | 8 +- 7 files changed, 365 insertions(+), 107 deletions(-) create mode 100644 homeassistant/components/automation/services.yaml diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index 4b526c40b38..3e8ed6ad77f 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -90,67 +90,12 @@ def _setup_component(hass: core.HomeAssistant, domain: str, config) -> bool: domain, domain) return False + config = prepare_setup_component(hass, config, domain) + + if config is None: + return False + component = loader.get_component(domain) - missing_deps = [dep for dep in getattr(component, 'DEPENDENCIES', []) - if dep not in hass.config.components] - - if missing_deps: - _LOGGER.error( - 'Not initializing %s because not all dependencies loaded: %s', - domain, ", ".join(missing_deps)) - return False - - if hasattr(component, 'CONFIG_SCHEMA'): - try: - config = component.CONFIG_SCHEMA(config) - except vol.MultipleInvalid as ex: - log_exception(ex, domain, config) - return False - - elif hasattr(component, 'PLATFORM_SCHEMA'): - platforms = [] - for p_name, p_config in config_per_platform(config, domain): - # Validate component specific platform schema - try: - p_validated = component.PLATFORM_SCHEMA(p_config) - except vol.MultipleInvalid as ex: - log_exception(ex, domain, p_config) - return False - - # Not all platform components follow same pattern for platforms - # So if p_name is None we are not going to validate platform - # (the automation component is one of them) - if p_name is None: - platforms.append(p_validated) - continue - - platform = prepare_setup_platform(hass, config, domain, - p_name) - - if platform is None: - return False - - # Validate platform specific schema - if hasattr(platform, 'PLATFORM_SCHEMA'): - try: - p_validated = platform.PLATFORM_SCHEMA(p_validated) - except vol.MultipleInvalid as ex: - log_exception(ex, '{}.{}'.format(domain, p_name), - p_validated) - return False - - platforms.append(p_validated) - - # Create a copy of the configuration with all config for current - # component removed and add validated config back in. - filter_keys = extract_domain_configs(config, domain) - config = {key: value for key, value in config.items() - if key not in filter_keys} - config[domain] = platforms - - if not _handle_requirements(hass, component, domain): - return False - _CURRENT_SETUP.append(domain) try: @@ -182,6 +127,74 @@ def _setup_component(hass: core.HomeAssistant, domain: str, config) -> bool: return True +def prepare_setup_component(hass: core.HomeAssistant, config: dict, + domain: str): + """Prepare setup of a component and return processed config.""" + # pylint: disable=too-many-return-statements + component = loader.get_component(domain) + missing_deps = [dep for dep in getattr(component, 'DEPENDENCIES', []) + if dep not in hass.config.components] + + if missing_deps: + _LOGGER.error( + 'Not initializing %s because not all dependencies loaded: %s', + domain, ", ".join(missing_deps)) + return None + + if hasattr(component, 'CONFIG_SCHEMA'): + try: + config = component.CONFIG_SCHEMA(config) + except vol.MultipleInvalid as ex: + log_exception(ex, domain, config) + return None + + elif hasattr(component, 'PLATFORM_SCHEMA'): + platforms = [] + for p_name, p_config in config_per_platform(config, domain): + # Validate component specific platform schema + try: + p_validated = component.PLATFORM_SCHEMA(p_config) + except vol.MultipleInvalid as ex: + log_exception(ex, domain, p_config) + return None + + # Not all platform components follow same pattern for platforms + # So if p_name is None we are not going to validate platform + # (the automation component is one of them) + if p_name is None: + platforms.append(p_validated) + continue + + platform = prepare_setup_platform(hass, config, domain, + p_name) + + if platform is None: + return None + + # Validate platform specific schema + if hasattr(platform, 'PLATFORM_SCHEMA'): + try: + p_validated = platform.PLATFORM_SCHEMA(p_validated) + except vol.MultipleInvalid as ex: + log_exception(ex, '{}.{}'.format(domain, p_name), + p_validated) + return None + + platforms.append(p_validated) + + # Create a copy of the configuration with all config for current + # component removed and add validated config back in. + filter_keys = extract_domain_configs(config, domain) + config = {key: value for key, value in config.items() + if key not in filter_keys} + config[domain] = platforms + + if not _handle_requirements(hass, component, domain): + return None + + return config + + def prepare_setup_platform(hass: core.HomeAssistant, config, domain: str, platform_name: str) -> Optional[ModuleType]: """Load a platform and makes sure dependencies are setup.""" diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 6f5396afa15..40715bca502 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -6,10 +6,13 @@ https://home-assistant.io/components/automation/ """ from functools import partial import logging +import os import voluptuous as vol -from homeassistant.bootstrap import prepare_setup_platform +from homeassistant.bootstrap import ( + prepare_setup_platform, prepare_setup_component) +from homeassistant import config as conf_util from homeassistant.const import ( ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE) @@ -46,6 +49,7 @@ METHOD_IF_ACTION = 'if_action' ATTR_LAST_TRIGGERED = 'last_triggered' ATTR_VARIABLES = 'variables' SERVICE_TRIGGER = 'trigger' +SERVICE_RELOAD = 'reload' _LOGGER = logging.getLogger(__name__) @@ -112,6 +116,8 @@ TRIGGER_SERVICE_SCHEMA = vol.Schema({ vol.Optional(ATTR_VARIABLES, default={}): dict, }) +RELOAD_SERVICE_SCHEMA = vol.Schema({}) + def is_on(hass, entity_id=None): """ @@ -148,40 +154,23 @@ def trigger(hass, entity_id=None): hass.services.call(DOMAIN, SERVICE_TRIGGER, data) +def reload(hass): + """Reload the automation from config.""" + hass.services.call(DOMAIN, SERVICE_RELOAD) + + def setup(hass, config): """Setup the automation.""" - # pylint: disable=too-many-locals component = EntityComponent(_LOGGER, DOMAIN, hass) - success = False - for config_key in extract_domain_configs(config, DOMAIN): - conf = config[config_key] - - for list_no, config_block in enumerate(conf): - name = config_block.get(CONF_ALIAS) or "{} {}".format(config_key, - list_no) - - action = _get_action(hass, config_block.get(CONF_ACTION, {}), name) - - if CONF_CONDITION in config_block: - cond_func = _process_if(hass, config, config_block) - - if cond_func is None: - continue - else: - def cond_func(variables): - """Condition will always pass.""" - return True - - attach_triggers = partial(_process_trigger, hass, config, - config_block.get(CONF_TRIGGER, []), name) - entity = AutomationEntity(name, attach_triggers, cond_func, action) - component.add_entities((entity,)) - success = True + success = _process_config(hass, config, component) if not success: return False + descriptions = conf_util.load_yaml_config_file( + os.path.join(os.path.dirname(__file__), 'services.yaml')) + def trigger_service_handler(service_call): """Handle automation triggers.""" for entity in component.extract_from_service(service_call): @@ -192,11 +181,34 @@ def setup(hass, config): for entity in component.extract_from_service(service_call): getattr(entity, service_call.service)() + def reload_service_handler(service_call): + """Remove all automations and load new ones from config.""" + try: + path = conf_util.find_config_file(hass.config.config_dir) + conf = conf_util.load_yaml_config_file(path) + except HomeAssistantError as err: + _LOGGER.error(err) + return + + conf = prepare_setup_component(hass, conf, DOMAIN) + + if conf is None: + return + + component.reset() + _process_config(hass, conf, component) + hass.services.register(DOMAIN, SERVICE_TRIGGER, trigger_service_handler, + descriptions.get(SERVICE_TRIGGER), schema=TRIGGER_SERVICE_SCHEMA) + hass.services.register(DOMAIN, SERVICE_RELOAD, reload_service_handler, + descriptions.get(SERVICE_RELOAD), + schema=RELOAD_SERVICE_SCHEMA) + for service in (SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE): hass.services.register(DOMAIN, service, service_handler, + descriptions.get(service), schema=SERVICE_SCHEMA) return True @@ -263,6 +275,43 @@ class AutomationEntity(ToggleEntity): self._last_triggered = utcnow() self.update_ha_state() + def remove(self): + """Remove automation from HASS.""" + self.turn_off() + super().remove() + + +def _process_config(hass, config, component): + """Process config and add automations.""" + success = False + + for config_key in extract_domain_configs(config, DOMAIN): + conf = config[config_key] + + for list_no, config_block in enumerate(conf): + name = config_block.get(CONF_ALIAS) or "{} {}".format(config_key, + list_no) + + action = _get_action(hass, config_block.get(CONF_ACTION, {}), name) + + if CONF_CONDITION in config_block: + cond_func = _process_if(hass, config, config_block) + + if cond_func is None: + continue + else: + def cond_func(variables): + """Condition will always pass.""" + return True + + attach_triggers = partial(_process_trigger, hass, config, + config_block.get(CONF_TRIGGER, []), name) + entity = AutomationEntity(name, attach_triggers, cond_func, action) + component.add_entities((entity,)) + success = True + + return success + def _get_action(hass, config, name): """Return an action based on a configuration.""" diff --git a/homeassistant/components/automation/services.yaml b/homeassistant/components/automation/services.yaml new file mode 100644 index 00000000000..ee22b671eca --- /dev/null +++ b/homeassistant/components/automation/services.yaml @@ -0,0 +1,34 @@ +turn_on: + description: Enable an automation. + + fields: + entity_id: + description: Name of the automation to turn on. + example: 'automation.notify_home' + +turn_off: + description: Disable an automation. + + fields: + entity_id: + description: Name of the automation to turn off. + example: 'automation.notify_home' + +toggle: + description: Toggle an automation. + + fields: + entity_id: + description: Name of the automation to toggle on/off. + example: 'automation.notify_home' + +trigger: + description: Trigger the action of an automation. + + fields: + entity_id: + description: Name of the automation to trigger. + example: 'automation.notify_home' + +reload: + description: Reload the automation configuration. diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 61cda43d431..0b4768b809d 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -195,6 +195,10 @@ class Entity(object): return self.hass.states.set( self.entity_id, state, attr, self.force_update) + def remove(self) -> None: + """Remove entitiy from HASS.""" + self.hass.states.remove(self.entity_id) + def _attr_setter(self, name, typ, attr, attrs): """Helper method to populate attributes based on properties.""" if attr in attrs: diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 898a445c788..e853d20df89 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -32,13 +32,14 @@ class EntityComponent(object): self.entities = {} self.group = None - self.is_polling = False self.config = None self.lock = Lock() - self.add_entities = EntityPlatform(self, self.scan_interval, - None).add_entities + self._platforms = { + 'core': EntityPlatform(self, self.scan_interval, None), + } + self.add_entities = self._platforms['core'].add_entities def setup(self, config): """Set up a full entity component. @@ -85,17 +86,22 @@ class EntityComponent(object): return # Config > Platform > Component - scan_interval = platform_config.get( - CONF_SCAN_INTERVAL, - getattr(platform, 'SCAN_INTERVAL', self.scan_interval)) + scan_interval = (platform_config.get(CONF_SCAN_INTERVAL) or + getattr(platform, 'SCAN_INTERVAL', None) or + self.scan_interval) entity_namespace = platform_config.get(CONF_ENTITY_NAMESPACE) + key = (platform_type, scan_interval, entity_namespace) + + if key not in self._platforms: + self._platforms[key] = EntityPlatform(self, scan_interval, + entity_namespace) + entity_platform = self._platforms[key] + try: - platform.setup_platform( - self.hass, platform_config, - EntityPlatform(self, scan_interval, - entity_namespace).add_entities, - discovery_info) + platform.setup_platform(self.hass, platform_config, + entity_platform.add_entities, + discovery_info) self.hass.config.components.append( '{}.{}'.format(self.domain, platform_type)) @@ -135,6 +141,22 @@ class EntityComponent(object): if self.group is not None: self.group.update_tracked_entity_ids(self.entities.keys()) + def reset(self): + """Remove entities and reset the entity component to initial values.""" + with self.lock: + for platform in self._platforms.values(): + platform.reset() + + self._platforms = { + 'core': self._platforms['core'] + } + self.entities = {} + self.config = None + + if self.group is not None: + self.group.stop() + self.group = None + class EntityPlatform(object): """Keep track of entities for a single platform.""" @@ -146,7 +168,7 @@ class EntityPlatform(object): self.scan_interval = scan_interval self.entity_namespace = entity_namespace self.platform_entities = [] - self.is_polling = False + self._unsub_polling = None def add_entities(self, new_entities): """Add entities for a single platform.""" @@ -157,17 +179,23 @@ class EntityPlatform(object): self.component.update_group() - if self.is_polling or \ + if self._unsub_polling is not None or \ not any(entity.should_poll for entity in self.platform_entities): return - self.is_polling = True - - track_utc_time_change( + self._unsub_polling = track_utc_time_change( self.component.hass, self._update_entity_states, second=range(0, 60, self.scan_interval)) + def reset(self): + """Remove all entities and reset data.""" + for entity in self.platform_entities: + entity.remove() + if self._unsub_polling is not None: + self._unsub_polling() + self._unsub_polling = None + def _update_entity_states(self, now): """Update the states of all the polling entities.""" with self.component.lock: diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 77727ca56b5..f244bb3a23b 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -5,6 +5,7 @@ from unittest.mock import patch from homeassistant.bootstrap import _setup_component import homeassistant.components.automation as automation from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.exceptions import HomeAssistantError import homeassistant.util.dt as dt_util from tests.common import get_test_home_assistant @@ -414,3 +415,132 @@ class TestAutomation(unittest.TestCase): automation.turn_on(self.hass, entity_id) self.hass.pool.block_till_done() assert automation.is_on(self.hass, entity_id) + + @patch('homeassistant.config.load_yaml_config_file', return_value={ + automation.DOMAIN: { + 'alias': 'bye', + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event2', + }, + 'action': { + 'service': 'test.automation', + 'data_template': { + 'event': '{{ trigger.event.event_type }}' + } + } + } + }) + def test_reload_config_service(self, mock_load_yaml): + """Test the reload config service.""" + assert _setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'alias': 'hello', + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event', + }, + 'action': { + 'service': 'test.automation', + 'data_template': { + 'event': '{{ trigger.event.event_type }}' + } + } + } + }) + assert self.hass.states.get('automation.hello') is not None + assert self.hass.states.get('automation.bye') is None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + + assert len(self.calls) == 1 + assert self.calls[0].data.get('event') == 'test_event' + + automation.reload(self.hass) + self.hass.pool.block_till_done() + + assert self.hass.states.get('automation.hello') is None + assert self.hass.states.get('automation.bye') is not None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + assert len(self.calls) == 1 + + self.hass.bus.fire('test_event2') + self.hass.pool.block_till_done() + assert len(self.calls) == 2 + assert self.calls[1].data.get('event') == 'test_event2' + + @patch('homeassistant.config.load_yaml_config_file', return_value={ + automation.DOMAIN: 'not valid', + }) + def test_reload_config_when_invalid_config(self, mock_load_yaml): + """Test the reload config service handling invalid config.""" + assert _setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'alias': 'hello', + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event', + }, + 'action': { + 'service': 'test.automation', + 'data_template': { + 'event': '{{ trigger.event.event_type }}' + } + } + } + }) + assert self.hass.states.get('automation.hello') is not None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + + assert len(self.calls) == 1 + assert self.calls[0].data.get('event') == 'test_event' + + automation.reload(self.hass) + self.hass.pool.block_till_done() + + assert self.hass.states.get('automation.hello') is not None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + assert len(self.calls) == 2 + + def test_reload_config_handles_load_fails(self): + """Test the reload config service.""" + assert _setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'alias': 'hello', + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event', + }, + 'action': { + 'service': 'test.automation', + 'data_template': { + 'event': '{{ trigger.event.event_type }}' + } + } + } + }) + assert self.hass.states.get('automation.hello') is not None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + + assert len(self.calls) == 1 + assert self.calls[0].data.get('event') == 'test_event' + + with patch('homeassistant.config.load_yaml_config_file', + side_effect=HomeAssistantError('bla')): + automation.reload(self.hass) + self.hass.pool.block_till_done() + + assert self.hass.states.get('automation.hello') is not None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + assert len(self.calls) == 2 diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index f9abe764866..0ed70ecef77 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -211,19 +211,19 @@ class TestBootstrap: deps = ['non_existing'] loader.set_component('comp', MockModule('comp', dependencies=deps)) - assert not bootstrap._setup_component(self.hass, 'comp', None) + assert not bootstrap._setup_component(self.hass, 'comp', {}) assert 'comp' not in self.hass.config.components self.hass.config.components.append('non_existing') - assert bootstrap._setup_component(self.hass, 'comp', None) + assert bootstrap._setup_component(self.hass, 'comp', {}) def test_component_failing_setup(self): """Test component that fails setup.""" loader.set_component( 'comp', MockModule('comp', setup=lambda hass, config: False)) - assert not bootstrap._setup_component(self.hass, 'comp', None) + assert not bootstrap._setup_component(self.hass, 'comp', {}) assert 'comp' not in self.hass.config.components def test_component_exception_setup(self): @@ -234,7 +234,7 @@ class TestBootstrap: loader.set_component('comp', MockModule('comp', setup=exception_setup)) - assert not bootstrap._setup_component(self.hass, 'comp', None) + assert not bootstrap._setup_component(self.hass, 'comp', {}) assert 'comp' not in self.hass.config.components def test_home_assistant_core_config_validation(self):