Add async friendly helper for validating config schemas (#123800)

* Add async friendly helper for validating config schemas

* Improve docstrings

* Add tests
This commit is contained in:
Erik Montnemery 2024-08-17 11:01:49 +02:00 committed by GitHub
parent a7bca9bcea
commit 533442f33e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 161 additions and 7 deletions

View File

@ -1535,7 +1535,9 @@ async def async_process_component_config(
# No custom config validator, proceed with schema validation # No custom config validator, proceed with schema validation
if hasattr(component, "CONFIG_SCHEMA"): if hasattr(component, "CONFIG_SCHEMA"):
try: try:
return IntegrationConfigInfo(component.CONFIG_SCHEMA(config), []) return IntegrationConfigInfo(
await cv.async_validate(hass, component.CONFIG_SCHEMA, config), []
)
except vol.Invalid as exc: except vol.Invalid as exc:
exc_info = ConfigExceptionInfo( exc_info = ConfigExceptionInfo(
exc, exc,
@ -1570,7 +1572,9 @@ async def async_process_component_config(
# Validate component specific platform schema # Validate component specific platform schema
platform_path = f"{p_name}.{domain}" platform_path = f"{p_name}.{domain}"
try: try:
p_validated = component_platform_schema(p_config) p_validated = await cv.async_validate(
hass, component_platform_schema, p_config
)
except vol.Invalid as exc: except vol.Invalid as exc:
exc_info = ConfigExceptionInfo( exc_info = ConfigExceptionInfo(
exc, exc,

View File

@ -234,7 +234,7 @@ async def async_check_ha_config_file( # noqa: C901
config_schema = getattr(component, "CONFIG_SCHEMA", None) config_schema = getattr(component, "CONFIG_SCHEMA", None)
if config_schema is not None: if config_schema is not None:
try: try:
validated_config = config_schema(config) validated_config = await cv.async_validate(hass, config_schema, config)
# Don't fail if the validator removed the domain from the config # Don't fail if the validator removed the domain from the config
if domain in validated_config: if domain in validated_config:
result[domain] = validated_config[domain] result[domain] = validated_config[domain]
@ -255,7 +255,9 @@ async def async_check_ha_config_file( # noqa: C901
for p_name, p_config in config_per_platform(config, domain): for p_name, p_config in config_per_platform(config, domain):
# Validate component specific platform schema # Validate component specific platform schema
try: try:
p_validated = component_platform_schema(p_config) p_validated = await cv.async_validate(
hass, component_platform_schema, p_config
)
except vol.Invalid as ex: except vol.Invalid as ex:
_comp_error(ex, domain, p_config, p_config) _comp_error(ex, domain, p_config, p_config)
continue continue

View File

@ -6,6 +6,7 @@
from collections.abc import Callable, Hashable from collections.abc import Callable, Hashable
import contextlib import contextlib
from contextvars import ContextVar
from datetime import ( from datetime import (
date as date_sys, date as date_sys,
datetime as datetime_sys, datetime as datetime_sys,
@ -13,6 +14,7 @@ from datetime import (
timedelta, timedelta,
) )
from enum import Enum, StrEnum from enum import Enum, StrEnum
import functools
import logging import logging
from numbers import Number from numbers import Number
import os import os
@ -20,6 +22,7 @@ import re
from socket import ( # type: ignore[attr-defined] # private, not in typeshed from socket import ( # type: ignore[attr-defined] # private, not in typeshed
_GLOBAL_DEFAULT_TIMEOUT, _GLOBAL_DEFAULT_TIMEOUT,
) )
import threading
from typing import Any, cast, overload from typing import Any, cast, overload
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import UUID from uuid import UUID
@ -94,6 +97,7 @@ from homeassistant.const import (
) )
from homeassistant.core import ( from homeassistant.core import (
DOMAIN as HOMEASSISTANT_DOMAIN, DOMAIN as HOMEASSISTANT_DOMAIN,
HomeAssistant,
async_get_hass, async_get_hass,
async_get_hass_or_none, async_get_hass_or_none,
split_entity_id, split_entity_id,
@ -114,6 +118,51 @@ from .typing import VolDictType, VolSchemaType
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM', 'HH:MM:SS' or 'HH:MM:SS.F'" TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM', 'HH:MM:SS' or 'HH:MM:SS.F'"
class MustValidateInExecutor(HomeAssistantError):
"""Raised when validation must happen in an executor thread."""
class _Hass(threading.local):
"""Container which makes a HomeAssistant instance available to validators."""
hass: HomeAssistant | None = None
_hass = _Hass()
"""Set when doing async friendly schema validation."""
def _async_get_hass_or_none() -> HomeAssistant | None:
"""Return the HomeAssistant instance or None.
First tries core.async_get_hass_or_none, then _hass which is
set when doing async friendly schema validation.
"""
return async_get_hass_or_none() or _hass.hass
_validating_async: ContextVar[bool] = ContextVar("_validating_async", default=False)
"""Set to True when doing async friendly schema validation."""
def not_async_friendly[**_P, _R](validator: Callable[_P, _R]) -> Callable[_P, _R]:
"""Mark a validator as not async friendly.
This makes validation happen in an executor thread if validation is done by
async_validate, otherwise does nothing.
"""
@functools.wraps(validator)
def _not_async_friendly(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if _validating_async.get() and async_get_hass_or_none():
# Raise if doing async friendly validation and validation
# is happening in the event loop
raise MustValidateInExecutor
return validator(*args, **kwargs)
return _not_async_friendly
class UrlProtocolSchema(StrEnum): class UrlProtocolSchema(StrEnum):
"""Valid URL protocol schema values.""" """Valid URL protocol schema values."""
@ -217,6 +266,7 @@ def whitespace(value: Any) -> str:
raise vol.Invalid(f"contains non-whitespace: {value}") raise vol.Invalid(f"contains non-whitespace: {value}")
@not_async_friendly
def isdevice(value: Any) -> str: def isdevice(value: Any) -> str:
"""Validate that value is a real device.""" """Validate that value is a real device."""
try: try:
@ -258,6 +308,7 @@ def is_regex(value: Any) -> re.Pattern[Any]:
return r return r
@not_async_friendly
def isfile(value: Any) -> str: def isfile(value: Any) -> str:
"""Validate that the value is an existing file.""" """Validate that the value is an existing file."""
if value is None: if value is None:
@ -271,6 +322,7 @@ def isfile(value: Any) -> str:
return file_in return file_in
@not_async_friendly
def isdir(value: Any) -> str: def isdir(value: Any) -> str:
"""Validate that the value is an existing dir.""" """Validate that the value is an existing dir."""
if value is None: if value is None:
@ -664,7 +716,7 @@ def template(value: Any | None) -> template_helper.Template:
if isinstance(value, (list, dict, template_helper.Template)): if isinstance(value, (list, dict, template_helper.Template)):
raise vol.Invalid("template value should be a string") raise vol.Invalid("template value should be a string")
template_value = template_helper.Template(str(value), async_get_hass_or_none()) template_value = template_helper.Template(str(value), _async_get_hass_or_none())
try: try:
template_value.ensure_valid() template_value.ensure_valid()
@ -682,7 +734,7 @@ def dynamic_template(value: Any | None) -> template_helper.Template:
if not template_helper.is_template_string(str(value)): if not template_helper.is_template_string(str(value)):
raise vol.Invalid("template value does not contain a dynamic template") raise vol.Invalid("template value does not contain a dynamic template")
template_value = template_helper.Template(str(value), async_get_hass_or_none()) template_value = template_helper.Template(str(value), _async_get_hass_or_none())
try: try:
template_value.ensure_valid() template_value.ensure_valid()
@ -1918,3 +1970,32 @@ historic_currency = vol.In(
country = vol.In(COUNTRIES, msg="invalid ISO 3166 formatted country") country = vol.In(COUNTRIES, msg="invalid ISO 3166 formatted country")
language = vol.In(LANGUAGES, msg="invalid RFC 5646 formatted language") language = vol.In(LANGUAGES, msg="invalid RFC 5646 formatted language")
async def async_validate(
hass: HomeAssistant, validator: Callable[[Any], Any], value: Any
) -> Any:
"""Async friendly schema validation.
If a validator decorated with @not_async_friendly is called, validation will be
deferred to an executor. If not, validation will happen in the event loop.
"""
_validating_async.set(True)
try:
return validator(value)
except MustValidateInExecutor:
return await hass.async_add_executor_job(
_validate_in_executor, hass, validator, value
)
finally:
_validating_async.set(False)
def _validate_in_executor(
hass: HomeAssistant, validator: Callable[[Any], Any], value: Any
) -> Any:
_hass.hass = hass
try:
return validator(value)
finally:
_hass.hass = None

View File

@ -3,13 +3,16 @@
from collections import OrderedDict from collections import OrderedDict
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
import enum import enum
from functools import partial
import logging import logging
import os import os
from socket import _GLOBAL_DEFAULT_TIMEOUT from socket import _GLOBAL_DEFAULT_TIMEOUT
import threading
from typing import Any from typing import Any
from unittest.mock import Mock, patch from unittest.mock import ANY, Mock, patch
import uuid import uuid
import py
import pytest import pytest
import voluptuous as vol import voluptuous as vol
@ -1738,3 +1741,67 @@ def test_determine_script_action_ambiguous() -> None:
def test_determine_script_action_non_ambiguous() -> None: def test_determine_script_action_non_ambiguous() -> None:
"""Test determine script action with a non ambiguous action.""" """Test determine script action with a non ambiguous action."""
assert cv.determine_script_action({"delay": "00:00:05"}) == "delay" assert cv.determine_script_action({"delay": "00:00:05"}) == "delay"
async def test_async_validate(hass: HomeAssistant, tmpdir: py.path.local) -> None:
"""Test the async_validate helper."""
validator_calls: dict[str, list[int]] = {}
def _mock_validator_schema(real_func, *args):
calls = validator_calls.setdefault(real_func.__name__, [])
calls.append(threading.get_ident())
return real_func(*args)
CV_PREFIX = "homeassistant.helpers.config_validation"
with (
patch(f"{CV_PREFIX}.isdir", wraps=partial(_mock_validator_schema, cv.isdir)),
patch(f"{CV_PREFIX}.string", wraps=partial(_mock_validator_schema, cv.string)),
):
# Assert validation in event loop when not decorated with not_async_friendly
await cv.async_validate(hass, cv.string, "abcd")
assert validator_calls == {"string": [hass.loop_thread_id]}
validator_calls = {}
# Assert validation in executor when decorated with not_async_friendly
await cv.async_validate(hass, cv.isdir, tmpdir)
assert validator_calls == {"isdir": [hass.loop_thread_id, ANY]}
assert validator_calls["isdir"][1] != hass.loop_thread_id
validator_calls = {}
# Assert validation in executor when decorated with not_async_friendly
await cv.async_validate(hass, vol.All(cv.isdir, cv.string), tmpdir)
assert validator_calls == {"isdir": [hass.loop_thread_id, ANY], "string": [ANY]}
assert validator_calls["isdir"][1] != hass.loop_thread_id
assert validator_calls["string"][0] != hass.loop_thread_id
validator_calls = {}
# Assert validation in executor when decorated with not_async_friendly
await cv.async_validate(hass, vol.All(cv.string, cv.isdir), tmpdir)
assert validator_calls == {
"isdir": [hass.loop_thread_id, ANY],
"string": [hass.loop_thread_id, ANY],
}
assert validator_calls["isdir"][1] != hass.loop_thread_id
assert validator_calls["string"][1] != hass.loop_thread_id
validator_calls = {}
# Assert validation in event loop when not using cv.async_validate
cv.isdir(tmpdir)
assert validator_calls == {"isdir": [hass.loop_thread_id]}
validator_calls = {}
# Assert validation in event loop when not using cv.async_validate
vol.All(cv.isdir, cv.string)(tmpdir)
assert validator_calls == {
"isdir": [hass.loop_thread_id],
"string": [hass.loop_thread_id],
}
validator_calls = {}
# Assert validation in event loop when not using cv.async_validate
vol.All(cv.string, cv.isdir)(tmpdir)
assert validator_calls == {
"isdir": [hass.loop_thread_id],
"string": [hass.loop_thread_id],
}
validator_calls = {}