Add support for condition platforms to provide multiple conditions (#147376)

This commit is contained in:
Erik Montnemery
2025-06-25 19:10:30 +02:00
committed by GitHub
parent 2b5f5f641d
commit 26e3caea9a
4 changed files with 189 additions and 64 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import abc
import asyncio
from collections import deque
from collections.abc import Callable, Container, Generator
@@ -75,7 +76,7 @@ ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config"
FROM_CONFIG_FORMAT = "{}_from_config"
VALIDATE_CONFIG_FORMAT = "{}_validate_config"
_PLATFORM_ALIASES = {
_PLATFORM_ALIASES: dict[str | None, str | None] = {
"and": None,
"device": "device_automation",
"not": None,
@@ -93,20 +94,33 @@ INPUT_ENTITY_ID = re.compile(
)
class ConditionProtocol(Protocol):
"""Define the format of condition modules."""
class Condition(abc.ABC):
"""Condition class."""
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
"""Initialize condition."""
@classmethod
@abc.abstractmethod
async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType
cls, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
def async_condition_from_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType:
@abc.abstractmethod
async def async_condition_from_config(self) -> ConditionCheckerType:
"""Evaluate state based on configuration."""
class ConditionProtocol(Protocol):
"""Define the format of condition modules."""
async def async_get_conditions(
self, hass: HomeAssistant
) -> dict[str, type[Condition]]:
"""Return the conditions provided by this integration."""
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
@@ -179,7 +193,9 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke
async def _async_get_condition_platform(
hass: HomeAssistant, config: ConfigType
) -> ConditionProtocol | None:
platform = config[CONF_CONDITION]
condition_key: str = config[CONF_CONDITION]
platform_and_sub_type = condition_key.partition(".")
platform: str | None = platform_and_sub_type[0]
platform = _PLATFORM_ALIASES.get(platform, platform)
if platform is None:
return None
@@ -187,7 +203,7 @@ async def _async_get_condition_platform(
integration = await async_get_integration(hass, platform)
except IntegrationNotFound:
raise HomeAssistantError(
f'Invalid condition "{platform}" specified {config}'
f'Invalid condition "{condition_key}" specified {config}'
) from None
try:
return await integration.async_get_platform("condition")
@@ -205,19 +221,6 @@ async def async_from_config(
Should be run on the event loop.
"""
factory: Any = None
platform = await _async_get_condition_platform(hass, config)
if platform is None:
condition = config.get(CONF_CONDITION)
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
if factory:
break
else:
factory = platform.async_condition_from_config
# Check if condition is not enabled
if CONF_ENABLED in config:
enabled = config[CONF_ENABLED]
@@ -239,6 +242,21 @@ async def async_from_config(
return disabled_condition
condition: str = config[CONF_CONDITION]
factory: Any = None
platform = await _async_get_condition_platform(hass, config)
if platform is not None:
condition_descriptors = await platform.async_get_conditions(hass)
condition_instance = condition_descriptors[condition](hass, config)
return await condition_instance.async_condition_from_config()
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
if factory:
break
# Check for partials to properly determine if coroutine function
check_factory = factory
while isinstance(check_factory, ft.partial):
@@ -936,7 +954,7 @@ async def async_validate_condition_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
condition = config[CONF_CONDITION]
condition: str = config[CONF_CONDITION]
if condition in ("and", "not", "or"):
conditions = []
for sub_cond in config["conditions"]:
@@ -947,7 +965,10 @@ async def async_validate_condition_config(
platform = await _async_get_condition_platform(hass, config)
if platform is not None:
return await platform.async_validate_condition_config(hass, config)
condition_descriptors = await platform.async_get_conditions(hass)
if not (condition_class := condition_descriptors.get(condition)):
raise vol.Invalid(f"Invalid condition '{condition}' specified")
return await condition_class.async_validate_condition_config(hass, config)
if platform is None and condition in ("numeric_state", "state"):
validator = cast(
Callable[[HomeAssistant, ConfigType], ConfigType],