diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 3037d7cc3a7..9b6f9de1945 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -1,7 +1,9 @@ """Allow to set up simple automation rules via the config file.""" from __future__ import annotations +import asyncio from collections.abc import Callable, Mapping +from dataclasses import dataclass import logging from typing import Any, Protocol, cast @@ -274,7 +276,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def reload_service_handler(service_call: ServiceCall) -> None: """Remove all automations and load new ones from config.""" - if (conf := await component.async_prepare_reload()) is None: + if (conf := await component.async_prepare_reload(skip_reset=True)) is None: return async_get_blueprints(hass).async_reset_cache() await _async_process_config(hass, conf, component) @@ -660,20 +662,27 @@ class AutomationEntity(ToggleEntity, RestoreEntity): ) -async def _async_process_config( - hass: HomeAssistant, - config: dict[str, Any], - component: EntityComponent[AutomationEntity], -) -> bool: - """Process config and add automations. +@dataclass +class AutomationEntityConfig: + """Container for prepared automation entity configuration.""" - Returns if blueprints were used. - """ - entities: list[AutomationEntity] = [] + config_block: ConfigType + config_key: str + list_no: int + raw_blueprint_inputs: ConfigType | None + raw_config: ConfigType | None + + +async def _prepare_automation_config( + hass: HomeAssistant, + config: ConfigType, +) -> tuple[bool, list[AutomationEntityConfig]]: + """Parse configuration and prepare automation entity configuration.""" + automation_configs: list[AutomationEntityConfig] = [] blueprints_used = False for config_key in extract_domain_configs(config, DOMAIN): - conf: list[dict[str, Any] | blueprint.BlueprintInputs] = config[config_key] + conf: list[ConfigType | blueprint.BlueprintInputs] = config[config_key] for list_no, config_block in enumerate(conf): raw_blueprint_inputs = None @@ -700,62 +709,154 @@ async def _async_process_config( else: raw_config = cast(AutomationConfig, config_block).raw_config - automation_id: str | None = config_block.get(CONF_ID) - name: str = config_block.get(CONF_ALIAS) or f"{config_key} {list_no}" - - initial_state: bool | None = config_block.get(CONF_INITIAL_STATE) - - action_script = Script( - hass, - config_block[CONF_ACTION], - name, - DOMAIN, - running_description="automation actions", - script_mode=config_block[CONF_MODE], - max_runs=config_block[CONF_MAX], - max_exceeded=config_block[CONF_MAX_EXCEEDED], - logger=LOGGER, - # We don't pass variables here - # Automation will already render them to use them in the condition - # and so will pass them on to the script. - ) - - if CONF_CONDITION in config_block: - cond_func = await _async_process_if(hass, name, config, config_block) - - if cond_func is None: - continue - else: - cond_func = None - - # Add trigger variables to variables - variables = None - if CONF_TRIGGER_VARIABLES in config_block: - variables = ScriptVariables( - dict(config_block[CONF_TRIGGER_VARIABLES].as_dict()) + automation_configs.append( + AutomationEntityConfig( + config_block, config_key, list_no, raw_blueprint_inputs, raw_config ) - if CONF_VARIABLES in config_block: - if variables: - variables.variables.update(config_block[CONF_VARIABLES].as_dict()) - else: - variables = config_block[CONF_VARIABLES] - - entity = AutomationEntity( - automation_id, - name, - config_block[CONF_TRIGGER], - cond_func, - action_script, - initial_state, - variables, - config_block.get(CONF_TRIGGER_VARIABLES), - raw_config, - raw_blueprint_inputs, - config_block[CONF_TRACE], ) - entities.append(entity) + return (blueprints_used, automation_configs) + +def _automation_name(automation_config: AutomationEntityConfig) -> str: + """Return the configured name of an automation.""" + config_block = automation_config.config_block + config_key = automation_config.config_key + list_no = automation_config.list_no + return config_block.get(CONF_ALIAS) or f"{config_key} {list_no}" + + +async def _create_automation_entities( + hass: HomeAssistant, automation_configs: list[AutomationEntityConfig] +) -> list[AutomationEntity]: + """Create automation entities from prepared configuration.""" + entities: list[AutomationEntity] = [] + + for automation_config in automation_configs: + config_block = automation_config.config_block + + automation_id: str | None = config_block.get(CONF_ID) + name = _automation_name(automation_config) + + initial_state: bool | None = config_block.get(CONF_INITIAL_STATE) + + action_script = Script( + hass, + config_block[CONF_ACTION], + name, + DOMAIN, + running_description="automation actions", + script_mode=config_block[CONF_MODE], + max_runs=config_block[CONF_MAX], + max_exceeded=config_block[CONF_MAX_EXCEEDED], + logger=LOGGER, + # We don't pass variables here + # Automation will already render them to use them in the condition + # and so will pass them on to the script. + ) + + if CONF_CONDITION in config_block: + cond_func = await _async_process_if(hass, name, config_block) + + if cond_func is None: + continue + else: + cond_func = None + + # Add trigger variables to variables + variables = None + if CONF_TRIGGER_VARIABLES in config_block: + variables = ScriptVariables( + dict(config_block[CONF_TRIGGER_VARIABLES].as_dict()) + ) + if CONF_VARIABLES in config_block: + if variables: + variables.variables.update(config_block[CONF_VARIABLES].as_dict()) + else: + variables = config_block[CONF_VARIABLES] + + entity = AutomationEntity( + automation_id, + name, + config_block[CONF_TRIGGER], + cond_func, + action_script, + initial_state, + variables, + config_block.get(CONF_TRIGGER_VARIABLES), + automation_config.raw_config, + automation_config.raw_blueprint_inputs, + config_block[CONF_TRACE], + ) + entities.append(entity) + + return entities + + +async def _async_process_config( + hass: HomeAssistant, + config: dict[str, Any], + component: EntityComponent[AutomationEntity], +) -> bool: + """Process config and add automations. + + Returns if blueprints were used. + """ + + def automation_matches_config( + automation: AutomationEntity, config: AutomationEntityConfig + ) -> bool: + name = _automation_name(config) + return automation.name == name and automation.raw_config == config.raw_config + + def find_matches( + automations: list[AutomationEntity], + automation_configs: list[AutomationEntityConfig], + ) -> tuple[set[int], set[int]]: + """Find matches between a list of automation entities and a list of configurations. + + An automation or configuration is only allowed to match at most once to handle + the case of multiple automations with identical configuration. + + Returns a tuple of sets of indices: ({automation_matches}, {config_matches}) + """ + automation_matches: set[int] = set() + config_matches: set[int] = set() + + for automation_idx, automation in enumerate(automations): + for config_idx, config in enumerate(automation_configs): + if config_idx in config_matches: + # Only allow an automation config to match at most once + continue + if automation_matches_config(automation, config): + automation_matches.add(automation_idx) + config_matches.add(config_idx) + # Only allow an automation to match at most once + break + + return automation_matches, config_matches + + blueprints_used, automation_configs = await _prepare_automation_config(hass, config) + automations: list[AutomationEntity] = list(component.entities) + + # Find automations and configurations which have matches + automation_matches, config_matches = find_matches(automations, automation_configs) + + # Remove automations which have changed config or no longer exist + tasks = [ + automation.async_remove() + for idx, automation in enumerate(automations) + if idx not in automation_matches + ] + await asyncio.gather(*tasks) + + # Create automations which have changed config or have been added + updated_automation_configs = [ + config + for idx, config in enumerate(automation_configs) + if idx not in config_matches + ] + entities = await _create_automation_entities(hass, updated_automation_configs) if entities: await component.async_add_entities(entities) @@ -763,10 +864,10 @@ async def _async_process_config( async def _async_process_if( - hass: HomeAssistant, name: str, config: dict[str, Any], p_config: dict[str, Any] + hass: HomeAssistant, name: str, config: dict[str, Any] ) -> IfAction | None: """Process if checks.""" - if_configs = p_config[CONF_CONDITION] + if_configs = config[CONF_CONDITION] checks: list[condition.ConditionCheckerType] = [] for if_config in if_configs: diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 3cdf0c3a477..f129d829d4c 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -15,6 +15,7 @@ from homeassistant.components.automation import ( EVENT_AUTOMATION_RELOADED, EVENT_AUTOMATION_TRIGGERED, SERVICE_TRIGGER, + AutomationEntity, ) from homeassistant.const import ( ATTR_ENTITY_ID, @@ -720,6 +721,7 @@ async def test_automation_stops(hass, calls, service): blocking=True, ) else: + config[automation.DOMAIN]["alias"] = "goodbye" with patch( "homeassistant.config.load_yaml_config_file", autospec=True, @@ -735,6 +737,271 @@ async def test_automation_stops(hass, calls, service): assert len(calls) == (1 if service == "turn_off_no_stop" else 0) +async def test_reload_unchanged_does_not_stop(hass, calls): + """Test that turning off / reloading stops any running actions as appropriate.""" + test_entity = "test.entity" + + config = { + automation.DOMAIN: { + "alias": "hello", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [ + {"event": "running"}, + {"wait_template": "{{ is_state('test.entity', 'goodbye') }}"}, + {"service": "test.automation"}, + ], + } + } + assert await async_setup_component(hass, automation.DOMAIN, config) + + running = asyncio.Event() + + @callback + def running_cb(event): + running.set() + + hass.bus.async_listen_once("running", running_cb) + hass.states.async_set(test_entity, "hello") + + hass.bus.async_fire("test_event") + await running.wait() + + with patch( + "homeassistant.config.load_yaml_config_file", + autospec=True, + return_value=config, + ): + await hass.services.async_call(automation.DOMAIN, SERVICE_RELOAD, blocking=True) + + hass.states.async_set(test_entity, "goodbye") + await hass.async_block_till_done() + + assert len(calls) == 1 + + +async def test_reload_moved_automation_without_alias(hass, calls): + """Test that changing the order of automations without alias triggers reload.""" + with patch( + "homeassistant.components.automation.AutomationEntity", wraps=AutomationEntity + ) as automation_entity_init: + config = { + automation.DOMAIN: [ + { + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [{"service": "test.automation"}], + }, + { + "alias": "automation_with_alias", + "trigger": {"platform": "event", "event_type": "test_event2"}, + "action": [{"service": "test.automation"}], + }, + ] + } + assert await async_setup_component(hass, automation.DOMAIN, config) + assert automation_entity_init.call_count == 2 + automation_entity_init.reset_mock() + + assert hass.states.get("automation.automation_0") + assert not hass.states.get("automation.automation_1") + assert hass.states.get("automation.automation_with_alias") + + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 1 + + # Reverse the order of the automations + config[automation.DOMAIN].reverse() + with patch( + "homeassistant.config.load_yaml_config_file", + autospec=True, + return_value=config, + ): + await hass.services.async_call( + automation.DOMAIN, SERVICE_RELOAD, blocking=True + ) + + assert automation_entity_init.call_count == 1 + automation_entity_init.reset_mock() + + assert not hass.states.get("automation.automation_0") + assert hass.states.get("automation.automation_1") + assert hass.states.get("automation.automation_with_alias") + + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 2 + + +async def test_reload_identical_automations_without_id(hass, calls): + """Test reloading of identical automations without id.""" + with patch( + "homeassistant.components.automation.AutomationEntity", wraps=AutomationEntity + ) as automation_entity_init: + config = { + automation.DOMAIN: [ + { + "alias": "dolly", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [{"service": "test.automation"}], + }, + { + "alias": "dolly", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [{"service": "test.automation"}], + }, + { + "alias": "dolly", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [{"service": "test.automation"}], + }, + ] + } + assert await async_setup_component(hass, automation.DOMAIN, config) + assert automation_entity_init.call_count == 3 + automation_entity_init.reset_mock() + + assert hass.states.get("automation.dolly") + assert hass.states.get("automation.dolly_2") + assert hass.states.get("automation.dolly_3") + + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 3 + + # Reload the automations without any change + with patch( + "homeassistant.config.load_yaml_config_file", + autospec=True, + return_value=config, + ): + await hass.services.async_call( + automation.DOMAIN, SERVICE_RELOAD, blocking=True + ) + + assert automation_entity_init.call_count == 0 + automation_entity_init.reset_mock() + + assert hass.states.get("automation.dolly") + assert hass.states.get("automation.dolly_2") + assert hass.states.get("automation.dolly_3") + + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 6 + + # Remove two clones + del config[automation.DOMAIN][-1] + del config[automation.DOMAIN][-1] + with patch( + "homeassistant.config.load_yaml_config_file", + autospec=True, + return_value=config, + ): + await hass.services.async_call( + automation.DOMAIN, SERVICE_RELOAD, blocking=True + ) + + assert automation_entity_init.call_count == 0 + automation_entity_init.reset_mock() + + assert hass.states.get("automation.dolly") + + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 7 + + # Add two clones + config[automation.DOMAIN].append(config[automation.DOMAIN][-1]) + config[automation.DOMAIN].append(config[automation.DOMAIN][-1]) + with patch( + "homeassistant.config.load_yaml_config_file", + autospec=True, + return_value=config, + ): + await hass.services.async_call( + automation.DOMAIN, SERVICE_RELOAD, blocking=True + ) + + assert automation_entity_init.call_count == 2 + automation_entity_init.reset_mock() + + assert hass.states.get("automation.dolly") + assert hass.states.get("automation.dolly_2") + assert hass.states.get("automation.dolly_3") + + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 10 + + +@pytest.mark.parametrize( + "automation_config", + ( + { + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [{"service": "test.automation"}], + }, + # An automation using templates + { + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [{"service": "{{ 'test.automation' }}"}], + }, + # An automation using blueprint + { + "use_blueprint": { + "path": "test_event_service.yaml", + "input": { + "trigger_event": "test_event", + "service_to_call": "test.automation", + "a_number": 5, + }, + } + }, + # An automation using blueprint with templated input + { + "use_blueprint": { + "path": "test_event_service.yaml", + "input": { + "trigger_event": "{{ 'test_event' }}", + "service_to_call": "{{ 'test.automation' }}", + "a_number": 5, + }, + } + }, + ), +) +async def test_reload_unchanged_automation(hass, calls, automation_config): + """Test an unmodified automation is not reloaded.""" + with patch( + "homeassistant.components.automation.AutomationEntity", wraps=AutomationEntity + ) as automation_entity_init: + config = {automation.DOMAIN: [automation_config]} + assert await async_setup_component(hass, automation.DOMAIN, config) + assert automation_entity_init.call_count == 1 + automation_entity_init.reset_mock() + + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 1 + + # Reload the automations without any change + with patch( + "homeassistant.config.load_yaml_config_file", + autospec=True, + return_value=config, + ): + await hass.services.async_call( + automation.DOMAIN, SERVICE_RELOAD, blocking=True + ) + + assert automation_entity_init.call_count == 0 + automation_entity_init.reset_mock() + + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 2 + + async def test_automation_restore_state(hass): """Ensure states are restored on startup.""" time = dt_util.utcnow()