mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Allow conditions to be implemented in platforms (#88509)
* Allow conditions to be implemented in platforms * Update tests * Tweak typing * Rebase fixes
This commit is contained in:
parent
2f826a6f86
commit
d90ee85118
@ -8,6 +8,7 @@ import voluptuous as vol
|
|||||||
from homeassistant.const import CONF_DOMAIN
|
from homeassistant.const import CONF_DOMAIN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
from homeassistant.helpers.condition import ConditionProtocol, trace_condition_function
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from . import DeviceAutomationType, async_get_device_automation_platform
|
from . import DeviceAutomationType, async_get_device_automation_platform
|
||||||
@ -17,24 +18,13 @@ if TYPE_CHECKING:
|
|||||||
from homeassistant.helpers import condition
|
from homeassistant.helpers import condition
|
||||||
|
|
||||||
|
|
||||||
class DeviceAutomationConditionProtocol(Protocol):
|
class DeviceAutomationConditionProtocol(ConditionProtocol, Protocol):
|
||||||
"""Define the format of device_condition modules.
|
"""Define the format of device_condition modules.
|
||||||
|
|
||||||
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
|
Each module must define either CONDITION_SCHEMA or async_validate_condition_config
|
||||||
|
from ConditionProtocol.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CONDITION_SCHEMA: vol.Schema
|
|
||||||
|
|
||||||
async def async_validate_condition_config(
|
|
||||||
self, hass: HomeAssistant, config: ConfigType
|
|
||||||
) -> ConfigType:
|
|
||||||
"""Validate config."""
|
|
||||||
|
|
||||||
def async_condition_from_config(
|
|
||||||
self, hass: HomeAssistant, config: ConfigType
|
|
||||||
) -> condition.ConditionCheckerType:
|
|
||||||
"""Evaluate state based on configuration."""
|
|
||||||
|
|
||||||
async def async_get_condition_capabilities(
|
async def async_get_condition_capabilities(
|
||||||
self, hass: HomeAssistant, config: ConfigType
|
self, hass: HomeAssistant, config: ConfigType
|
||||||
) -> dict[str, vol.Schema]:
|
) -> dict[str, vol.Schema]:
|
||||||
@ -62,4 +52,4 @@ async def async_condition_from_config(
|
|||||||
platform = await async_get_device_automation_platform(
|
platform = await async_get_device_automation_platform(
|
||||||
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||||
)
|
)
|
||||||
return platform.async_condition_from_config(hass, config)
|
return trace_condition_function(platform.async_condition_from_config(hass, config))
|
||||||
|
@ -7,15 +7,13 @@ from collections.abc import Callable, Container, Generator
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime, time as dt_time, timedelta
|
from datetime import datetime, time as dt_time, timedelta
|
||||||
import functools as ft
|
import functools as ft
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, cast
|
from typing import Any, Protocol, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import zone as zone_cmp
|
from homeassistant.components import zone as zone_cmp
|
||||||
from homeassistant.components.device_automation import condition as device_condition
|
|
||||||
from homeassistant.components.sensor import SensorDeviceClass
|
from homeassistant.components.sensor import SensorDeviceClass
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_DEVICE_CLASS,
|
ATTR_DEVICE_CLASS,
|
||||||
@ -55,6 +53,7 @@ from homeassistant.exceptions import (
|
|||||||
HomeAssistantError,
|
HomeAssistantError,
|
||||||
TemplateError,
|
TemplateError,
|
||||||
)
|
)
|
||||||
|
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||||
from homeassistant.util.async_ import run_callback_threadsafe
|
from homeassistant.util.async_ import run_callback_threadsafe
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
@ -77,12 +76,44 @@ ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config"
|
|||||||
FROM_CONFIG_FORMAT = "{}_from_config"
|
FROM_CONFIG_FORMAT = "{}_from_config"
|
||||||
VALIDATE_CONFIG_FORMAT = "{}_validate_config"
|
VALIDATE_CONFIG_FORMAT = "{}_validate_config"
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_PLATFORM_ALIASES = {
|
||||||
|
"and": None,
|
||||||
|
"device": "device_automation",
|
||||||
|
"not": None,
|
||||||
|
"numeric_state": None,
|
||||||
|
"or": None,
|
||||||
|
"state": None,
|
||||||
|
"sun": None,
|
||||||
|
"template": None,
|
||||||
|
"time": None,
|
||||||
|
"trigger": None,
|
||||||
|
"zone": None,
|
||||||
|
}
|
||||||
|
|
||||||
INPUT_ENTITY_ID = re.compile(
|
INPUT_ENTITY_ID = re.compile(
|
||||||
r"^input_(?:select|text|number|boolean|datetime)\.(?!.+__)(?!_)[\da-z_]+(?<!_)$"
|
r"^input_(?:select|text|number|boolean|datetime)\.(?!.+__)(?!_)[\da-z_]+(?<!_)$"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionProtocol(Protocol):
|
||||||
|
"""Define the format of device_condition modules.
|
||||||
|
|
||||||
|
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CONDITION_SCHEMA: vol.Schema
|
||||||
|
|
||||||
|
async def async_validate_condition_config(
|
||||||
|
self, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConfigType:
|
||||||
|
"""Validate config."""
|
||||||
|
|
||||||
|
def async_condition_from_config(
|
||||||
|
self, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConditionCheckerType:
|
||||||
|
"""Evaluate state based on configuration."""
|
||||||
|
|
||||||
|
|
||||||
ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
|
ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
|
||||||
|
|
||||||
|
|
||||||
@ -152,6 +183,27 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_get_condition_platform(
|
||||||
|
hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConditionProtocol | None:
|
||||||
|
platform = config[CONF_CONDITION]
|
||||||
|
platform = _PLATFORM_ALIASES.get(platform, platform)
|
||||||
|
if platform is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
integration = await async_get_integration(hass, platform)
|
||||||
|
except IntegrationNotFound:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f'Invalid condition "{platform}" specified {config}'
|
||||||
|
) from None
|
||||||
|
try:
|
||||||
|
return integration.get_platform("condition")
|
||||||
|
except ImportError:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"Integration '{platform}' does not provide condition support"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
async def async_from_config(
|
async def async_from_config(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config: ConfigType,
|
config: ConfigType,
|
||||||
@ -160,15 +212,18 @@ async def async_from_config(
|
|||||||
|
|
||||||
Should be run on the event loop.
|
Should be run on the event loop.
|
||||||
"""
|
"""
|
||||||
condition = config.get(CONF_CONDITION)
|
factory: Any = None
|
||||||
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
|
platform = await _async_get_condition_platform(hass, config)
|
||||||
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
|
|
||||||
|
|
||||||
if factory:
|
if platform is None:
|
||||||
break
|
condition = config.get(CONF_CONDITION)
|
||||||
|
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
|
||||||
|
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
|
||||||
|
|
||||||
if factory is None:
|
if factory:
|
||||||
raise HomeAssistantError(f'Invalid condition "{condition}" specified {config}')
|
break
|
||||||
|
else:
|
||||||
|
factory = platform.async_condition_from_config
|
||||||
|
|
||||||
# Check if condition is not enabled
|
# Check if condition is not enabled
|
||||||
if not config.get(CONF_ENABLED, True):
|
if not config.get(CONF_ENABLED, True):
|
||||||
@ -928,14 +983,6 @@ def zone_from_config(config: ConfigType) -> ConditionCheckerType:
|
|||||||
return if_in_zone
|
return if_in_zone
|
||||||
|
|
||||||
|
|
||||||
async def async_device_from_config(
|
|
||||||
hass: HomeAssistant, config: ConfigType
|
|
||||||
) -> ConditionCheckerType:
|
|
||||||
"""Test a device condition."""
|
|
||||||
checker = await device_condition.async_condition_from_config(hass, config)
|
|
||||||
return trace_condition_function(checker)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_trigger_from_config(
|
async def async_trigger_from_config(
|
||||||
hass: HomeAssistant, config: ConfigType
|
hass: HomeAssistant, config: ConfigType
|
||||||
) -> ConditionCheckerType:
|
) -> ConditionCheckerType:
|
||||||
@ -991,10 +1038,10 @@ async def async_validate_condition_config(
|
|||||||
config["conditions"] = conditions
|
config["conditions"] = conditions
|
||||||
return config
|
return config
|
||||||
|
|
||||||
if condition == "device":
|
platform = await _async_get_condition_platform(hass, config)
|
||||||
return await device_condition.async_validate_condition_config(hass, config)
|
if platform is not None and hasattr(platform, "async_validate_condition_config"):
|
||||||
|
return await platform.async_validate_condition_config(hass, config)
|
||||||
if condition in ("numeric_state", "state"):
|
if platform is None and condition in ("numeric_state", "state"):
|
||||||
validator = cast(
|
validator = cast(
|
||||||
Callable[[HomeAssistant, ConfigType], ConfigType],
|
Callable[[HomeAssistant, ConfigType], ConfigType],
|
||||||
getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)),
|
getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)),
|
||||||
|
@ -2406,13 +2406,7 @@ async def test_repeat_var_in_condition(hass: HomeAssistant, condition) -> None:
|
|||||||
script_obj = script.Script(
|
script_obj = script.Script(
|
||||||
hass, cv.SCRIPT_SCHEMA(sequence), "Test Name", "test_domain"
|
hass, cv.SCRIPT_SCHEMA(sequence), "Test Name", "test_domain"
|
||||||
)
|
)
|
||||||
|
await script_obj.async_run(context=Context())
|
||||||
with mock.patch(
|
|
||||||
"homeassistant.helpers.condition._LOGGER.error",
|
|
||||||
side_effect=AssertionError("Template Error"),
|
|
||||||
):
|
|
||||||
await script_obj.async_run(context=Context())
|
|
||||||
|
|
||||||
assert len(events) == 2
|
assert len(events) == 2
|
||||||
|
|
||||||
if condition == "while":
|
if condition == "while":
|
||||||
@ -2545,13 +2539,7 @@ async def test_repeat_nested(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||||
|
await script_obj.async_run(variables, Context())
|
||||||
with mock.patch(
|
|
||||||
"homeassistant.helpers.condition._LOGGER.error",
|
|
||||||
side_effect=AssertionError("Template Error"),
|
|
||||||
):
|
|
||||||
await script_obj.async_run(variables, Context())
|
|
||||||
|
|
||||||
assert len(events) == 10
|
assert len(events) == 10
|
||||||
assert events[0].data == first_last
|
assert events[0].data == first_last
|
||||||
assert events[-1].data == first_last
|
assert events[-1].data == first_last
|
||||||
|
Loading…
x
Reference in New Issue
Block a user