mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Improve validation of entity service schemas (#124102)
* Improve validation of entity service schemas * Update tests/helpers/test_entity_platform.py Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com> --------- Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
parent
0dc1eb8757
commit
55c42fde88
@ -1305,9 +1305,28 @@ TARGET_SERVICE_FIELDS = {
|
||||
_HAS_ENTITY_SERVICE_FIELD = has_at_least_one_key(*ENTITY_SERVICE_FIELDS)
|
||||
|
||||
|
||||
def is_entity_service_schema(validator: VolSchemaType) -> bool:
|
||||
"""Check if the passed validator is an entity schema validator.
|
||||
|
||||
The validator must be either of:
|
||||
- A validator returned by cv._make_entity_service_schema
|
||||
- A validator returned by cv._make_entity_service_schema, wrapped in a vol.Schema
|
||||
- A validator returned by cv._make_entity_service_schema, wrapped in a vol.All
|
||||
Nesting is allowed.
|
||||
"""
|
||||
if hasattr(validator, "_entity_service_schema"):
|
||||
return True
|
||||
if isinstance(validator, (vol.All)):
|
||||
return any(is_entity_service_schema(val) for val in validator.validators)
|
||||
if isinstance(validator, (vol.Schema)):
|
||||
return is_entity_service_schema(validator.schema)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType:
|
||||
"""Create an entity service schema."""
|
||||
return vol.All(
|
||||
validator = vol.All(
|
||||
vol.Schema(
|
||||
{
|
||||
# The frontend stores data here. Don't use in core.
|
||||
@ -1319,6 +1338,8 @@ def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType:
|
||||
),
|
||||
_HAS_ENTITY_SERVICE_FIELD,
|
||||
)
|
||||
setattr(validator, "_entity_service_schema", True)
|
||||
return validator
|
||||
|
||||
|
||||
BASE_ENTITY_SCHEMA = _make_entity_service_schema({}, vol.PREVENT_EXTRA)
|
||||
|
@ -1267,17 +1267,8 @@ def async_register_entity_service(
|
||||
# Do a sanity check to check this is a valid entity service schema,
|
||||
# the check could be extended to require All/Any to have sub schema(s)
|
||||
# with all entity service fields
|
||||
elif (
|
||||
# Don't check All/Any
|
||||
not isinstance(schema, (vol.All, vol.Any))
|
||||
# Don't check All/Any wrapped in schema
|
||||
and not isinstance(schema.schema, (vol.All, vol.Any))
|
||||
and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS)
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
"The schema does not include all required keys: "
|
||||
f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}"
|
||||
)
|
||||
elif not cv.is_entity_service_schema(schema):
|
||||
raise HomeAssistantError("The schema is not an entity service schema")
|
||||
|
||||
service_func: str | HassJob[..., Any]
|
||||
service_func = func if isinstance(func, str) else HassJob(func)
|
||||
|
@ -1805,3 +1805,27 @@ async def test_async_validate(hass: HomeAssistant, tmpdir: py.path.local) -> Non
|
||||
"string": [hass.loop_thread_id],
|
||||
}
|
||||
validator_calls = {}
|
||||
|
||||
|
||||
async def test_is_entity_service_schema(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test cv.is_entity_service_schema."""
|
||||
for schema in (
|
||||
vol.Schema({"some": str}),
|
||||
vol.All(vol.Schema({"some": str})),
|
||||
vol.Any(vol.Schema({"some": str})),
|
||||
vol.Any(cv.make_entity_service_schema({"some": str})),
|
||||
):
|
||||
assert cv.is_entity_service_schema(schema) is False
|
||||
|
||||
for schema in (
|
||||
cv.make_entity_service_schema({"some": str}),
|
||||
vol.Schema(cv.make_entity_service_schema({"some": str})),
|
||||
vol.Schema(vol.All(cv.make_entity_service_schema({"some": str}))),
|
||||
vol.Schema(vol.Schema(cv.make_entity_service_schema({"some": str}))),
|
||||
vol.All(cv.make_entity_service_schema({"some": str})),
|
||||
vol.All(vol.All(cv.make_entity_service_schema({"some": str}))),
|
||||
vol.All(vol.Schema(cv.make_entity_service_schema({"some": str}))),
|
||||
):
|
||||
assert cv.is_entity_service_schema(schema) is True
|
||||
|
@ -23,7 +23,7 @@ from homeassistant.core import (
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
||||
from homeassistant.helpers import discovery
|
||||
from homeassistant.helpers import config_validation as cv, discovery
|
||||
from homeassistant.helpers.entity_component import EntityComponent, async_update_entity
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
@ -559,28 +559,28 @@ async def test_register_entity_service(
|
||||
async def test_register_entity_service_non_entity_service_schema(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test attempting to register a service with an incomplete schema."""
|
||||
"""Test attempting to register a service with a non entity service schema."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match=(
|
||||
"The schema does not include all required keys: entity_id, device_id, area_id, "
|
||||
"floor_id, label_id"
|
||||
),
|
||||
for schema in (
|
||||
vol.Schema({"some": str}),
|
||||
vol.All(vol.Schema({"some": str})),
|
||||
vol.Any(vol.Schema({"some": str})),
|
||||
):
|
||||
component.async_register_entity_service(
|
||||
"hello", vol.Schema({"some": str}), Mock()
|
||||
)
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match=("The schema is not an entity service schema"),
|
||||
):
|
||||
component.async_register_entity_service("hello", schema, Mock())
|
||||
|
||||
# The check currently does not recurse into vol.All or vol.Any allowing these
|
||||
# non-compliant schemas to pass
|
||||
component.async_register_entity_service(
|
||||
"hello", vol.All(vol.Schema({"some": str})), Mock()
|
||||
)
|
||||
component.async_register_entity_service(
|
||||
"hello", vol.Any(vol.Schema({"some": str})), Mock()
|
||||
)
|
||||
for idx, schema in enumerate(
|
||||
(
|
||||
cv.make_entity_service_schema({"some": str}),
|
||||
vol.Schema(cv.make_entity_service_schema({"some": str})),
|
||||
vol.All(cv.make_entity_service_schema({"some": str})),
|
||||
)
|
||||
):
|
||||
component.async_register_entity_service(f"test_service_{idx}", schema, Mock())
|
||||
|
||||
|
||||
async def test_register_entity_service_response_data(hass: HomeAssistant) -> None:
|
||||
|
@ -23,6 +23,7 @@ from homeassistant.core import (
|
||||
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
entity_platform,
|
||||
entity_registry as er,
|
||||
@ -1812,31 +1813,32 @@ async def test_register_entity_service_none_schema(
|
||||
async def test_register_entity_service_non_entity_service_schema(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test attempting to register a service with an incomplete schema."""
|
||||
"""Test attempting to register a service with a non entity service schema."""
|
||||
entity_platform = MockEntityPlatform(
|
||||
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match=(
|
||||
"The schema does not include all required keys: entity_id, device_id, area_id, "
|
||||
"floor_id, label_id"
|
||||
),
|
||||
for schema in (
|
||||
vol.Schema({"some": str}),
|
||||
vol.All(vol.Schema({"some": str})),
|
||||
vol.Any(vol.Schema({"some": str})),
|
||||
):
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="The schema is not an entity service schema",
|
||||
):
|
||||
entity_platform.async_register_entity_service("hello", schema, Mock())
|
||||
|
||||
for idx, schema in enumerate(
|
||||
(
|
||||
cv.make_entity_service_schema({"some": str}),
|
||||
vol.Schema(cv.make_entity_service_schema({"some": str})),
|
||||
vol.All(cv.make_entity_service_schema({"some": str})),
|
||||
)
|
||||
):
|
||||
entity_platform.async_register_entity_service(
|
||||
"hello",
|
||||
vol.Schema({"some": str}),
|
||||
Mock(),
|
||||
f"test_service_{idx}", schema, Mock()
|
||||
)
|
||||
# The check currently does not recurse into vol.All or vol.Any allowing these
|
||||
# non-compliant schemas to pass
|
||||
entity_platform.async_register_entity_service(
|
||||
"hello", vol.All(vol.Schema({"some": str})), Mock()
|
||||
)
|
||||
entity_platform.async_register_entity_service(
|
||||
"hello", vol.Any(vol.Schema({"some": str})), Mock()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("update_before_add", [True, False])
|
||||
|
Loading…
x
Reference in New Issue
Block a user