Set scripts which fail validation unavailable (#95381)

This commit is contained in:
Erik Montnemery 2023-06-27 18:24:34 +02:00 committed by GitHub
parent 17ac1a6d32
commit 1fec407a24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 357 additions and 39 deletions

View File

@ -1,6 +1,7 @@
"""Support for scripts."""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from dataclasses import dataclass
import logging
@ -94,12 +95,12 @@ def _scripts_with_x(
if DOMAIN not in hass.data:
return []
component: EntityComponent[ScriptEntity] = hass.data[DOMAIN]
component: EntityComponent[BaseScriptEntity] = hass.data[DOMAIN]
return [
script_entity.entity_id
for script_entity in component.entities
if referenced_id in getattr(script_entity.script, property_name)
if referenced_id in getattr(script_entity, property_name)
]
@ -108,12 +109,12 @@ def _x_in_script(hass: HomeAssistant, entity_id: str, property_name: str) -> lis
if DOMAIN not in hass.data:
return []
component: EntityComponent[ScriptEntity] = hass.data[DOMAIN]
component: EntityComponent[BaseScriptEntity] = hass.data[DOMAIN]
if (script_entity := component.get_entity(entity_id)) is None:
return []
return list(getattr(script_entity.script, property_name))
return list(getattr(script_entity, property_name))
@callback
@ -158,7 +159,7 @@ def scripts_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str
if DOMAIN not in hass.data:
return []
component: EntityComponent[ScriptEntity] = hass.data[DOMAIN]
component: EntityComponent[BaseScriptEntity] = hass.data[DOMAIN]
return [
script_entity.entity_id
@ -173,7 +174,7 @@ def blueprint_in_script(hass: HomeAssistant, entity_id: str) -> str | None:
if DOMAIN not in hass.data:
return None
component: EntityComponent[ScriptEntity] = hass.data[DOMAIN]
component: EntityComponent[BaseScriptEntity] = hass.data[DOMAIN]
if (script_entity := component.get_entity(entity_id)) is None:
return None
@ -183,7 +184,9 @@ def blueprint_in_script(hass: HomeAssistant, entity_id: str) -> str | None:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Load the scripts from the configuration."""
hass.data[DOMAIN] = component = EntityComponent[ScriptEntity](LOGGER, DOMAIN, hass)
hass.data[DOMAIN] = component = EntityComponent[BaseScriptEntity](
LOGGER, DOMAIN, hass
)
# Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED
@ -260,6 +263,7 @@ class ScriptEntityConfig:
key: str
raw_blueprint_inputs: ConfigType | None
raw_config: ConfigType | None
validation_failed: bool
async def _prepare_script_config(
@ -274,9 +278,12 @@ async def _prepare_script_config(
for key, config_block in conf.items():
raw_config = cast(ScriptConfig, config_block).raw_config
raw_blueprint_inputs = cast(ScriptConfig, config_block).raw_blueprint_inputs
validation_failed = cast(ScriptConfig, config_block).validation_failed
script_configs.append(
ScriptEntityConfig(config_block, key, raw_blueprint_inputs, raw_config)
ScriptEntityConfig(
config_block, key, raw_blueprint_inputs, raw_config, validation_failed
)
)
return script_configs
@ -284,11 +291,20 @@ async def _prepare_script_config(
async def _create_script_entities(
hass: HomeAssistant, script_configs: list[ScriptEntityConfig]
) -> list[ScriptEntity]:
) -> list[BaseScriptEntity]:
"""Create script entities from prepared configuration."""
entities: list[ScriptEntity] = []
entities: list[BaseScriptEntity] = []
for script_config in script_configs:
if script_config.validation_failed:
entities.append(
UnavailableScriptEntity(
script_config.key,
script_config.raw_config,
)
)
continue
entity = ScriptEntity(
hass,
script_config.key,
@ -302,16 +318,20 @@ async def _create_script_entities(
async def _async_process_config(
hass: HomeAssistant, config: ConfigType, component: EntityComponent[ScriptEntity]
hass: HomeAssistant,
config: ConfigType,
component: EntityComponent[BaseScriptEntity],
) -> None:
"""Process script configuration."""
entities = []
def script_matches_config(script: ScriptEntity, config: ScriptEntityConfig) -> bool:
def script_matches_config(
script: BaseScriptEntity, config: ScriptEntityConfig
) -> bool:
return script.unique_id == config.key and script.raw_config == config.raw_config
def find_matches(
scripts: list[ScriptEntity],
scripts: list[BaseScriptEntity],
script_configs: list[ScriptEntityConfig],
) -> tuple[set[int], set[int]]:
"""Find matches between a list of script entities and a list of configurations.
@ -338,7 +358,7 @@ async def _async_process_config(
return script_matches, config_matches
script_configs = await _prepare_script_config(hass, config)
scripts: list[ScriptEntity] = list(component.entities)
scripts: list[BaseScriptEntity] = list(component.entities)
# Find scripts and configurations which have matches
script_matches, config_matches = find_matches(scripts, script_configs)
@ -359,7 +379,78 @@ async def _async_process_config(
await component.async_add_entities(entities)
class ScriptEntity(ToggleEntity, RestoreEntity):
class BaseScriptEntity(ToggleEntity, ABC):
"""Base class for script entities."""
raw_config: ConfigType | None
@property
@abstractmethod
def referenced_areas(self) -> set[str]:
"""Return a set of referenced areas."""
@property
@abstractmethod
def referenced_blueprint(self) -> str | None:
"""Return referenced blueprint or None."""
@property
@abstractmethod
def referenced_devices(self) -> set[str]:
"""Return a set of referenced devices."""
@property
@abstractmethod
def referenced_entities(self) -> set[str]:
"""Return a set of referenced entities."""
class UnavailableScriptEntity(BaseScriptEntity):
"""A non-functional script entity with its state set to unavailable.
This class is instatiated when an script fails to validate.
"""
_attr_should_poll = False
_attr_available = False
def __init__(
self,
key: str,
raw_config: ConfigType | None,
) -> None:
"""Initialize a script entity."""
self._name = raw_config.get(CONF_ALIAS, key) if raw_config else key
self._attr_unique_id = key
self.raw_config = raw_config
@property
def name(self) -> str:
"""Return the name of the entity."""
return self._name
@property
def referenced_areas(self) -> set[str]:
"""Return a set of referenced areas."""
return set()
@property
def referenced_blueprint(self) -> str | None:
"""Return referenced blueprint or None."""
return None
@property
def referenced_devices(self) -> set[str]:
"""Return a set of referenced devices."""
return set()
@property
def referenced_entities(self) -> set[str]:
"""Return a set of referenced entities."""
return set()
class ScriptEntity(BaseScriptEntity, RestoreEntity):
"""Representation of a script entity."""
icon = None
@ -421,6 +512,11 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
"""Return true if script is on."""
return self.script.is_running
@property
def referenced_areas(self) -> set[str]:
"""Return a set of referenced areas."""
return self.script.referenced_areas
@property
def referenced_blueprint(self):
"""Return referenced blueprint or None."""
@ -428,6 +524,16 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
return None
return self._blueprint_inputs[CONF_USE_BLUEPRINT][CONF_PATH]
@property
def referenced_devices(self) -> set[str]:
"""Return a set of referenced devices."""
return self.script.referenced_devices
@property
def referenced_entities(self) -> set[str]:
"""Return a set of referenced entities."""
return self.script.referenced_entities
@callback
def async_change_listener(self):
"""Update state."""
@ -544,7 +650,7 @@ def websocket_config(
msg: dict[str, Any],
) -> None:
"""Get script config."""
component: EntityComponent[ScriptEntity] = hass.data[DOMAIN]
component: EntityComponent[BaseScriptEntity] = hass.data[DOMAIN]
script = component.get_entity(msg["entity_id"])

View File

@ -49,6 +49,15 @@ from .helpers import async_get_blueprints
PACKAGE_MERGE_HINT = "dict"
_MINIMAL_SCRIPT_ENTITY_SCHEMA = vol.Schema(
{
CONF_ALIAS: cv.string,
vol.Optional(CONF_DESCRIPTION): cv.string,
},
extra=vol.ALLOW_EXTRA,
)
SCRIPT_ENTITY_SCHEMA = make_script_schema(
{
vol.Optional(CONF_ALIAS): cv.string,
@ -74,7 +83,11 @@ SCRIPT_ENTITY_SCHEMA = make_script_schema(
async def _async_validate_config_item(
hass: HomeAssistant, object_id: str, config: ConfigType, warn_on_errors: bool
hass: HomeAssistant,
object_id: str,
config: ConfigType,
raise_on_errors: bool,
warn_on_errors: bool,
) -> ScriptConfig:
"""Validate config item."""
raw_config = None
@ -110,6 +123,15 @@ async def _async_validate_config_item(
)
return
def _minimal_config() -> ScriptConfig:
"""Try validating id, alias and description."""
minimal_config = _MINIMAL_SCRIPT_ENTITY_SCHEMA(config)
script_config = ScriptConfig(minimal_config)
script_config.raw_blueprint_inputs = raw_blueprint_inputs
script_config.raw_config = raw_config
script_config.validation_failed = True
return script_config
if is_blueprint_instance_config(config):
uses_blueprint = True
blueprints = async_get_blueprints(hass)
@ -121,7 +143,9 @@ async def _async_validate_config_item(
"Failed to generate script from blueprint: %s",
err,
)
raise
if raise_on_errors:
raise
return _minimal_config()
raw_blueprint_inputs = blueprint_inputs.config_with_inputs
@ -136,7 +160,9 @@ async def _async_validate_config_item(
blueprint_inputs.inputs,
err,
)
raise HomeAssistantError from err
if raise_on_errors:
raise HomeAssistantError(err) from err
return _minimal_config()
script_name = f"Script with object id '{object_id}'"
if isinstance(config, Mapping):
@ -152,10 +178,16 @@ async def _async_validate_config_item(
validated_config = SCRIPT_ENTITY_SCHEMA(config)
except vol.Invalid as err:
_log_invalid_script(err, script_name, "could not be validated", config)
raise
if raise_on_errors:
raise
return _minimal_config()
script_config = ScriptConfig(validated_config)
script_config.raw_blueprint_inputs = raw_blueprint_inputs
script_config.raw_config = raw_config
try:
validated_config[CONF_SEQUENCE] = await async_validate_actions_config(
script_config[CONF_SEQUENCE] = await async_validate_actions_config(
hass, validated_config[CONF_SEQUENCE]
)
except (
@ -165,11 +197,11 @@ async def _async_validate_config_item(
_log_invalid_script(
err, script_name, "failed to setup actions", validated_config
)
raise
if raise_on_errors:
raise
script_config.validation_failed = True
return script_config
script_config = ScriptConfig(validated_config)
script_config.raw_blueprint_inputs = raw_blueprint_inputs
script_config.raw_config = raw_config
return script_config
@ -178,6 +210,7 @@ class ScriptConfig(dict):
raw_config: ConfigType | None = None
raw_blueprint_inputs: ConfigType | None = None
validation_failed: bool = False
async def _try_async_validate_config_item(
@ -187,7 +220,7 @@ async def _try_async_validate_config_item(
) -> ScriptConfig | None:
"""Validate config item."""
try:
return await _async_validate_config_item(hass, object_id, config, True)
return await _async_validate_config_item(hass, object_id, config, False, True)
except (vol.Invalid, HomeAssistantError):
return None
@ -198,7 +231,7 @@ async def async_validate_config_item(
config: dict[str, Any],
) -> ScriptConfig | None:
"""Validate config item, called by EditScriptConfigView."""
return await _async_validate_config_item(hass, object_id, config, False)
return await _async_validate_config_item(hass, object_id, config, True, False)
async def async_validate_config(hass, config):

View File

@ -1,14 +1,17 @@
"""Tests for config/script."""
from http import HTTPStatus
import json
from typing import Any
from unittest.mock import patch
import pytest
from homeassistant.bootstrap import async_setup_component
from homeassistant.components import config
from homeassistant.const import STATE_OFF, STATE_UNAVAILABLE
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.util import yaml
from tests.typing import ClientSessionGenerator
@ -67,7 +70,12 @@ async def test_update_script_config(
data=json.dumps({"alias": "Moon updated", "sequence": []}),
)
await hass.async_block_till_done()
assert sorted(hass.states.async_entity_ids("script")) == ["script.moon"]
assert sorted(hass.states.async_entity_ids("script")) == [
"script.moon",
"script.sun",
]
assert hass.states.get("script.moon").state == STATE_OFF
assert hass.states.get("script.sun").state == STATE_UNAVAILABLE
assert resp.status == HTTPStatus.OK
result = await resp.json()
@ -79,11 +87,39 @@ async def test_update_script_config(
@pytest.mark.parametrize("script_config", ({},))
@pytest.mark.parametrize(
("updated_config", "validation_error"),
[
({}, "required key not provided @ data['sequence']"),
(
{
"sequence": {
"condition": "state",
# The UUID will fail being resolved to en entity_id
"entity_id": "abcdabcdabcdabcdabcdabcdabcdabcd",
"state": "blah",
}
},
"Unknown entity registry entry abcdabcdabcdabcdabcdabcdabcdabcd",
),
(
{
"use_blueprint": {
"path": "test_service.yaml",
"input": {},
},
},
"Missing input service_to_call",
),
],
)
async def test_update_script_config_with_error(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
hass_config_store,
caplog: pytest.LogCaptureFixture,
updated_config: Any,
validation_error: str,
) -> None:
"""Test updating script config with errors."""
with patch.object(config, "SECTIONS", ["script"]):
@ -98,14 +134,68 @@ async def test_update_script_config_with_error(
resp = await client.post(
"/api/config/script/config/moon",
data=json.dumps({}),
data=json.dumps(updated_config),
)
await hass.async_block_till_done()
assert sorted(hass.states.async_entity_ids("script")) == []
assert resp.status != HTTPStatus.OK
result = await resp.json()
validation_error = "required key not provided @ data['sequence']"
assert result == {"message": f"Message malformed: {validation_error}"}
# Assert the validation error is not logged
assert validation_error not in caplog.text
@pytest.mark.parametrize("script_config", ({},))
@pytest.mark.parametrize(
("updated_config", "validation_error"),
[
(
{
"use_blueprint": {
"path": "test_service.yaml",
"input": {
"service_to_call": "test.automation",
},
},
},
"No substitution found for input blah",
),
],
)
async def test_update_script_config_with_blueprint_substitution_error(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
hass_config_store,
# setup_automation,
caplog: pytest.LogCaptureFixture,
updated_config: Any,
validation_error: str,
) -> None:
"""Test updating script config with errors."""
with patch.object(config, "SECTIONS", ["script"]):
await async_setup_component(hass, "config", {})
assert sorted(hass.states.async_entity_ids("script")) == []
client = await hass_client()
orig_data = {"sun": {}, "moon": {}}
hass_config_store["scripts.yaml"] = orig_data
with patch(
"homeassistant.components.blueprint.models.BlueprintInputs.async_substitute",
side_effect=yaml.UndefinedSubstitution("blah"),
):
resp = await client.post(
"/api/config/script/config/moon",
data=json.dumps(updated_config),
)
await hass.async_block_till_done()
assert sorted(hass.states.async_entity_ids("script")) == []
assert resp.status != HTTPStatus.OK
result = await resp.json()
assert result == {"message": f"Message malformed: {validation_error}"}
# Assert the validation error is not logged
assert validation_error not in caplog.text
@ -131,7 +221,12 @@ async def test_update_remove_key_script_config(
data=json.dumps({"sequence": []}),
)
await hass.async_block_till_done()
assert sorted(hass.states.async_entity_ids("script")) == ["script.moon"]
assert sorted(hass.states.async_entity_ids("script")) == [
"script.moon",
"script.sun",
]
assert hass.states.get("script.moon").state == STATE_OFF
assert hass.states.get("script.sun").state == STATE_UNAVAILABLE
assert resp.status == HTTPStatus.OK
result = await resp.json()

View File

@ -16,6 +16,7 @@ from homeassistant.const import (
SERVICE_TURN_OFF,
SERVICE_TURN_ON,
STATE_OFF,
STATE_UNAVAILABLE,
)
from homeassistant.core import (
Context,
@ -158,14 +159,32 @@ invalid_configs = [
]
@pytest.mark.parametrize("value", invalid_configs)
async def test_setup_with_invalid_configs(hass: HomeAssistant, value) -> None:
@pytest.mark.parametrize(
("config", "nbr_script_entities"),
[
({"test": {}}, 1),
# Invalid slug, entity can't be set up
({"test hello world": {"sequence": [{"event": "bla"}]}}, 0),
(
{
"test": {
"sequence": {
"event": "test_event",
"service": "homeassistant.turn_on",
}
}
},
1,
),
],
)
async def test_setup_with_invalid_configs(
hass: HomeAssistant, config, nbr_script_entities
) -> None:
"""Test setup with invalid configs."""
assert await async_setup_component(
hass, "script", {"script": value}
), f"Script loaded with wrong config {value}"
assert await async_setup_component(hass, "script", {"script": config})
assert len(hass.states.async_entity_ids("script")) == 0
assert len(hass.states.async_entity_ids("script")) == nbr_script_entities
@pytest.mark.parametrize(
@ -177,6 +196,47 @@ async def test_setup_with_invalid_configs(hass: HomeAssistant, value) -> None:
"has invalid object id",
"invalid slug Bad Script",
),
),
)
async def test_bad_config_validation_critical(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
object_id,
broken_config,
problem,
details,
) -> None:
"""Test bad script configuration which can be detected during validation."""
assert await async_setup_component(
hass,
script.DOMAIN,
{
script.DOMAIN: {
object_id: {"alias": "bad_script", **broken_config},
"good_script": {
"alias": "good_script",
"sequence": {
"service": "test.automation",
"entity_id": "hello.world",
},
},
}
},
)
# Check we get the expected error message
assert (
f"Script with alias 'bad_script' {problem} and has been disabled: {details}"
in caplog.text
)
# Make sure one bad script does not prevent other scripts from setting up
assert hass.states.async_entity_ids("script") == ["script.good_script"]
@pytest.mark.parametrize(
("object_id", "broken_config", "problem", "details"),
(
(
"bad_script",
{},
@ -230,8 +290,13 @@ async def test_bad_config_validation(
in caplog.text
)
# Make sure one bad script does not prevent other scripts from setting up
assert hass.states.async_entity_ids("script") == ["script.good_script"]
# Make sure both scripts are setup
assert set(hass.states.async_entity_ids("script")) == {
"script.bad_script",
"script.good_script",
}
# The script failing validation should be unavailable
assert hass.states.get("script.bad_script").state == STATE_UNAVAILABLE
@pytest.mark.parametrize("running", ["no", "same", "different"])
@ -614,6 +679,25 @@ async def test_extraction_functions_unknown_script(hass: HomeAssistant) -> None:
assert script.entities_in_script(hass, "script.unknown") == []
async def test_extraction_functions_unavailable_script(hass: HomeAssistant) -> None:
"""Test extraction functions for an unknown automation."""
entity_id = "script.test1"
assert await async_setup_component(
hass,
DOMAIN,
{DOMAIN: {"test1": {}}},
)
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
assert script.scripts_with_area(hass, "area-in-both") == []
assert script.areas_in_script(hass, entity_id) == []
assert script.scripts_with_blueprint(hass, "blabla.yaml") == []
assert script.blueprint_in_script(hass, entity_id) is None
assert script.scripts_with_device(hass, "device-in-both") == []
assert script.devices_in_script(hass, entity_id) == []
assert script.scripts_with_entity(hass, "light.in_both") == []
assert script.entities_in_script(hass, entity_id) == []
async def test_extraction_functions(hass: HomeAssistant) -> None:
"""Test extraction functions."""
assert await async_setup_component(