Add a new validate config WS command (#67057)

This commit is contained in:
Paulus Schoutsen 2022-02-22 13:28:37 -08:00 committed by GitHub
parent c2e62e4d9f
commit 756e711850
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 120 additions and 37 deletions

View File

@ -63,6 +63,7 @@ def async_register_commands(
async_reg(hass, handle_subscribe_trigger)
async_reg(hass, handle_test_condition)
async_reg(hass, handle_unsubscribe_events)
async_reg(hass, handle_validate_config)
def pong_message(iden: int) -> dict[str, Any]:
@ -116,7 +117,7 @@ def handle_subscribe_events(
event_type, forward_events
)
connection.send_message(messages.result_message(msg["id"]))
connection.send_result(msg["id"])
@callback
@ -139,7 +140,7 @@ def handle_subscribe_bootstrap_integrations(
hass, SIGNAL_BOOTSTRAP_INTEGRATONS, forward_bootstrap_integrations
)
connection.send_message(messages.result_message(msg["id"]))
connection.send_result(msg["id"])
@callback
@ -157,13 +158,9 @@ def handle_unsubscribe_events(
if subscription in connection.subscriptions:
connection.subscriptions.pop(subscription)()
connection.send_message(messages.result_message(msg["id"]))
connection.send_result(msg["id"])
else:
connection.send_message(
messages.error_message(
msg["id"], const.ERR_NOT_FOUND, "Subscription not found."
)
)
connection.send_error(msg["id"], const.ERR_NOT_FOUND, "Subscription not found.")
@decorators.websocket_command(
@ -196,36 +193,20 @@ async def handle_call_service(
context,
target=target,
)
connection.send_message(
messages.result_message(msg["id"], {"context": context})
)
connection.send_result(msg["id"], {"context": context})
except ServiceNotFound as err:
if err.domain == msg["domain"] and err.service == msg["service"]:
connection.send_message(
messages.error_message(
msg["id"], const.ERR_NOT_FOUND, "Service not found."
)
)
connection.send_error(msg["id"], const.ERR_NOT_FOUND, "Service not found.")
else:
connection.send_message(
messages.error_message(
msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err)
)
)
connection.send_error(msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err))
except vol.Invalid as err:
connection.send_message(
messages.error_message(msg["id"], const.ERR_INVALID_FORMAT, str(err))
)
connection.send_error(msg["id"], const.ERR_INVALID_FORMAT, str(err))
except HomeAssistantError as err:
connection.logger.exception(err)
connection.send_message(
messages.error_message(msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err))
)
connection.send_error(msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err))
except Exception as err: # pylint: disable=broad-except
connection.logger.exception(err)
connection.send_message(
messages.error_message(msg["id"], const.ERR_UNKNOWN_ERROR, str(err))
)
connection.send_error(msg["id"], const.ERR_UNKNOWN_ERROR, str(err))
@callback
@ -244,7 +225,7 @@ def handle_get_states(
if entity_perm(state.entity_id, "read")
]
connection.send_message(messages.result_message(msg["id"], states))
connection.send_result(msg["id"], states)
@decorators.websocket_command({vol.Required("type"): "get_services"})
@ -254,7 +235,7 @@ async def handle_get_services(
) -> None:
"""Handle get services command."""
descriptions = await async_get_all_descriptions(hass)
connection.send_message(messages.result_message(msg["id"], descriptions))
connection.send_result(msg["id"], descriptions)
@callback
@ -263,7 +244,7 @@ def handle_get_config(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get config command."""
connection.send_message(messages.result_message(msg["id"], hass.config.as_dict()))
connection.send_result(msg["id"], hass.config.as_dict())
@decorators.websocket_command({vol.Required("type"): "manifest/list"})
@ -417,7 +398,7 @@ def handle_entity_source(
if entity_perm(entity_id, "read")
}
connection.send_message(messages.result_message(msg["id"], sources))
connection.send_result(msg["id"], sources)
return
sources = {}
@ -535,7 +516,7 @@ async def handle_execute_script(
context = connection.context(msg)
script_obj = Script(hass, msg["sequence"], f"{const.DOMAIN} script", const.DOMAIN)
await script_obj.async_run(msg.get("variables"), context=context)
connection.send_message(messages.result_message(msg["id"], {"context": context}))
connection.send_result(msg["id"], {"context": context})
@decorators.websocket_command(
@ -555,3 +536,40 @@ async def handle_fire_event(
hass.bus.async_fire(msg["event_type"], msg.get("event_data"), context=context)
connection.send_result(msg["id"], {"context": context})
@decorators.websocket_command(
{
vol.Required("type"): "validate_config",
vol.Optional("trigger"): cv.match_all,
vol.Optional("condition"): cv.match_all,
vol.Optional("action"): cv.match_all,
}
)
@decorators.async_response
async def handle_validate_config(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle validate config command."""
# Circular dep
# pylint: disable=import-outside-toplevel
from homeassistant.helpers import condition, script, trigger
result = {}
for key, schema, validator in (
("trigger", cv.TRIGGER_SCHEMA, trigger.async_validate_trigger_config),
("condition", cv.CONDITION_SCHEMA, condition.async_validate_condition_config),
("action", cv.SCRIPT_SCHEMA, script.async_validate_actions_config),
):
if key not in msg:
continue
try:
await validator(hass, schema(msg[key])) # type: ignore
except vol.Invalid as err:
result[key] = {"valid": False, "error": str(err)}
else:
result[key] = {"valid": True, "error": None}
connection.send_result(msg["id"], result)

View File

@ -1033,7 +1033,12 @@ def script_action(value: Any) -> dict:
if not isinstance(value, dict):
raise vol.Invalid("expected dictionary")
return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)
try:
action = determine_script_action(value)
except ValueError as err:
raise vol.Invalid(str(err))
return ACTION_TYPE_SCHEMAS[action](value)
SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])
@ -1444,7 +1449,10 @@ def determine_script_action(action: dict[str, Any]) -> str:
if CONF_VARIABLES in action:
return SCRIPT_ACTION_VARIABLES
return SCRIPT_ACTION_CALL_SERVICE
if CONF_SERVICE in action or CONF_SERVICE_TEMPLATE in action:
return SCRIPT_ACTION_CALL_SERVICE
raise ValueError("Unable to determine action")
ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = {

View File

@ -1286,3 +1286,60 @@ async def test_integration_setup_info(hass, websocket_client, hass_admin_user):
{"domain": "august", "seconds": 12.5},
{"domain": "isy994", "seconds": 12.8},
]
@pytest.mark.parametrize(
"key,config",
(
("trigger", {"platform": "event", "event_type": "hello"}),
(
"condition",
{"condition": "state", "entity_id": "hello.world", "state": "paulus"},
),
("action", {"service": "domain_test.test_service"}),
),
)
async def test_validate_config_works(websocket_client, key, config):
"""Test config validation."""
await websocket_client.send_json({"id": 7, "type": "validate_config", key: config})
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {key: {"valid": True, "error": None}}
@pytest.mark.parametrize(
"key,config,error",
(
(
"trigger",
{"platform": "non_existing", "event_type": "hello"},
"Invalid platform 'non_existing' specified",
),
(
"condition",
{
"condition": "non_existing",
"entity_id": "hello.world",
"state": "paulus",
},
"Unexpected value for condition: 'non_existing'. Expected and, device, not, numeric_state, or, state, sun, template, time, trigger, zone",
),
(
"action",
{"non_existing": "domain_test.test_service"},
"Unable to determine action @ data[0]",
),
),
)
async def test_validate_config_invalid(websocket_client, key, config, error):
"""Test config validation."""
await websocket_client.send_json({"id": 7, "type": "validate_config", key: config})
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {key: {"valid": False, "error": error}}