mirror of
https://github.com/home-assistant/core.git
synced 2025-07-13 16:27:08 +00:00
Add support for condition platforms to provide multiple conditions (#147376)
This commit is contained in:
parent
2b5f5f641d
commit
26e3caea9a
@ -10,6 +10,7 @@ from homeassistant.const import CONF_DOMAIN
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.condition import (
|
from homeassistant.helpers.condition import (
|
||||||
|
Condition,
|
||||||
ConditionCheckerType,
|
ConditionCheckerType,
|
||||||
trace_condition_function,
|
trace_condition_function,
|
||||||
)
|
)
|
||||||
@ -51,20 +52,38 @@ class DeviceAutomationConditionProtocol(Protocol):
|
|||||||
"""List conditions."""
|
"""List conditions."""
|
||||||
|
|
||||||
|
|
||||||
async def async_validate_condition_config(
|
class DeviceCondition(Condition):
|
||||||
hass: HomeAssistant, config: ConfigType
|
"""Device condition."""
|
||||||
) -> ConfigType:
|
|
||||||
"""Validate device condition config."""
|
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
|
||||||
return await async_validate_device_automation_config(
|
"""Initialize condition."""
|
||||||
hass, config, cv.DEVICE_CONDITION_SCHEMA, DeviceAutomationType.CONDITION
|
self._config = config
|
||||||
)
|
self._hass = hass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def async_validate_condition_config(
|
||||||
|
cls, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConfigType:
|
||||||
|
"""Validate device condition config."""
|
||||||
|
return await async_validate_device_automation_config(
|
||||||
|
hass, config, cv.DEVICE_CONDITION_SCHEMA, DeviceAutomationType.CONDITION
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_condition_from_config(self) -> condition.ConditionCheckerType:
|
||||||
|
"""Test a device condition."""
|
||||||
|
platform = await async_get_device_automation_platform(
|
||||||
|
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||||
|
)
|
||||||
|
return trace_condition_function(
|
||||||
|
platform.async_condition_from_config(self._hass, self._config)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_condition_from_config(
|
CONDITIONS: dict[str, type[Condition]] = {
|
||||||
hass: HomeAssistant, config: ConfigType
|
"device": DeviceCondition,
|
||||||
) -> condition.ConditionCheckerType:
|
}
|
||||||
"""Test a device condition."""
|
|
||||||
platform = await async_get_device_automation_platform(
|
|
||||||
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
|
||||||
)
|
"""Return the sun conditions."""
|
||||||
return trace_condition_function(platform.async_condition_from_config(hass, config))
|
return CONDITIONS
|
||||||
|
@ -11,6 +11,7 @@ from homeassistant.const import CONF_CONDITION, SUN_EVENT_SUNRISE, SUN_EVENT_SUN
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.condition import (
|
from homeassistant.helpers.condition import (
|
||||||
|
Condition,
|
||||||
ConditionCheckerType,
|
ConditionCheckerType,
|
||||||
condition_trace_set_result,
|
condition_trace_set_result,
|
||||||
condition_trace_update_result,
|
condition_trace_update_result,
|
||||||
@ -37,13 +38,6 @@ _CONDITION_SCHEMA = vol.All(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_validate_condition_config(
|
|
||||||
hass: HomeAssistant, config: ConfigType
|
|
||||||
) -> ConfigType:
|
|
||||||
"""Validate config."""
|
|
||||||
return _CONDITION_SCHEMA(config) # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
|
|
||||||
def sun(
|
def sun(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
before: str | None = None,
|
before: str | None = None,
|
||||||
@ -128,16 +122,41 @@ def sun(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def async_condition_from_config(config: ConfigType) -> ConditionCheckerType:
|
class SunCondition(Condition):
|
||||||
"""Wrap action method with sun based condition."""
|
"""Sun condition."""
|
||||||
before = config.get("before")
|
|
||||||
after = config.get("after")
|
|
||||||
before_offset = config.get("before_offset")
|
|
||||||
after_offset = config.get("after_offset")
|
|
||||||
|
|
||||||
@trace_condition_function
|
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
|
||||||
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
"""Initialize condition."""
|
||||||
"""Validate time based if-condition."""
|
self._config = config
|
||||||
return sun(hass, before, after, before_offset, after_offset)
|
self._hass = hass
|
||||||
|
|
||||||
return sun_if
|
@classmethod
|
||||||
|
async def async_validate_condition_config(
|
||||||
|
cls, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConfigType:
|
||||||
|
"""Validate config."""
|
||||||
|
return _CONDITION_SCHEMA(config) # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
async def async_condition_from_config(self) -> ConditionCheckerType:
|
||||||
|
"""Wrap action method with sun based condition."""
|
||||||
|
before = self._config.get("before")
|
||||||
|
after = self._config.get("after")
|
||||||
|
before_offset = self._config.get("before_offset")
|
||||||
|
after_offset = self._config.get("after_offset")
|
||||||
|
|
||||||
|
@trace_condition_function
|
||||||
|
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
||||||
|
"""Validate time based if-condition."""
|
||||||
|
return sun(hass, before, after, before_offset, after_offset)
|
||||||
|
|
||||||
|
return sun_if
|
||||||
|
|
||||||
|
|
||||||
|
CONDITIONS: dict[str, type[Condition]] = {
|
||||||
|
"sun": SunCondition,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
|
||||||
|
"""Return the sun conditions."""
|
||||||
|
return CONDITIONS
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Callable, Container, Generator
|
from collections.abc import Callable, Container, Generator
|
||||||
@ -75,7 +76,7 @@ ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config"
|
|||||||
FROM_CONFIG_FORMAT = "{}_from_config"
|
FROM_CONFIG_FORMAT = "{}_from_config"
|
||||||
VALIDATE_CONFIG_FORMAT = "{}_validate_config"
|
VALIDATE_CONFIG_FORMAT = "{}_validate_config"
|
||||||
|
|
||||||
_PLATFORM_ALIASES = {
|
_PLATFORM_ALIASES: dict[str | None, str | None] = {
|
||||||
"and": None,
|
"and": None,
|
||||||
"device": "device_automation",
|
"device": "device_automation",
|
||||||
"not": None,
|
"not": None,
|
||||||
@ -93,20 +94,33 @@ INPUT_ENTITY_ID = re.compile(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConditionProtocol(Protocol):
|
class Condition(abc.ABC):
|
||||||
"""Define the format of condition modules."""
|
"""Condition class."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
|
||||||
|
"""Initialize condition."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abc.abstractmethod
|
||||||
async def async_validate_condition_config(
|
async def async_validate_condition_config(
|
||||||
self, hass: HomeAssistant, config: ConfigType
|
cls, hass: HomeAssistant, config: ConfigType
|
||||||
) -> ConfigType:
|
) -> ConfigType:
|
||||||
"""Validate config."""
|
"""Validate config."""
|
||||||
|
|
||||||
def async_condition_from_config(
|
@abc.abstractmethod
|
||||||
self, hass: HomeAssistant, config: ConfigType
|
async def async_condition_from_config(self) -> ConditionCheckerType:
|
||||||
) -> ConditionCheckerType:
|
|
||||||
"""Evaluate state based on configuration."""
|
"""Evaluate state based on configuration."""
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionProtocol(Protocol):
|
||||||
|
"""Define the format of condition modules."""
|
||||||
|
|
||||||
|
async def async_get_conditions(
|
||||||
|
self, hass: HomeAssistant
|
||||||
|
) -> dict[str, type[Condition]]:
|
||||||
|
"""Return the conditions provided by this integration."""
|
||||||
|
|
||||||
|
|
||||||
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
|
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
|
||||||
|
|
||||||
|
|
||||||
@ -179,7 +193,9 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke
|
|||||||
async def _async_get_condition_platform(
|
async def _async_get_condition_platform(
|
||||||
hass: HomeAssistant, config: ConfigType
|
hass: HomeAssistant, config: ConfigType
|
||||||
) -> ConditionProtocol | None:
|
) -> ConditionProtocol | None:
|
||||||
platform = config[CONF_CONDITION]
|
condition_key: str = config[CONF_CONDITION]
|
||||||
|
platform_and_sub_type = condition_key.partition(".")
|
||||||
|
platform: str | None = platform_and_sub_type[0]
|
||||||
platform = _PLATFORM_ALIASES.get(platform, platform)
|
platform = _PLATFORM_ALIASES.get(platform, platform)
|
||||||
if platform is None:
|
if platform is None:
|
||||||
return None
|
return None
|
||||||
@ -187,7 +203,7 @@ async def _async_get_condition_platform(
|
|||||||
integration = await async_get_integration(hass, platform)
|
integration = await async_get_integration(hass, platform)
|
||||||
except IntegrationNotFound:
|
except IntegrationNotFound:
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
f'Invalid condition "{platform}" specified {config}'
|
f'Invalid condition "{condition_key}" specified {config}'
|
||||||
) from None
|
) from None
|
||||||
try:
|
try:
|
||||||
return await integration.async_get_platform("condition")
|
return await integration.async_get_platform("condition")
|
||||||
@ -205,19 +221,6 @@ async def async_from_config(
|
|||||||
|
|
||||||
Should be run on the event loop.
|
Should be run on the event loop.
|
||||||
"""
|
"""
|
||||||
factory: Any = None
|
|
||||||
platform = await _async_get_condition_platform(hass, config)
|
|
||||||
|
|
||||||
if platform is None:
|
|
||||||
condition = config.get(CONF_CONDITION)
|
|
||||||
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
|
|
||||||
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
|
|
||||||
|
|
||||||
if factory:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
factory = platform.async_condition_from_config
|
|
||||||
|
|
||||||
# Check if condition is not enabled
|
# Check if condition is not enabled
|
||||||
if CONF_ENABLED in config:
|
if CONF_ENABLED in config:
|
||||||
enabled = config[CONF_ENABLED]
|
enabled = config[CONF_ENABLED]
|
||||||
@ -239,6 +242,21 @@ async def async_from_config(
|
|||||||
|
|
||||||
return disabled_condition
|
return disabled_condition
|
||||||
|
|
||||||
|
condition: str = config[CONF_CONDITION]
|
||||||
|
factory: Any = None
|
||||||
|
platform = await _async_get_condition_platform(hass, config)
|
||||||
|
|
||||||
|
if platform is not None:
|
||||||
|
condition_descriptors = await platform.async_get_conditions(hass)
|
||||||
|
condition_instance = condition_descriptors[condition](hass, config)
|
||||||
|
return await condition_instance.async_condition_from_config()
|
||||||
|
|
||||||
|
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
|
||||||
|
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
|
||||||
|
|
||||||
|
if factory:
|
||||||
|
break
|
||||||
|
|
||||||
# Check for partials to properly determine if coroutine function
|
# Check for partials to properly determine if coroutine function
|
||||||
check_factory = factory
|
check_factory = factory
|
||||||
while isinstance(check_factory, ft.partial):
|
while isinstance(check_factory, ft.partial):
|
||||||
@ -936,7 +954,7 @@ async def async_validate_condition_config(
|
|||||||
hass: HomeAssistant, config: ConfigType
|
hass: HomeAssistant, config: ConfigType
|
||||||
) -> ConfigType:
|
) -> ConfigType:
|
||||||
"""Validate config."""
|
"""Validate config."""
|
||||||
condition = config[CONF_CONDITION]
|
condition: str = config[CONF_CONDITION]
|
||||||
if condition in ("and", "not", "or"):
|
if condition in ("and", "not", "or"):
|
||||||
conditions = []
|
conditions = []
|
||||||
for sub_cond in config["conditions"]:
|
for sub_cond in config["conditions"]:
|
||||||
@ -947,7 +965,10 @@ async def async_validate_condition_config(
|
|||||||
|
|
||||||
platform = await _async_get_condition_platform(hass, config)
|
platform = await _async_get_condition_platform(hass, config)
|
||||||
if platform is not None:
|
if platform is not None:
|
||||||
return await platform.async_validate_condition_config(hass, config)
|
condition_descriptors = await platform.async_get_conditions(hass)
|
||||||
|
if not (condition_class := condition_descriptors.get(condition)):
|
||||||
|
raise vol.Invalid(f"Invalid condition '{condition}' specified")
|
||||||
|
return await condition_class.async_validate_condition_config(hass, config)
|
||||||
if platform is None and condition in ("numeric_state", "state"):
|
if platform is None and condition in ("numeric_state", "state"):
|
||||||
validator = cast(
|
validator = cast(
|
||||||
Callable[[HomeAssistant, ConfigType], ConfigType],
|
Callable[[HomeAssistant, ConfigType], ConfigType],
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
import pytest
|
import pytest
|
||||||
@ -26,9 +26,12 @@ from homeassistant.helpers import (
|
|||||||
trace,
|
trace,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.template import Template
|
from homeassistant.helpers.template import Template
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
from tests.common import MockModule, mock_integration, mock_platform
|
||||||
|
|
||||||
|
|
||||||
def assert_element(trace_element, expected_element, path):
|
def assert_element(trace_element, expected_element, path):
|
||||||
"""Assert a trace element is as expected.
|
"""Assert a trace element is as expected.
|
||||||
@ -2251,15 +2254,78 @@ async def test_trigger(hass: HomeAssistant) -> None:
|
|||||||
assert test(hass, {"trigger": {"id": "123456"}})
|
assert test(hass, {"trigger": {"id": "123456"}})
|
||||||
|
|
||||||
|
|
||||||
async def test_platform_async_validate_condition_config(hass: HomeAssistant) -> None:
|
async def test_platform_async_get_conditions(hass: HomeAssistant) -> None:
|
||||||
"""Test platform.async_validate_condition_config will be called if it exists."""
|
"""Test platform.async_get_conditions will be called if it exists."""
|
||||||
config = {CONF_DEVICE_ID: "test", CONF_DOMAIN: "test", CONF_CONDITION: "device"}
|
config = {CONF_DEVICE_ID: "test", CONF_DOMAIN: "test", CONF_CONDITION: "device"}
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.device_automation.condition.async_validate_condition_config",
|
"homeassistant.components.device_automation.condition.async_get_conditions",
|
||||||
AsyncMock(),
|
AsyncMock(return_value={"device": AsyncMock()}),
|
||||||
) as device_automation_validate_condition_mock:
|
) as device_automation_async_get_conditions_mock:
|
||||||
await condition.async_validate_condition_config(hass, config)
|
await condition.async_validate_condition_config(hass, config)
|
||||||
device_automation_validate_condition_mock.assert_awaited()
|
device_automation_async_get_conditions_mock.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_platform_multiple_conditions(hass: HomeAssistant) -> None:
|
||||||
|
"""Test a condition platform with multiple conditions."""
|
||||||
|
|
||||||
|
class MockCondition(condition.Condition):
|
||||||
|
"""Mock condition."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
|
||||||
|
"""Initialize condition."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def async_validate_condition_config(
|
||||||
|
cls, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConfigType:
|
||||||
|
"""Validate config."""
|
||||||
|
return config
|
||||||
|
|
||||||
|
class MockCondition1(MockCondition):
|
||||||
|
"""Mock condition 1."""
|
||||||
|
|
||||||
|
async def async_condition_from_config(self) -> condition.ConditionCheckerType:
|
||||||
|
"""Evaluate state based on configuration."""
|
||||||
|
return lambda hass, vars: True
|
||||||
|
|
||||||
|
class MockCondition2(MockCondition):
|
||||||
|
"""Mock condition 2."""
|
||||||
|
|
||||||
|
async def async_condition_from_config(self) -> condition.ConditionCheckerType:
|
||||||
|
"""Evaluate state based on configuration."""
|
||||||
|
return lambda hass, vars: False
|
||||||
|
|
||||||
|
async def async_get_conditions(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> dict[str, type[condition.Condition]]:
|
||||||
|
return {
|
||||||
|
"test": MockCondition1,
|
||||||
|
"test.cond_2": MockCondition2,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_integration(hass, MockModule("test"))
|
||||||
|
mock_platform(
|
||||||
|
hass, "test.condition", Mock(async_get_conditions=async_get_conditions)
|
||||||
|
)
|
||||||
|
|
||||||
|
config_1 = {CONF_CONDITION: "test"}
|
||||||
|
config_2 = {CONF_CONDITION: "test.cond_2"}
|
||||||
|
config_3 = {CONF_CONDITION: "test.unknown_cond"}
|
||||||
|
assert await condition.async_validate_condition_config(hass, config_1) == config_1
|
||||||
|
assert await condition.async_validate_condition_config(hass, config_2) == config_2
|
||||||
|
with pytest.raises(
|
||||||
|
vol.Invalid, match="Invalid condition 'test.unknown_cond' specified"
|
||||||
|
):
|
||||||
|
await condition.async_validate_condition_config(hass, config_3)
|
||||||
|
|
||||||
|
cond_func = await condition.async_from_config(hass, config_1)
|
||||||
|
assert cond_func(hass, {}) is True
|
||||||
|
|
||||||
|
cond_func = await condition.async_from_config(hass, config_2)
|
||||||
|
assert cond_func(hass, {}) is False
|
||||||
|
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
await condition.async_from_config(hass, config_3)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("enabled_value", [True, "{{ 1 == 1 }}"])
|
@pytest.mark.parametrize("enabled_value", [True, "{{ 1 == 1 }}"])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user