diff --git a/homeassistant/components/device_automation/condition.py b/homeassistant/components/device_automation/condition.py index 3856458c3dd..f819668f090 100644 --- a/homeassistant/components/device_automation/condition.py +++ b/homeassistant/components/device_automation/condition.py @@ -8,6 +8,7 @@ import voluptuous as vol from homeassistant.const import CONF_DOMAIN from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.condition import ConditionProtocol, trace_condition_function from homeassistant.helpers.typing import ConfigType from . import DeviceAutomationType, async_get_device_automation_platform @@ -17,24 +18,13 @@ if TYPE_CHECKING: from homeassistant.helpers import condition -class DeviceAutomationConditionProtocol(Protocol): +class DeviceAutomationConditionProtocol(ConditionProtocol, Protocol): """Define the format of device_condition modules. - Each module must define either CONDITION_SCHEMA or async_validate_condition_config. + Each module must define either CONDITION_SCHEMA or async_validate_condition_config + from ConditionProtocol. """ - CONDITION_SCHEMA: vol.Schema - - async def async_validate_condition_config( - self, hass: HomeAssistant, config: ConfigType - ) -> ConfigType: - """Validate config.""" - - def async_condition_from_config( - self, hass: HomeAssistant, config: ConfigType - ) -> condition.ConditionCheckerType: - """Evaluate state based on configuration.""" - async def async_get_condition_capabilities( self, hass: HomeAssistant, config: ConfigType ) -> dict[str, vol.Schema]: @@ -62,4 +52,4 @@ async def async_condition_from_config( platform = await async_get_device_automation_platform( hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION ) - return platform.async_condition_from_config(hass, config) + return trace_condition_function(platform.async_condition_from_config(hass, config)) diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index 7513e2b0087..0029a9c906b 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -7,15 +7,13 @@ from collections.abc import Callable, Container, Generator from contextlib import contextmanager from datetime import datetime, time as dt_time, timedelta import functools as ft -import logging import re import sys -from typing import Any, cast +from typing import Any, Protocol, cast import voluptuous as vol from homeassistant.components import zone as zone_cmp -from homeassistant.components.device_automation import condition as device_condition from homeassistant.components.sensor import SensorDeviceClass from homeassistant.const import ( ATTR_DEVICE_CLASS, @@ -55,6 +53,7 @@ from homeassistant.exceptions import ( HomeAssistantError, TemplateError, ) +from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.util.async_ import run_callback_threadsafe import homeassistant.util.dt as dt_util @@ -77,12 +76,44 @@ ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config" FROM_CONFIG_FORMAT = "{}_from_config" VALIDATE_CONFIG_FORMAT = "{}_validate_config" -_LOGGER = logging.getLogger(__name__) +_PLATFORM_ALIASES = { + "and": None, + "device": "device_automation", + "not": None, + "numeric_state": None, + "or": None, + "state": None, + "sun": None, + "template": None, + "time": None, + "trigger": None, + "zone": None, +} INPUT_ENTITY_ID = re.compile( r"^input_(?:select|text|number|boolean|datetime)\.(?!.+__)(?!_)[\da-z_]+(? ConfigType: + """Validate config.""" + + def async_condition_from_config( + self, hass: HomeAssistant, config: ConfigType + ) -> ConditionCheckerType: + """Evaluate state based on configuration.""" + + ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None] @@ -152,6 +183,27 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke return wrapper +async def _async_get_condition_platform( + hass: HomeAssistant, config: ConfigType +) -> ConditionProtocol | None: + platform = config[CONF_CONDITION] + platform = _PLATFORM_ALIASES.get(platform, platform) + if platform is None: + return None + try: + integration = await async_get_integration(hass, platform) + except IntegrationNotFound: + raise HomeAssistantError( + f'Invalid condition "{platform}" specified {config}' + ) from None + try: + return integration.get_platform("condition") + except ImportError: + raise HomeAssistantError( + f"Integration '{platform}' does not provide condition support" + ) from None + + async def async_from_config( hass: HomeAssistant, config: ConfigType, @@ -160,15 +212,18 @@ async def async_from_config( Should be run on the event loop. """ - condition = config.get(CONF_CONDITION) - for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT): - factory = getattr(sys.modules[__name__], fmt.format(condition), None) + factory: Any = None + platform = await _async_get_condition_platform(hass, config) - if factory: - break + if platform is None: + condition = config.get(CONF_CONDITION) + for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT): + factory = getattr(sys.modules[__name__], fmt.format(condition), None) - if factory is None: - raise HomeAssistantError(f'Invalid condition "{condition}" specified {config}') + if factory: + break + else: + factory = platform.async_condition_from_config # Check if condition is not enabled if not config.get(CONF_ENABLED, True): @@ -928,14 +983,6 @@ def zone_from_config(config: ConfigType) -> ConditionCheckerType: return if_in_zone -async def async_device_from_config( - hass: HomeAssistant, config: ConfigType -) -> ConditionCheckerType: - """Test a device condition.""" - checker = await device_condition.async_condition_from_config(hass, config) - return trace_condition_function(checker) - - async def async_trigger_from_config( hass: HomeAssistant, config: ConfigType ) -> ConditionCheckerType: @@ -991,10 +1038,10 @@ async def async_validate_condition_config( config["conditions"] = conditions return config - if condition == "device": - return await device_condition.async_validate_condition_config(hass, config) - - if condition in ("numeric_state", "state"): + platform = await _async_get_condition_platform(hass, config) + if platform is not None and hasattr(platform, "async_validate_condition_config"): + return await platform.async_validate_condition_config(hass, config) + if platform is None and condition in ("numeric_state", "state"): validator = cast( Callable[[HomeAssistant, ConfigType], ConfigType], getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)), diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 5328a1d38ed..0521bc722cd 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -2406,13 +2406,7 @@ async def test_repeat_var_in_condition(hass: HomeAssistant, condition) -> None: script_obj = script.Script( hass, cv.SCRIPT_SCHEMA(sequence), "Test Name", "test_domain" ) - - with mock.patch( - "homeassistant.helpers.condition._LOGGER.error", - side_effect=AssertionError("Template Error"), - ): - await script_obj.async_run(context=Context()) - + await script_obj.async_run(context=Context()) assert len(events) == 2 if condition == "while": @@ -2545,13 +2539,7 @@ async def test_repeat_nested( ] ) script_obj = script.Script(hass, sequence, "Test Name", "test_domain") - - with mock.patch( - "homeassistant.helpers.condition._LOGGER.error", - side_effect=AssertionError("Template Error"), - ): - await script_obj.async_run(variables, Context()) - + await script_obj.async_run(variables, Context()) assert len(events) == 10 assert events[0].data == first_last assert events[-1].data == first_last