Add async friendly helper for validating config schemas (#123800)

* Add async friendly helper for validating config schemas

* Improve docstrings

* Add tests
This commit is contained in:
Erik Montnemery
2024-08-17 11:01:49 +02:00
committed by GitHub
parent a7bca9bcea
commit 533442f33e
4 changed files with 161 additions and 7 deletions

View File

@@ -6,6 +6,7 @@
from collections.abc import Callable, Hashable
import contextlib
from contextvars import ContextVar
from datetime import (
date as date_sys,
datetime as datetime_sys,
@@ -13,6 +14,7 @@ from datetime import (
timedelta,
)
from enum import Enum, StrEnum
import functools
import logging
from numbers import Number
import os
@@ -20,6 +22,7 @@ import re
from socket import ( # type: ignore[attr-defined] # private, not in typeshed
_GLOBAL_DEFAULT_TIMEOUT,
)
import threading
from typing import Any, cast, overload
from urllib.parse import urlparse
from uuid import UUID
@@ -94,6 +97,7 @@ from homeassistant.const import (
)
from homeassistant.core import (
DOMAIN as HOMEASSISTANT_DOMAIN,
HomeAssistant,
async_get_hass,
async_get_hass_or_none,
split_entity_id,
@@ -114,6 +118,51 @@ from .typing import VolDictType, VolSchemaType
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM', 'HH:MM:SS' or 'HH:MM:SS.F'"
class MustValidateInExecutor(HomeAssistantError):
"""Raised when validation must happen in an executor thread."""
class _Hass(threading.local):
"""Container which makes a HomeAssistant instance available to validators."""
hass: HomeAssistant | None = None
_hass = _Hass()
"""Set when doing async friendly schema validation."""
def _async_get_hass_or_none() -> HomeAssistant | None:
"""Return the HomeAssistant instance or None.
First tries core.async_get_hass_or_none, then _hass which is
set when doing async friendly schema validation.
"""
return async_get_hass_or_none() or _hass.hass
_validating_async: ContextVar[bool] = ContextVar("_validating_async", default=False)
"""Set to True when doing async friendly schema validation."""
def not_async_friendly[**_P, _R](validator: Callable[_P, _R]) -> Callable[_P, _R]:
"""Mark a validator as not async friendly.
This makes validation happen in an executor thread if validation is done by
async_validate, otherwise does nothing.
"""
@functools.wraps(validator)
def _not_async_friendly(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if _validating_async.get() and async_get_hass_or_none():
# Raise if doing async friendly validation and validation
# is happening in the event loop
raise MustValidateInExecutor
return validator(*args, **kwargs)
return _not_async_friendly
class UrlProtocolSchema(StrEnum):
"""Valid URL protocol schema values."""
@@ -217,6 +266,7 @@ def whitespace(value: Any) -> str:
raise vol.Invalid(f"contains non-whitespace: {value}")
@not_async_friendly
def isdevice(value: Any) -> str:
"""Validate that value is a real device."""
try:
@@ -258,6 +308,7 @@ def is_regex(value: Any) -> re.Pattern[Any]:
return r
@not_async_friendly
def isfile(value: Any) -> str:
"""Validate that the value is an existing file."""
if value is None:
@@ -271,6 +322,7 @@ def isfile(value: Any) -> str:
return file_in
@not_async_friendly
def isdir(value: Any) -> str:
"""Validate that the value is an existing dir."""
if value is None:
@@ -664,7 +716,7 @@ def template(value: Any | None) -> template_helper.Template:
if isinstance(value, (list, dict, template_helper.Template)):
raise vol.Invalid("template value should be a string")
template_value = template_helper.Template(str(value), async_get_hass_or_none())
template_value = template_helper.Template(str(value), _async_get_hass_or_none())
try:
template_value.ensure_valid()
@@ -682,7 +734,7 @@ def dynamic_template(value: Any | None) -> template_helper.Template:
if not template_helper.is_template_string(str(value)):
raise vol.Invalid("template value does not contain a dynamic template")
template_value = template_helper.Template(str(value), async_get_hass_or_none())
template_value = template_helper.Template(str(value), _async_get_hass_or_none())
try:
template_value.ensure_valid()
@@ -1918,3 +1970,32 @@ historic_currency = vol.In(
country = vol.In(COUNTRIES, msg="invalid ISO 3166 formatted country")
language = vol.In(LANGUAGES, msg="invalid RFC 5646 formatted language")
async def async_validate(
hass: HomeAssistant, validator: Callable[[Any], Any], value: Any
) -> Any:
"""Async friendly schema validation.
If a validator decorated with @not_async_friendly is called, validation will be
deferred to an executor. If not, validation will happen in the event loop.
"""
_validating_async.set(True)
try:
return validator(value)
except MustValidateInExecutor:
return await hass.async_add_executor_job(
_validate_in_executor, hass, validator, value
)
finally:
_validating_async.set(False)
def _validate_in_executor(
hass: HomeAssistant, validator: Callable[[Any], Any], value: Any
) -> Any:
_hass.hass = hass
try:
return validator(value)
finally:
_hass.hass = None