Add support for condition platforms to provide multiple conditions (#147376)

This commit is contained in:
Erik Montnemery 2025-06-25 19:10:30 +02:00 committed by GitHub
parent 2b5f5f641d
commit 26e3caea9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 189 additions and 64 deletions

View File

@ -10,6 +10,7 @@ from homeassistant.const import CONF_DOMAIN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.condition import ( from homeassistant.helpers.condition import (
Condition,
ConditionCheckerType, ConditionCheckerType,
trace_condition_function, trace_condition_function,
) )
@ -51,20 +52,38 @@ class DeviceAutomationConditionProtocol(Protocol):
"""List conditions.""" """List conditions."""
async def async_validate_condition_config( class DeviceCondition(Condition):
hass: HomeAssistant, config: ConfigType """Device condition."""
) -> ConfigType:
"""Validate device condition config.""" def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
return await async_validate_device_automation_config( """Initialize condition."""
hass, config, cv.DEVICE_CONDITION_SCHEMA, DeviceAutomationType.CONDITION self._config = config
) self._hass = hass
@classmethod
async def async_validate_condition_config(
cls, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate device condition config."""
return await async_validate_device_automation_config(
hass, config, cv.DEVICE_CONDITION_SCHEMA, DeviceAutomationType.CONDITION
)
async def async_condition_from_config(self) -> condition.ConditionCheckerType:
"""Test a device condition."""
platform = await async_get_device_automation_platform(
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
return trace_condition_function(
platform.async_condition_from_config(self._hass, self._config)
)
async def async_condition_from_config( CONDITIONS: dict[str, type[Condition]] = {
hass: HomeAssistant, config: ConfigType "device": DeviceCondition,
) -> condition.ConditionCheckerType: }
"""Test a device condition."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
) """Return the sun conditions."""
return trace_condition_function(platform.async_condition_from_config(hass, config)) return CONDITIONS

View File

@ -11,6 +11,7 @@ from homeassistant.const import CONF_CONDITION, SUN_EVENT_SUNRISE, SUN_EVENT_SUN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.condition import ( from homeassistant.helpers.condition import (
Condition,
ConditionCheckerType, ConditionCheckerType,
condition_trace_set_result, condition_trace_set_result,
condition_trace_update_result, condition_trace_update_result,
@ -37,13 +38,6 @@ _CONDITION_SCHEMA = vol.All(
) )
async def async_validate_condition_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
return _CONDITION_SCHEMA(config) # type: ignore[no-any-return]
def sun( def sun(
hass: HomeAssistant, hass: HomeAssistant,
before: str | None = None, before: str | None = None,
@ -128,16 +122,41 @@ def sun(
return True return True
def async_condition_from_config(config: ConfigType) -> ConditionCheckerType: class SunCondition(Condition):
"""Wrap action method with sun based condition.""" """Sun condition."""
before = config.get("before")
after = config.get("after")
before_offset = config.get("before_offset")
after_offset = config.get("after_offset")
@trace_condition_function def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: """Initialize condition."""
"""Validate time based if-condition.""" self._config = config
return sun(hass, before, after, before_offset, after_offset) self._hass = hass
return sun_if @classmethod
async def async_validate_condition_config(
cls, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
return _CONDITION_SCHEMA(config) # type: ignore[no-any-return]
async def async_condition_from_config(self) -> ConditionCheckerType:
"""Wrap action method with sun based condition."""
before = self._config.get("before")
after = self._config.get("after")
before_offset = self._config.get("before_offset")
after_offset = self._config.get("after_offset")
@trace_condition_function
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Validate time based if-condition."""
return sun(hass, before, after, before_offset, after_offset)
return sun_if
CONDITIONS: dict[str, type[Condition]] = {
"sun": SunCondition,
}
async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
"""Return the sun conditions."""
return CONDITIONS

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import abc
import asyncio import asyncio
from collections import deque from collections import deque
from collections.abc import Callable, Container, Generator from collections.abc import Callable, Container, Generator
@ -75,7 +76,7 @@ ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config"
FROM_CONFIG_FORMAT = "{}_from_config" FROM_CONFIG_FORMAT = "{}_from_config"
VALIDATE_CONFIG_FORMAT = "{}_validate_config" VALIDATE_CONFIG_FORMAT = "{}_validate_config"
_PLATFORM_ALIASES = { _PLATFORM_ALIASES: dict[str | None, str | None] = {
"and": None, "and": None,
"device": "device_automation", "device": "device_automation",
"not": None, "not": None,
@ -93,20 +94,33 @@ INPUT_ENTITY_ID = re.compile(
) )
class ConditionProtocol(Protocol): class Condition(abc.ABC):
"""Define the format of condition modules.""" """Condition class."""
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
"""Initialize condition."""
@classmethod
@abc.abstractmethod
async def async_validate_condition_config( async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType cls, hass: HomeAssistant, config: ConfigType
) -> ConfigType: ) -> ConfigType:
"""Validate config.""" """Validate config."""
def async_condition_from_config( @abc.abstractmethod
self, hass: HomeAssistant, config: ConfigType async def async_condition_from_config(self) -> ConditionCheckerType:
) -> ConditionCheckerType:
"""Evaluate state based on configuration.""" """Evaluate state based on configuration."""
class ConditionProtocol(Protocol):
"""Define the format of condition modules."""
async def async_get_conditions(
self, hass: HomeAssistant
) -> dict[str, type[Condition]]:
"""Return the conditions provided by this integration."""
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None] type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
@ -179,7 +193,9 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke
async def _async_get_condition_platform( async def _async_get_condition_platform(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> ConditionProtocol | None: ) -> ConditionProtocol | None:
platform = config[CONF_CONDITION] condition_key: str = config[CONF_CONDITION]
platform_and_sub_type = condition_key.partition(".")
platform: str | None = platform_and_sub_type[0]
platform = _PLATFORM_ALIASES.get(platform, platform) platform = _PLATFORM_ALIASES.get(platform, platform)
if platform is None: if platform is None:
return None return None
@ -187,7 +203,7 @@ async def _async_get_condition_platform(
integration = await async_get_integration(hass, platform) integration = await async_get_integration(hass, platform)
except IntegrationNotFound: except IntegrationNotFound:
raise HomeAssistantError( raise HomeAssistantError(
f'Invalid condition "{platform}" specified {config}' f'Invalid condition "{condition_key}" specified {config}'
) from None ) from None
try: try:
return await integration.async_get_platform("condition") return await integration.async_get_platform("condition")
@ -205,19 +221,6 @@ async def async_from_config(
Should be run on the event loop. Should be run on the event loop.
""" """
factory: Any = None
platform = await _async_get_condition_platform(hass, config)
if platform is None:
condition = config.get(CONF_CONDITION)
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
if factory:
break
else:
factory = platform.async_condition_from_config
# Check if condition is not enabled # Check if condition is not enabled
if CONF_ENABLED in config: if CONF_ENABLED in config:
enabled = config[CONF_ENABLED] enabled = config[CONF_ENABLED]
@ -239,6 +242,21 @@ async def async_from_config(
return disabled_condition return disabled_condition
condition: str = config[CONF_CONDITION]
factory: Any = None
platform = await _async_get_condition_platform(hass, config)
if platform is not None:
condition_descriptors = await platform.async_get_conditions(hass)
condition_instance = condition_descriptors[condition](hass, config)
return await condition_instance.async_condition_from_config()
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
if factory:
break
# Check for partials to properly determine if coroutine function # Check for partials to properly determine if coroutine function
check_factory = factory check_factory = factory
while isinstance(check_factory, ft.partial): while isinstance(check_factory, ft.partial):
@ -936,7 +954,7 @@ async def async_validate_condition_config(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> ConfigType: ) -> ConfigType:
"""Validate config.""" """Validate config."""
condition = config[CONF_CONDITION] condition: str = config[CONF_CONDITION]
if condition in ("and", "not", "or"): if condition in ("and", "not", "or"):
conditions = [] conditions = []
for sub_cond in config["conditions"]: for sub_cond in config["conditions"]:
@ -947,7 +965,10 @@ async def async_validate_condition_config(
platform = await _async_get_condition_platform(hass, config) platform = await _async_get_condition_platform(hass, config)
if platform is not None: if platform is not None:
return await platform.async_validate_condition_config(hass, config) condition_descriptors = await platform.async_get_conditions(hass)
if not (condition_class := condition_descriptors.get(condition)):
raise vol.Invalid(f"Invalid condition '{condition}' specified")
return await condition_class.async_validate_condition_config(hass, config)
if platform is None and condition in ("numeric_state", "state"): if platform is None and condition in ("numeric_state", "state"):
validator = cast( validator = cast(
Callable[[HomeAssistant, ConfigType], ConfigType], Callable[[HomeAssistant, ConfigType], ConfigType],

View File

@ -2,7 +2,7 @@
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, Mock, patch
from freezegun import freeze_time from freezegun import freeze_time
import pytest import pytest
@ -26,9 +26,12 @@ from homeassistant.helpers import (
trace, trace,
) )
from homeassistant.helpers.template import Template from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from tests.common import MockModule, mock_integration, mock_platform
def assert_element(trace_element, expected_element, path): def assert_element(trace_element, expected_element, path):
"""Assert a trace element is as expected. """Assert a trace element is as expected.
@ -2251,15 +2254,78 @@ async def test_trigger(hass: HomeAssistant) -> None:
assert test(hass, {"trigger": {"id": "123456"}}) assert test(hass, {"trigger": {"id": "123456"}})
async def test_platform_async_validate_condition_config(hass: HomeAssistant) -> None: async def test_platform_async_get_conditions(hass: HomeAssistant) -> None:
"""Test platform.async_validate_condition_config will be called if it exists.""" """Test platform.async_get_conditions will be called if it exists."""
config = {CONF_DEVICE_ID: "test", CONF_DOMAIN: "test", CONF_CONDITION: "device"} config = {CONF_DEVICE_ID: "test", CONF_DOMAIN: "test", CONF_CONDITION: "device"}
with patch( with patch(
"homeassistant.components.device_automation.condition.async_validate_condition_config", "homeassistant.components.device_automation.condition.async_get_conditions",
AsyncMock(), AsyncMock(return_value={"device": AsyncMock()}),
) as device_automation_validate_condition_mock: ) as device_automation_async_get_conditions_mock:
await condition.async_validate_condition_config(hass, config) await condition.async_validate_condition_config(hass, config)
device_automation_validate_condition_mock.assert_awaited() device_automation_async_get_conditions_mock.assert_awaited()
async def test_platform_multiple_conditions(hass: HomeAssistant) -> None:
"""Test a condition platform with multiple conditions."""
class MockCondition(condition.Condition):
"""Mock condition."""
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
"""Initialize condition."""
@classmethod
async def async_validate_condition_config(
cls, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
return config
class MockCondition1(MockCondition):
"""Mock condition 1."""
async def async_condition_from_config(self) -> condition.ConditionCheckerType:
"""Evaluate state based on configuration."""
return lambda hass, vars: True
class MockCondition2(MockCondition):
"""Mock condition 2."""
async def async_condition_from_config(self) -> condition.ConditionCheckerType:
"""Evaluate state based on configuration."""
return lambda hass, vars: False
async def async_get_conditions(
hass: HomeAssistant,
) -> dict[str, type[condition.Condition]]:
return {
"test": MockCondition1,
"test.cond_2": MockCondition2,
}
mock_integration(hass, MockModule("test"))
mock_platform(
hass, "test.condition", Mock(async_get_conditions=async_get_conditions)
)
config_1 = {CONF_CONDITION: "test"}
config_2 = {CONF_CONDITION: "test.cond_2"}
config_3 = {CONF_CONDITION: "test.unknown_cond"}
assert await condition.async_validate_condition_config(hass, config_1) == config_1
assert await condition.async_validate_condition_config(hass, config_2) == config_2
with pytest.raises(
vol.Invalid, match="Invalid condition 'test.unknown_cond' specified"
):
await condition.async_validate_condition_config(hass, config_3)
cond_func = await condition.async_from_config(hass, config_1)
assert cond_func(hass, {}) is True
cond_func = await condition.async_from_config(hass, config_2)
assert cond_func(hass, {}) is False
with pytest.raises(KeyError):
await condition.async_from_config(hass, config_3)
@pytest.mark.parametrize("enabled_value", [True, "{{ 1 == 1 }}"]) @pytest.mark.parametrize("enabled_value", [True, "{{ 1 == 1 }}"])