mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Resolve traceback error when using variables in template triggers (#77287)
Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
parent
ee32e0eb3f
commit
ba6a81c565
@ -2,10 +2,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
import functools
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -16,7 +16,13 @@ from homeassistant.const import (
|
||||
CONF_PLATFORM,
|
||||
CONF_VARIABLES,
|
||||
)
|
||||
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
Context,
|
||||
HomeAssistant,
|
||||
callback,
|
||||
is_callback,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||
|
||||
@ -101,20 +107,51 @@ async def async_validate_trigger_config(
|
||||
def _trigger_action_wrapper(
|
||||
hass: HomeAssistant, action: Callable, conf: ConfigType
|
||||
) -> Callable:
|
||||
"""Wrap trigger action with extra vars if configured."""
|
||||
"""Wrap trigger action with extra vars if configured.
|
||||
|
||||
If action is a coroutine function, a coroutine function will be returned.
|
||||
If action is a callback, a callback will be returned.
|
||||
"""
|
||||
if CONF_VARIABLES not in conf:
|
||||
return action
|
||||
|
||||
@functools.wraps(action)
|
||||
async def with_vars(
|
||||
run_variables: dict[str, Any], context: Context | None = None
|
||||
) -> None:
|
||||
"""Wrap action with extra vars."""
|
||||
trigger_variables = conf[CONF_VARIABLES]
|
||||
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
||||
await action(run_variables, context)
|
||||
# Check for partials to properly determine if coroutine function
|
||||
check_func = action
|
||||
while isinstance(check_func, functools.partial):
|
||||
check_func = check_func.func
|
||||
|
||||
return with_vars
|
||||
wrapper_func: Callable[..., None] | Callable[..., Coroutine[Any, Any, None]]
|
||||
if asyncio.iscoroutinefunction(check_func):
|
||||
async_action = cast(Callable[..., Coroutine[Any, Any, None]], action)
|
||||
|
||||
@functools.wraps(async_action)
|
||||
async def async_with_vars(
|
||||
run_variables: dict[str, Any], context: Context | None = None
|
||||
) -> None:
|
||||
"""Wrap action with extra vars."""
|
||||
trigger_variables = conf[CONF_VARIABLES]
|
||||
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
||||
await action(run_variables, context)
|
||||
|
||||
wrapper_func = async_with_vars
|
||||
|
||||
else:
|
||||
|
||||
@functools.wraps(action)
|
||||
async def with_vars(
|
||||
run_variables: dict[str, Any], context: Context | None = None
|
||||
) -> None:
|
||||
"""Wrap action with extra vars."""
|
||||
trigger_variables = conf[CONF_VARIABLES]
|
||||
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
||||
action(run_variables, context)
|
||||
|
||||
if is_callback(check_func):
|
||||
with_vars = callback(with_vars)
|
||||
|
||||
wrapper_func = with_vars
|
||||
|
||||
return wrapper_func
|
||||
|
||||
|
||||
async def async_initialize_triggers(
|
||||
|
@ -1,12 +1,13 @@
|
||||
"""The tests for the trigger helper."""
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
from unittest.mock import ANY, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.helpers.trigger import (
|
||||
_async_get_trigger_platform,
|
||||
async_initialize_triggers,
|
||||
async_validate_trigger_config,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
@ -137,3 +138,62 @@ async def test_trigger_alias(
|
||||
"Automation trigger 'My event' triggered by event 'trigger_event'"
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
|
||||
async def test_async_initialize_triggers(
|
||||
hass: HomeAssistant, calls: list[ServiceCall], caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test async_initialize_triggers with different action types."""
|
||||
|
||||
log_cb = MagicMock()
|
||||
|
||||
action_calls = []
|
||||
|
||||
trigger_config = await async_validate_trigger_config(
|
||||
hass,
|
||||
[
|
||||
{
|
||||
"platform": "event",
|
||||
"event_type": ["trigger_event"],
|
||||
"variables": {
|
||||
"name": "Paulus",
|
||||
"via_event": "{{ trigger.event.event_type }}",
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
async def async_action(*args):
|
||||
action_calls.append([*args])
|
||||
|
||||
@callback
|
||||
def cb_action(*args):
|
||||
action_calls.append([*args])
|
||||
|
||||
def non_cb_action(*args):
|
||||
action_calls.append([*args])
|
||||
|
||||
for action in (async_action, cb_action, non_cb_action):
|
||||
action_calls = []
|
||||
|
||||
unsub = await async_initialize_triggers(
|
||||
hass,
|
||||
trigger_config,
|
||||
action,
|
||||
"test",
|
||||
"",
|
||||
log_cb,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hass.bus.async_fire("trigger_event")
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(action_calls) == 1
|
||||
assert action_calls[0][0]["name"] == "Paulus"
|
||||
assert action_calls[0][0]["via_event"] == "trigger_event"
|
||||
log_cb.assert_called_once_with(ANY, "Initialized trigger")
|
||||
|
||||
log_cb.reset_mock()
|
||||
unsub()
|
||||
|
Loading…
x
Reference in New Issue
Block a user