Teach state and numeric_state conditions about entity registry ids (#60841)

This commit is contained in:
Erik Montnemery
2021-12-02 23:55:12 +01:00
committed by GitHub
parent a07f75c6b0
commit 0e3bc21d54
23 changed files with 236 additions and 31 deletions

View File

@@ -51,7 +51,7 @@ from homeassistant.exceptions import (
HomeAssistantError,
TemplateError,
)
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.sun import get_astral_event_date
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
@@ -71,8 +71,9 @@ from .trace import (
# mypy: disallow-any-generics
FROM_CONFIG_FORMAT = "{}_from_config"
ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config"
FROM_CONFIG_FORMAT = "{}_from_config"
VALIDATE_CONFIG_FORMAT = "{}_validate_config"
_LOGGER = logging.getLogger(__name__)
@@ -885,7 +886,7 @@ async def async_device_from_config(
return trace_condition_function(
cast(
ConditionCheckerType,
platform.async_condition_from_config(config), # type: ignore
platform.async_condition_from_config(hass, config), # type: ignore
)
)
@@ -908,6 +909,30 @@ async def async_trigger_from_config(
return trigger_if
def numeric_state_validate_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate numeric_state condition config."""
registry = er.async_get(hass)
config = dict(config)
config[CONF_ENTITY_ID] = er.async_resolve_entity_ids(
registry, cv.entity_ids_or_uuids(config[CONF_ENTITY_ID])
)
return config
def state_validate_config(hass: HomeAssistant, config: ConfigType) -> ConfigType:
"""Validate state condition config."""
registry = er.async_get(hass)
config = dict(config)
config[CONF_ENTITY_ID] = er.async_resolve_entity_ids(
registry, cv.entity_ids_or_uuids(config[CONF_ENTITY_ID])
)
return config
async def async_validate_condition_config(
hass: HomeAssistant, config: ConfigType | Template
) -> ConfigType | Template:
@@ -933,6 +958,12 @@ async def async_validate_condition_config(
return await platform.async_validate_condition_config(hass, config) # type: ignore
return cast(ConfigType, platform.CONDITION_SCHEMA(config)) # type: ignore
if condition in ("numeric_state", "state"):
validator = getattr(
sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)
)
return validator(hass, config) # type: ignore
return config