mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 13:47:35 +00:00
Don't reload other automations when saving an automation (#80254)
* Only reload modified automation * Correct check for existing automation * Add tests * Remove the new service, improve ReloadServiceHelper * Revert unneeded changes * Update tests * Address review comments * Improve test coverage * Address review comments * Tweak reloader code + add a targetted test * Apply suggestions from code review Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Explain the tests + add more variations * Fix copy-paste mistake in test * Rephrase explanation of expected test outcome --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
679752ceb8
commit
7cd0fe3c5f
@ -331,17 +331,25 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
await async_get_blueprints(hass).async_reset_cache()
|
await async_get_blueprints(hass).async_reset_cache()
|
||||||
if (conf := await component.async_prepare_reload(skip_reset=True)) is None:
|
if (conf := await component.async_prepare_reload(skip_reset=True)) is None:
|
||||||
return
|
return
|
||||||
await _async_process_config(hass, conf, component)
|
if automation_id := service_call.data.get(CONF_ID):
|
||||||
|
await _async_process_single_config(hass, conf, component, automation_id)
|
||||||
|
else:
|
||||||
|
await _async_process_config(hass, conf, component)
|
||||||
hass.bus.async_fire(EVENT_AUTOMATION_RELOADED, context=service_call.context)
|
hass.bus.async_fire(EVENT_AUTOMATION_RELOADED, context=service_call.context)
|
||||||
|
|
||||||
reload_helper = ReloadServiceHelper(reload_service_handler)
|
def reload_targets(service_call: ServiceCall) -> set[str | None]:
|
||||||
|
if automation_id := service_call.data.get(CONF_ID):
|
||||||
|
return {automation_id}
|
||||||
|
return {automation.unique_id for automation in component.entities}
|
||||||
|
|
||||||
|
reload_helper = ReloadServiceHelper(reload_service_handler, reload_targets)
|
||||||
|
|
||||||
async_register_admin_service(
|
async_register_admin_service(
|
||||||
hass,
|
hass,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
SERVICE_RELOAD,
|
SERVICE_RELOAD,
|
||||||
reload_helper.execute_service,
|
reload_helper.execute_service,
|
||||||
schema=vol.Schema({}),
|
schema=vol.Schema({vol.Optional(CONF_ID): str}),
|
||||||
)
|
)
|
||||||
|
|
||||||
websocket_api.async_register_command(hass, websocket_config)
|
websocket_api.async_register_command(hass, websocket_config)
|
||||||
@ -859,6 +867,7 @@ class AutomationEntityConfig:
|
|||||||
async def _prepare_automation_config(
|
async def _prepare_automation_config(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config: ConfigType,
|
config: ConfigType,
|
||||||
|
wanted_automation_id: str | None,
|
||||||
) -> list[AutomationEntityConfig]:
|
) -> list[AutomationEntityConfig]:
|
||||||
"""Parse configuration and prepare automation entity configuration."""
|
"""Parse configuration and prepare automation entity configuration."""
|
||||||
automation_configs: list[AutomationEntityConfig] = []
|
automation_configs: list[AutomationEntityConfig] = []
|
||||||
@ -866,6 +875,10 @@ async def _prepare_automation_config(
|
|||||||
conf: list[ConfigType] = config[DOMAIN]
|
conf: list[ConfigType] = config[DOMAIN]
|
||||||
|
|
||||||
for list_no, config_block in enumerate(conf):
|
for list_no, config_block in enumerate(conf):
|
||||||
|
automation_id: str | None = config_block.get(CONF_ID)
|
||||||
|
if wanted_automation_id is not None and automation_id != wanted_automation_id:
|
||||||
|
continue
|
||||||
|
|
||||||
raw_config = cast(AutomationConfig, config_block).raw_config
|
raw_config = cast(AutomationConfig, config_block).raw_config
|
||||||
raw_blueprint_inputs = cast(AutomationConfig, config_block).raw_blueprint_inputs
|
raw_blueprint_inputs = cast(AutomationConfig, config_block).raw_blueprint_inputs
|
||||||
validation_failed = cast(AutomationConfig, config_block).validation_failed
|
validation_failed = cast(AutomationConfig, config_block).validation_failed
|
||||||
@ -1025,7 +1038,7 @@ async def _async_process_config(
|
|||||||
|
|
||||||
return automation_matches, config_matches
|
return automation_matches, config_matches
|
||||||
|
|
||||||
automation_configs = await _prepare_automation_config(hass, config)
|
automation_configs = await _prepare_automation_config(hass, config, None)
|
||||||
automations: list[BaseAutomationEntity] = list(component.entities)
|
automations: list[BaseAutomationEntity] = list(component.entities)
|
||||||
|
|
||||||
# Find automations and configurations which have matches
|
# Find automations and configurations which have matches
|
||||||
@ -1049,6 +1062,41 @@ async def _async_process_config(
|
|||||||
await component.async_add_entities(entities)
|
await component.async_add_entities(entities)
|
||||||
|
|
||||||
|
|
||||||
|
def _automation_matches_config(
|
||||||
|
automation: BaseAutomationEntity | None, config: AutomationEntityConfig | None
|
||||||
|
) -> bool:
|
||||||
|
"""Return False if an automation's config has been changed."""
|
||||||
|
if not automation:
|
||||||
|
return False
|
||||||
|
if not config:
|
||||||
|
return False
|
||||||
|
name = _automation_name(config)
|
||||||
|
return automation.name == name and automation.raw_config == config.raw_config
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_process_single_config(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: dict[str, Any],
|
||||||
|
component: EntityComponent[BaseAutomationEntity],
|
||||||
|
automation_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Process config and add a single automation."""
|
||||||
|
|
||||||
|
automation_configs = await _prepare_automation_config(hass, config, automation_id)
|
||||||
|
automation = next(
|
||||||
|
(x for x in component.entities if x.unique_id == automation_id), None
|
||||||
|
)
|
||||||
|
automation_config = automation_configs[0] if automation_configs else None
|
||||||
|
|
||||||
|
if _automation_matches_config(automation, automation_config):
|
||||||
|
return
|
||||||
|
|
||||||
|
if automation:
|
||||||
|
await automation.async_remove()
|
||||||
|
entities = await _create_automation_entities(hass, automation_configs)
|
||||||
|
await component.async_add_entities(entities)
|
||||||
|
|
||||||
|
|
||||||
async def _async_process_if(
|
async def _async_process_if(
|
||||||
hass: HomeAssistant, name: str, config: dict[str, Any]
|
hass: HomeAssistant, name: str, config: dict[str, Any]
|
||||||
) -> IfAction | None:
|
) -> IfAction | None:
|
||||||
|
@ -26,7 +26,9 @@ def async_setup(hass: HomeAssistant) -> bool:
|
|||||||
async def hook(action: str, config_key: str) -> None:
|
async def hook(action: str, config_key: str) -> None:
|
||||||
"""post_write_hook for Config View that reloads automations."""
|
"""post_write_hook for Config View that reloads automations."""
|
||||||
if action != ACTION_DELETE:
|
if action != ACTION_DELETE:
|
||||||
await hass.services.async_call(DOMAIN, SERVICE_RELOAD)
|
await hass.services.async_call(
|
||||||
|
DOMAIN, SERVICE_RELOAD, {CONF_ID: config_key}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
ent_reg = er.async_get(hass)
|
ent_reg = er.async_get(hass)
|
||||||
|
@ -77,6 +77,8 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
||||||
ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache"
|
ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache"
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _base_components() -> dict[str, ModuleType]:
|
def _base_components() -> dict[str, ModuleType]:
|
||||||
@ -1154,40 +1156,67 @@ def verify_domain_control(
|
|||||||
|
|
||||||
|
|
||||||
class ReloadServiceHelper:
|
class ReloadServiceHelper:
|
||||||
"""Helper for reload services to minimize unnecessary reloads."""
|
"""Helper for reload services.
|
||||||
|
|
||||||
def __init__(self, service_func: Callable[[ServiceCall], Awaitable]) -> None:
|
The helper has the following purposes:
|
||||||
|
- Make sure reloads do not happen in parallel
|
||||||
|
- Avoid redundant reloads of the same target
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
service_func: Callable[[ServiceCall], Awaitable],
|
||||||
|
reload_targets_func: Callable[[ServiceCall], set[_T]],
|
||||||
|
) -> None:
|
||||||
"""Initialize ReloadServiceHelper."""
|
"""Initialize ReloadServiceHelper."""
|
||||||
self._service_func = service_func
|
self._service_func = service_func
|
||||||
self._service_running = False
|
self._service_running = False
|
||||||
self._service_condition = asyncio.Condition()
|
self._service_condition = asyncio.Condition()
|
||||||
|
self._pending_reload_targets: set[_T] = set()
|
||||||
|
self._reload_targets_func = reload_targets_func
|
||||||
|
|
||||||
async def execute_service(self, service_call: ServiceCall) -> None:
|
async def execute_service(self, service_call: ServiceCall) -> None:
|
||||||
"""Execute the service.
|
"""Execute the service.
|
||||||
|
|
||||||
If a previous reload task if currently in progress, wait for it to finish first.
|
If a previous reload task is currently in progress, wait for it to finish first.
|
||||||
Once the previous reload task has finished, one of the waiting tasks will be
|
Once the previous reload task has finished, one of the waiting tasks will be
|
||||||
assigned to execute the reload, the others will wait for the reload to finish.
|
assigned to execute the reload of the targets it is assigned to reload. The
|
||||||
|
other tasks will wait if they should reload the same target, otherwise they
|
||||||
|
will wait for the next round.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
do_reload = False
|
do_reload = False
|
||||||
|
reload_targets = None
|
||||||
async with self._service_condition:
|
async with self._service_condition:
|
||||||
if self._service_running:
|
if self._service_running:
|
||||||
# A previous reload task is already in progress, wait for it to finish
|
# A previous reload task is already in progress, wait for it to finish,
|
||||||
|
# because that task may be reloading a stale version of the resource.
|
||||||
await self._service_condition.wait()
|
await self._service_condition.wait()
|
||||||
|
|
||||||
async with self._service_condition:
|
while True:
|
||||||
if not self._service_running:
|
async with self._service_condition:
|
||||||
# This task will do the reload
|
# Once we've passed this point, we assume the version of the resource is
|
||||||
self._service_running = True
|
# the one our task was assigned to reload, or a newer one. Regardless of
|
||||||
do_reload = True
|
# which, our task is happy as long as the target is reloaded at least
|
||||||
else:
|
# once.
|
||||||
# Another task will perform the reload, wait for it to finish
|
if reload_targets is None:
|
||||||
|
reload_targets = self._reload_targets_func(service_call)
|
||||||
|
self._pending_reload_targets |= reload_targets
|
||||||
|
if not self._service_running:
|
||||||
|
# This task will do a reload
|
||||||
|
self._service_running = True
|
||||||
|
do_reload = True
|
||||||
|
break
|
||||||
|
# Another task will perform a reload, wait for it to finish
|
||||||
await self._service_condition.wait()
|
await self._service_condition.wait()
|
||||||
|
# Check if the reload this task is waiting for has been completed
|
||||||
|
if reload_targets.isdisjoint(self._pending_reload_targets):
|
||||||
|
break
|
||||||
|
|
||||||
if do_reload:
|
if do_reload:
|
||||||
# Reload, then notify other tasks
|
# Reload, then notify other tasks
|
||||||
await self._service_func(service_call)
|
await self._service_func(service_call)
|
||||||
async with self._service_condition:
|
async with self._service_condition:
|
||||||
self._service_running = False
|
self._service_running = False
|
||||||
|
self._pending_reload_targets -= reload_targets
|
||||||
self._service_condition.notify_all()
|
self._service_condition.notify_all()
|
||||||
|
@ -21,6 +21,7 @@ from homeassistant.config_entries import ConfigEntryState
|
|||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_ENTITY_ID,
|
ATTR_ENTITY_ID,
|
||||||
ATTR_NAME,
|
ATTR_NAME,
|
||||||
|
CONF_ID,
|
||||||
EVENT_HOMEASSISTANT_STARTED,
|
EVENT_HOMEASSISTANT_STARTED,
|
||||||
SERVICE_RELOAD,
|
SERVICE_RELOAD,
|
||||||
SERVICE_TOGGLE,
|
SERVICE_TOGGLE,
|
||||||
@ -692,7 +693,9 @@ async def test_reload_config_handles_load_fails(hass: HomeAssistant, calls) -> N
|
|||||||
assert len(calls) == 2
|
assert len(calls) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("service", ["turn_off_stop", "turn_off_no_stop", "reload"])
|
@pytest.mark.parametrize(
|
||||||
|
"service", ["turn_off_stop", "turn_off_no_stop", "reload", "reload_single"]
|
||||||
|
)
|
||||||
async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
|
async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
|
||||||
"""Test that turning off / reloading stops any running actions as appropriate."""
|
"""Test that turning off / reloading stops any running actions as appropriate."""
|
||||||
entity_id = "automation.hello"
|
entity_id = "automation.hello"
|
||||||
@ -700,6 +703,7 @@ async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
|
|||||||
|
|
||||||
config = {
|
config = {
|
||||||
automation.DOMAIN: {
|
automation.DOMAIN: {
|
||||||
|
"id": "sun",
|
||||||
"alias": "hello",
|
"alias": "hello",
|
||||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||||
"action": [
|
"action": [
|
||||||
@ -737,7 +741,7 @@ async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
|
|||||||
{ATTR_ENTITY_ID: entity_id, automation.CONF_STOP_ACTIONS: False},
|
{ATTR_ENTITY_ID: entity_id, automation.CONF_STOP_ACTIONS: False},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
else:
|
elif service == "reload":
|
||||||
config[automation.DOMAIN]["alias"] = "goodbye"
|
config[automation.DOMAIN]["alias"] = "goodbye"
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.config.load_yaml_config_file",
|
"homeassistant.config.load_yaml_config_file",
|
||||||
@ -747,6 +751,19 @@ async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
|
|||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
automation.DOMAIN, SERVICE_RELOAD, blocking=True
|
automation.DOMAIN, SERVICE_RELOAD, blocking=True
|
||||||
)
|
)
|
||||||
|
else: # service == "reload_single"
|
||||||
|
config[automation.DOMAIN]["alias"] = "goodbye"
|
||||||
|
with patch(
|
||||||
|
"homeassistant.config.load_yaml_config_file",
|
||||||
|
autospec=True,
|
||||||
|
return_value=config,
|
||||||
|
):
|
||||||
|
await hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "sun"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
hass.states.async_set(test_entity, "goodbye")
|
hass.states.async_set(test_entity, "goodbye")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
@ -801,6 +818,238 @@ async def test_reload_unchanged_does_not_stop(
|
|||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reload_single_unchanged_does_not_stop(
|
||||||
|
hass: HomeAssistant, calls
|
||||||
|
) -> None:
|
||||||
|
"""Test that reloading stops any running actions as appropriate."""
|
||||||
|
test_entity = "test.entity"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
automation.DOMAIN: {
|
||||||
|
"id": "sun",
|
||||||
|
"alias": "hello",
|
||||||
|
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||||
|
"action": [
|
||||||
|
{"event": "running"},
|
||||||
|
{"wait_template": "{{ is_state('test.entity', 'goodbye') }}"},
|
||||||
|
{"service": "test.automation"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert await async_setup_component(hass, automation.DOMAIN, config)
|
||||||
|
|
||||||
|
running = asyncio.Event()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def running_cb(event):
|
||||||
|
running.set()
|
||||||
|
|
||||||
|
hass.bus.async_listen_once("running", running_cb)
|
||||||
|
hass.states.async_set(test_entity, "hello")
|
||||||
|
|
||||||
|
hass.bus.async_fire("test_event")
|
||||||
|
await running.wait()
|
||||||
|
assert len(calls) == 0
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.config.load_yaml_config_file",
|
||||||
|
autospec=True,
|
||||||
|
return_value=config,
|
||||||
|
):
|
||||||
|
await hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "sun"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
hass.states.async_set(test_entity, "goodbye")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reload_single_add_automation(hass: HomeAssistant, calls) -> None:
|
||||||
|
"""Test that reloading a single automation."""
|
||||||
|
config1 = {automation.DOMAIN: {}}
|
||||||
|
config2 = {
|
||||||
|
automation.DOMAIN: {
|
||||||
|
"id": "sun",
|
||||||
|
"alias": "hello",
|
||||||
|
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||||
|
"action": [{"service": "test.automation"}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert await async_setup_component(hass, automation.DOMAIN, config1)
|
||||||
|
|
||||||
|
hass.bus.async_fire("test_event")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 0
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.config.load_yaml_config_file",
|
||||||
|
autospec=True,
|
||||||
|
return_value=config2,
|
||||||
|
):
|
||||||
|
await hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "sun"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
hass.bus.async_fire("test_event")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reload_single_parallel_calls(hass: HomeAssistant, calls) -> None:
|
||||||
|
"""Test reloading single automations in parallel."""
|
||||||
|
config1 = {automation.DOMAIN: {}}
|
||||||
|
config2 = {
|
||||||
|
automation.DOMAIN: [
|
||||||
|
{
|
||||||
|
"id": "sun",
|
||||||
|
"alias": "hello",
|
||||||
|
"trigger": {"platform": "event", "event_type": "test_event_sun"},
|
||||||
|
"action": [{"service": "test.automation"}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "moon",
|
||||||
|
"alias": "goodbye",
|
||||||
|
"trigger": {"platform": "event", "event_type": "test_event_moon"},
|
||||||
|
"action": [{"service": "test.automation"}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "mars",
|
||||||
|
"alias": "goodbye",
|
||||||
|
"trigger": {"platform": "event", "event_type": "test_event_mars"},
|
||||||
|
"action": [{"service": "test.automation"}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "venus",
|
||||||
|
"alias": "goodbye",
|
||||||
|
"trigger": {"platform": "event", "event_type": "test_event_venus"},
|
||||||
|
"action": [{"service": "test.automation"}],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert await async_setup_component(hass, automation.DOMAIN, config1)
|
||||||
|
|
||||||
|
hass.bus.async_fire("test_event")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 0
|
||||||
|
|
||||||
|
# Trigger multiple reload service calls, each automation is reloaded twice.
|
||||||
|
# This tests the logic in the `ReloadServiceHelper` which avoids redundant
|
||||||
|
# reloads of the same target automation.
|
||||||
|
with patch(
|
||||||
|
"homeassistant.config.load_yaml_config_file",
|
||||||
|
autospec=True,
|
||||||
|
return_value=config2,
|
||||||
|
):
|
||||||
|
tasks = [
|
||||||
|
hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "sun"},
|
||||||
|
blocking=False,
|
||||||
|
),
|
||||||
|
hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "moon"},
|
||||||
|
blocking=False,
|
||||||
|
),
|
||||||
|
hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "mars"},
|
||||||
|
blocking=False,
|
||||||
|
),
|
||||||
|
hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "venus"},
|
||||||
|
blocking=False,
|
||||||
|
),
|
||||||
|
hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "sun"},
|
||||||
|
blocking=False,
|
||||||
|
),
|
||||||
|
hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "moon"},
|
||||||
|
blocking=False,
|
||||||
|
),
|
||||||
|
hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "mars"},
|
||||||
|
blocking=False,
|
||||||
|
),
|
||||||
|
hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "venus"},
|
||||||
|
blocking=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Sanity check to ensure all automations are correctly setup
|
||||||
|
hass.bus.async_fire("test_event_sun")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 1
|
||||||
|
hass.bus.async_fire("test_event_moon")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 2
|
||||||
|
hass.bus.async_fire("test_event_mars")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 3
|
||||||
|
hass.bus.async_fire("test_event_venus")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 4
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reload_single_remove_automation(hass: HomeAssistant, calls) -> None:
|
||||||
|
"""Test that reloading a single automation."""
|
||||||
|
config1 = {
|
||||||
|
automation.DOMAIN: {
|
||||||
|
"id": "sun",
|
||||||
|
"alias": "hello",
|
||||||
|
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||||
|
"action": [{"service": "test.automation"}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config2 = {automation.DOMAIN: {}}
|
||||||
|
assert await async_setup_component(hass, automation.DOMAIN, config1)
|
||||||
|
|
||||||
|
hass.bus.async_fire("test_event")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 1
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.config.load_yaml_config_file",
|
||||||
|
autospec=True,
|
||||||
|
return_value=config2,
|
||||||
|
):
|
||||||
|
await hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_RELOAD,
|
||||||
|
{CONF_ID: "sun"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
hass.bus.async_fire("test_event")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_reload_moved_automation_without_alias(
|
async def test_reload_moved_automation_without_alias(
|
||||||
hass: HomeAssistant, calls
|
hass: HomeAssistant, calls
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -10,7 +10,7 @@ import pytest
|
|||||||
from homeassistant.bootstrap import async_setup_component
|
from homeassistant.bootstrap import async_setup_component
|
||||||
from homeassistant.components import config
|
from homeassistant.components import config
|
||||||
from homeassistant.components.config import automation
|
from homeassistant.components.config import automation
|
||||||
from homeassistant.const import STATE_ON, STATE_UNAVAILABLE
|
from homeassistant.const import STATE_ON
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import entity_registry as er
|
from homeassistant.helpers import entity_registry as er
|
||||||
from homeassistant.util import yaml
|
from homeassistant.util import yaml
|
||||||
@ -82,10 +82,8 @@ async def test_update_automation_config(
|
|||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert sorted(hass.states.async_entity_ids("automation")) == [
|
assert sorted(hass.states.async_entity_ids("automation")) == [
|
||||||
"automation.automation_0",
|
|
||||||
"automation.automation_1",
|
"automation.automation_1",
|
||||||
]
|
]
|
||||||
assert hass.states.get("automation.automation_0").state == STATE_UNAVAILABLE
|
|
||||||
assert hass.states.get("automation.automation_1").state == STATE_ON
|
assert hass.states.get("automation.automation_1").state == STATE_ON
|
||||||
|
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
@ -260,10 +258,8 @@ async def test_update_remove_key_automation_config(
|
|||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert sorted(hass.states.async_entity_ids("automation")) == [
|
assert sorted(hass.states.async_entity_ids("automation")) == [
|
||||||
"automation.automation_0",
|
|
||||||
"automation.automation_1",
|
"automation.automation_1",
|
||||||
]
|
]
|
||||||
assert hass.states.get("automation.automation_0").state == STATE_UNAVAILABLE
|
|
||||||
assert hass.states.get("automation.automation_1").state == STATE_ON
|
assert hass.states.get("automation.automation_1").state == STATE_ON
|
||||||
|
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
@ -305,10 +301,8 @@ async def test_bad_formatted_automations(
|
|||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert sorted(hass.states.async_entity_ids("automation")) == [
|
assert sorted(hass.states.async_entity_ids("automation")) == [
|
||||||
"automation.automation_0",
|
|
||||||
"automation.automation_1",
|
"automation.automation_1",
|
||||||
]
|
]
|
||||||
assert hass.states.get("automation.automation_0").state == STATE_UNAVAILABLE
|
|
||||||
assert hass.states.get("automation.automation_1").state == STATE_ON
|
assert hass.states.get("automation.automation_1").state == STATE_ON
|
||||||
|
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
|
@ -1852,3 +1852,139 @@ async def test_async_extract_config_entry_ids(hass: HomeAssistant) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert await service.async_extract_config_entry_ids(hass, call) == {"abc"}
|
assert await service.async_extract_config_entry_ids(hass, call) == {"abc"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reload_service_helper(hass: HomeAssistant) -> None:
|
||||||
|
"""Test the reload service helper."""
|
||||||
|
|
||||||
|
active_reload_calls = 0
|
||||||
|
reloaded = []
|
||||||
|
|
||||||
|
async def reload_service_handler(service_call: ServiceCall) -> None:
|
||||||
|
"""Remove all automations and load new ones from config."""
|
||||||
|
nonlocal active_reload_calls
|
||||||
|
# Assert the reload helper prevents parallel reloads
|
||||||
|
assert not active_reload_calls
|
||||||
|
active_reload_calls += 1
|
||||||
|
if not (target := service_call.data.get("target")):
|
||||||
|
reloaded.append("all")
|
||||||
|
else:
|
||||||
|
reloaded.append(target)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
active_reload_calls -= 1
|
||||||
|
|
||||||
|
def reload_targets(service_call: ServiceCall) -> set[str | None]:
|
||||||
|
if target_id := service_call.data.get("target"):
|
||||||
|
return {target_id}
|
||||||
|
return {"target1", "target2", "target3", "target4"}
|
||||||
|
|
||||||
|
# Test redundant reload of single targets
|
||||||
|
reloader = service.ReloadServiceHelper(reload_service_handler, reload_targets)
|
||||||
|
tasks = [
|
||||||
|
# This reload task will start executing first, (target1)
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||||
|
# These reload tasks will be deduplicated to (target2, target3, target4, target1)
|
||||||
|
# while the first task is reloaded, note that target1 can't be deduplicated
|
||||||
|
# because it's already being reloaded.
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
assert reloaded == unordered(
|
||||||
|
["target1", "target2", "target3", "target4", "target1"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test redundant reload of multiple targets + single target
|
||||||
|
reloaded.clear()
|
||||||
|
tasks = [
|
||||||
|
# This reload task will start executing first, (target1)
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||||
|
# These reload tasks will be deduplicated to (target2, target3, target4, all)
|
||||||
|
# while the first task is reloaded.
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test")),
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
assert reloaded == unordered(["target1", "target2", "target3", "target4", "all"])
|
||||||
|
|
||||||
|
# Test redundant reload of multiple targets + single target
|
||||||
|
reloaded.clear()
|
||||||
|
tasks = [
|
||||||
|
# This reload task will start executing first, (all)
|
||||||
|
reloader.execute_service(ServiceCall("test", "test")),
|
||||||
|
# These reload tasks will be deduplicated to (target1, target2, target3, target4)
|
||||||
|
# while the first task is reloaded.
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
assert reloaded == unordered(["all", "target1", "target2", "target3", "target4"])
|
||||||
|
|
||||||
|
# Test redundant reload of single targets
|
||||||
|
reloaded.clear()
|
||||||
|
tasks = [
|
||||||
|
# This reload task will start executing first, (target1)
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||||
|
# These reload tasks will be deduplicated to (target2, target3, target4, target1)
|
||||||
|
# while the first task is reloaded, note that target1 can't be deduplicated
|
||||||
|
# because it's already being reloaded.
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
assert reloaded == unordered(
|
||||||
|
["target1", "target2", "target3", "target4", "target1"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test redundant reload of multiple targets + single target
|
||||||
|
reloaded.clear()
|
||||||
|
tasks = [
|
||||||
|
# This reload task will start executing first, (target1)
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||||
|
# These reload tasks will be deduplicated to (target2, target3, target4, all)
|
||||||
|
# while the first task is reloaded.
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test")),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test")),
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
assert reloaded == unordered(["target1", "target2", "target3", "target4", "all"])
|
||||||
|
|
||||||
|
# Test redundant reload of multiple targets + single target
|
||||||
|
reloaded.clear()
|
||||||
|
tasks = [
|
||||||
|
# This reload task will start executing first, (all)
|
||||||
|
reloader.execute_service(ServiceCall("test", "test")),
|
||||||
|
# These reload tasks will be deduplicated to (target1, target2, target3, target4)
|
||||||
|
# while the first task is reloaded.
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||||
|
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
assert reloaded == unordered(["all", "target1", "target2", "target3", "target4"])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user