Allow conditions to be implemented in platforms (#88509)

* Allow conditions to be implemented in platforms

* Update tests

* Tweak typing

* Rebase fixes
This commit is contained in:
Erik Montnemery 2023-02-24 04:30:51 +01:00 committed by GitHub
parent 2f826a6f86
commit d90ee85118
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 77 additions and 52 deletions

View File

@ -8,6 +8,7 @@ import voluptuous as vol
from homeassistant.const import CONF_DOMAIN from homeassistant.const import CONF_DOMAIN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.condition import ConditionProtocol, trace_condition_function
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import DeviceAutomationType, async_get_device_automation_platform from . import DeviceAutomationType, async_get_device_automation_platform
@ -17,24 +18,13 @@ if TYPE_CHECKING:
from homeassistant.helpers import condition from homeassistant.helpers import condition
class DeviceAutomationConditionProtocol(Protocol): class DeviceAutomationConditionProtocol(ConditionProtocol, Protocol):
"""Define the format of device_condition modules. """Define the format of device_condition modules.
Each module must define either CONDITION_SCHEMA or async_validate_condition_config. Each module must define either CONDITION_SCHEMA or async_validate_condition_config
from ConditionProtocol.
""" """
CONDITION_SCHEMA: vol.Schema
async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
def async_condition_from_config(
self, hass: HomeAssistant, config: ConfigType
) -> condition.ConditionCheckerType:
"""Evaluate state based on configuration."""
async def async_get_condition_capabilities( async def async_get_condition_capabilities(
self, hass: HomeAssistant, config: ConfigType self, hass: HomeAssistant, config: ConfigType
) -> dict[str, vol.Schema]: ) -> dict[str, vol.Schema]:
@ -62,4 +52,4 @@ async def async_condition_from_config(
platform = await async_get_device_automation_platform( platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
) )
return platform.async_condition_from_config(hass, config) return trace_condition_function(platform.async_condition_from_config(hass, config))

View File

@ -7,15 +7,13 @@ from collections.abc import Callable, Container, Generator
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime, time as dt_time, timedelta from datetime import datetime, time as dt_time, timedelta
import functools as ft import functools as ft
import logging
import re import re
import sys import sys
from typing import Any, cast from typing import Any, Protocol, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.components import zone as zone_cmp from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import condition as device_condition
from homeassistant.components.sensor import SensorDeviceClass from homeassistant.components.sensor import SensorDeviceClass
from homeassistant.const import ( from homeassistant.const import (
ATTR_DEVICE_CLASS, ATTR_DEVICE_CLASS,
@ -55,6 +53,7 @@ from homeassistant.exceptions import (
HomeAssistantError, HomeAssistantError,
TemplateError, TemplateError,
) )
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -77,12 +76,44 @@ ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config"
FROM_CONFIG_FORMAT = "{}_from_config" FROM_CONFIG_FORMAT = "{}_from_config"
VALIDATE_CONFIG_FORMAT = "{}_validate_config" VALIDATE_CONFIG_FORMAT = "{}_validate_config"
_LOGGER = logging.getLogger(__name__) _PLATFORM_ALIASES = {
"and": None,
"device": "device_automation",
"not": None,
"numeric_state": None,
"or": None,
"state": None,
"sun": None,
"template": None,
"time": None,
"trigger": None,
"zone": None,
}
INPUT_ENTITY_ID = re.compile( INPUT_ENTITY_ID = re.compile(
r"^input_(?:select|text|number|boolean|datetime)\.(?!.+__)(?!_)[\da-z_]+(?<!_)$" r"^input_(?:select|text|number|boolean|datetime)\.(?!.+__)(?!_)[\da-z_]+(?<!_)$"
) )
class ConditionProtocol(Protocol):
"""Define the format of device_condition modules.
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
"""
CONDITION_SCHEMA: vol.Schema
async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
def async_condition_from_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType:
"""Evaluate state based on configuration."""
ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None] ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
@ -152,6 +183,27 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke
return wrapper return wrapper
async def _async_get_condition_platform(
hass: HomeAssistant, config: ConfigType
) -> ConditionProtocol | None:
platform = config[CONF_CONDITION]
platform = _PLATFORM_ALIASES.get(platform, platform)
if platform is None:
return None
try:
integration = await async_get_integration(hass, platform)
except IntegrationNotFound:
raise HomeAssistantError(
f'Invalid condition "{platform}" specified {config}'
) from None
try:
return integration.get_platform("condition")
except ImportError:
raise HomeAssistantError(
f"Integration '{platform}' does not provide condition support"
) from None
async def async_from_config( async def async_from_config(
hass: HomeAssistant, hass: HomeAssistant,
config: ConfigType, config: ConfigType,
@ -160,15 +212,18 @@ async def async_from_config(
Should be run on the event loop. Should be run on the event loop.
""" """
condition = config.get(CONF_CONDITION) factory: Any = None
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT): platform = await _async_get_condition_platform(hass, config)
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
if factory: if platform is None:
break condition = config.get(CONF_CONDITION)
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
if factory is None: if factory:
raise HomeAssistantError(f'Invalid condition "{condition}" specified {config}') break
else:
factory = platform.async_condition_from_config
# Check if condition is not enabled # Check if condition is not enabled
if not config.get(CONF_ENABLED, True): if not config.get(CONF_ENABLED, True):
@ -928,14 +983,6 @@ def zone_from_config(config: ConfigType) -> ConditionCheckerType:
return if_in_zone return if_in_zone
async def async_device_from_config(
hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType:
"""Test a device condition."""
checker = await device_condition.async_condition_from_config(hass, config)
return trace_condition_function(checker)
async def async_trigger_from_config( async def async_trigger_from_config(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType: ) -> ConditionCheckerType:
@ -991,10 +1038,10 @@ async def async_validate_condition_config(
config["conditions"] = conditions config["conditions"] = conditions
return config return config
if condition == "device": platform = await _async_get_condition_platform(hass, config)
return await device_condition.async_validate_condition_config(hass, config) if platform is not None and hasattr(platform, "async_validate_condition_config"):
return await platform.async_validate_condition_config(hass, config)
if condition in ("numeric_state", "state"): if platform is None and condition in ("numeric_state", "state"):
validator = cast( validator = cast(
Callable[[HomeAssistant, ConfigType], ConfigType], Callable[[HomeAssistant, ConfigType], ConfigType],
getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)), getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)),

View File

@ -2406,13 +2406,7 @@ async def test_repeat_var_in_condition(hass: HomeAssistant, condition) -> None:
script_obj = script.Script( script_obj = script.Script(
hass, cv.SCRIPT_SCHEMA(sequence), "Test Name", "test_domain" hass, cv.SCRIPT_SCHEMA(sequence), "Test Name", "test_domain"
) )
await script_obj.async_run(context=Context())
with mock.patch(
"homeassistant.helpers.condition._LOGGER.error",
side_effect=AssertionError("Template Error"),
):
await script_obj.async_run(context=Context())
assert len(events) == 2 assert len(events) == 2
if condition == "while": if condition == "while":
@ -2545,13 +2539,7 @@ async def test_repeat_nested(
] ]
) )
script_obj = script.Script(hass, sequence, "Test Name", "test_domain") script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
await script_obj.async_run(variables, Context())
with mock.patch(
"homeassistant.helpers.condition._LOGGER.error",
side_effect=AssertionError("Template Error"),
):
await script_obj.async_run(variables, Context())
assert len(events) == 10 assert len(events) == 10
assert events[0].data == first_last assert events[0].data == first_last
assert events[-1].data == first_last assert events[-1].data == first_last