From ba6a81c565f162c50166bf776b21b6a2271d7eef Mon Sep 17 00:00:00 2001 From: ehendrix23 Date: Thu, 29 Sep 2022 11:26:28 -0600 Subject: [PATCH] Resolve traceback error when using variables in template triggers (#77287) Co-authored-by: Erik --- homeassistant/helpers/trigger.py | 63 ++++++++++++++++++++++++------- tests/helpers/test_trigger.py | 64 +++++++++++++++++++++++++++++++- 2 files changed, 112 insertions(+), 15 deletions(-) diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 9fde56ec7aa..4cb724a6435 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -2,10 +2,10 @@ from __future__ import annotations import asyncio -from collections.abc import Callable +from collections.abc import Callable, Coroutine import functools import logging -from typing import TYPE_CHECKING, Any, Protocol, TypedDict +from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast import voluptuous as vol @@ -16,7 +16,13 @@ from homeassistant.const import ( CONF_PLATFORM, CONF_VARIABLES, ) -from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback +from homeassistant.core import ( + CALLBACK_TYPE, + Context, + HomeAssistant, + callback, + is_callback, +) from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import IntegrationNotFound, async_get_integration @@ -101,20 +107,51 @@ async def async_validate_trigger_config( def _trigger_action_wrapper( hass: HomeAssistant, action: Callable, conf: ConfigType ) -> Callable: - """Wrap trigger action with extra vars if configured.""" + """Wrap trigger action with extra vars if configured. + + If action is a coroutine function, a coroutine function will be returned. + If action is a callback, a callback will be returned. + """ if CONF_VARIABLES not in conf: return action - @functools.wraps(action) - async def with_vars( - run_variables: dict[str, Any], context: Context | None = None - ) -> None: - """Wrap action with extra vars.""" - trigger_variables = conf[CONF_VARIABLES] - run_variables.update(trigger_variables.async_render(hass, run_variables)) - await action(run_variables, context) + # Check for partials to properly determine if coroutine function + check_func = action + while isinstance(check_func, functools.partial): + check_func = check_func.func - return with_vars + wrapper_func: Callable[..., None] | Callable[..., Coroutine[Any, Any, None]] + if asyncio.iscoroutinefunction(check_func): + async_action = cast(Callable[..., Coroutine[Any, Any, None]], action) + + @functools.wraps(async_action) + async def async_with_vars( + run_variables: dict[str, Any], context: Context | None = None + ) -> None: + """Wrap action with extra vars.""" + trigger_variables = conf[CONF_VARIABLES] + run_variables.update(trigger_variables.async_render(hass, run_variables)) + await action(run_variables, context) + + wrapper_func = async_with_vars + + else: + + @functools.wraps(action) + async def with_vars( + run_variables: dict[str, Any], context: Context | None = None + ) -> None: + """Wrap action with extra vars.""" + trigger_variables = conf[CONF_VARIABLES] + run_variables.update(trigger_variables.async_render(hass, run_variables)) + action(run_variables, context) + + if is_callback(check_func): + with_vars = callback(with_vars) + + wrapper_func = with_vars + + return wrapper_func async def async_initialize_triggers( diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index 7cee307f3ec..9cd3b0956ce 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -1,12 +1,13 @@ """The tests for the trigger helper.""" -from unittest.mock import MagicMock, call, patch +from unittest.mock import ANY, MagicMock, call, patch import pytest import voluptuous as vol -from homeassistant.core import HomeAssistant, ServiceCall +from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.helpers.trigger import ( _async_get_trigger_platform, + async_initialize_triggers, async_validate_trigger_config, ) from homeassistant.setup import async_setup_component @@ -137,3 +138,62 @@ async def test_trigger_alias( "Automation trigger 'My event' triggered by event 'trigger_event'" in caplog.text ) + + +async def test_async_initialize_triggers( + hass: HomeAssistant, calls: list[ServiceCall], caplog: pytest.LogCaptureFixture +) -> None: + """Test async_initialize_triggers with different action types.""" + + log_cb = MagicMock() + + action_calls = [] + + trigger_config = await async_validate_trigger_config( + hass, + [ + { + "platform": "event", + "event_type": ["trigger_event"], + "variables": { + "name": "Paulus", + "via_event": "{{ trigger.event.event_type }}", + }, + } + ], + ) + + async def async_action(*args): + action_calls.append([*args]) + + @callback + def cb_action(*args): + action_calls.append([*args]) + + def non_cb_action(*args): + action_calls.append([*args]) + + for action in (async_action, cb_action, non_cb_action): + action_calls = [] + + unsub = await async_initialize_triggers( + hass, + trigger_config, + action, + "test", + "", + log_cb, + ) + await hass.async_block_till_done() + + hass.bus.async_fire("trigger_event") + await hass.async_block_till_done() + await hass.async_block_till_done() + + assert len(action_calls) == 1 + assert action_calls[0][0]["name"] == "Paulus" + assert action_calls[0][0]["via_event"] == "trigger_event" + log_cb.assert_called_once_with(ANY, "Initialized trigger") + + log_cb.reset_mock() + unsub()