From 756e7118507388ae6a53c3c5eb588cf8aabf168e Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 22 Feb 2022 13:28:37 -0800 Subject: [PATCH] Add a new validate config WS command (#67057) --- .../components/websocket_api/commands.py | 88 +++++++++++-------- homeassistant/helpers/config_validation.py | 12 ++- .../components/websocket_api/test_commands.py | 57 ++++++++++++ 3 files changed, 120 insertions(+), 37 deletions(-) diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 0b7e355ef24..650013dda7f 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -63,6 +63,7 @@ def async_register_commands( async_reg(hass, handle_subscribe_trigger) async_reg(hass, handle_test_condition) async_reg(hass, handle_unsubscribe_events) + async_reg(hass, handle_validate_config) def pong_message(iden: int) -> dict[str, Any]: @@ -116,7 +117,7 @@ def handle_subscribe_events( event_type, forward_events ) - connection.send_message(messages.result_message(msg["id"])) + connection.send_result(msg["id"]) @callback @@ -139,7 +140,7 @@ def handle_subscribe_bootstrap_integrations( hass, SIGNAL_BOOTSTRAP_INTEGRATONS, forward_bootstrap_integrations ) - connection.send_message(messages.result_message(msg["id"])) + connection.send_result(msg["id"]) @callback @@ -157,13 +158,9 @@ def handle_unsubscribe_events( if subscription in connection.subscriptions: connection.subscriptions.pop(subscription)() - connection.send_message(messages.result_message(msg["id"])) + connection.send_result(msg["id"]) else: - connection.send_message( - messages.error_message( - msg["id"], const.ERR_NOT_FOUND, "Subscription not found." - ) - ) + connection.send_error(msg["id"], const.ERR_NOT_FOUND, "Subscription not found.") @decorators.websocket_command( @@ -196,36 +193,20 @@ async def handle_call_service( context, target=target, ) - connection.send_message( - messages.result_message(msg["id"], {"context": context}) - ) + connection.send_result(msg["id"], {"context": context}) except ServiceNotFound as err: if err.domain == msg["domain"] and err.service == msg["service"]: - connection.send_message( - messages.error_message( - msg["id"], const.ERR_NOT_FOUND, "Service not found." - ) - ) + connection.send_error(msg["id"], const.ERR_NOT_FOUND, "Service not found.") else: - connection.send_message( - messages.error_message( - msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err) - ) - ) + connection.send_error(msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err)) except vol.Invalid as err: - connection.send_message( - messages.error_message(msg["id"], const.ERR_INVALID_FORMAT, str(err)) - ) + connection.send_error(msg["id"], const.ERR_INVALID_FORMAT, str(err)) except HomeAssistantError as err: connection.logger.exception(err) - connection.send_message( - messages.error_message(msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err)) - ) + connection.send_error(msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err)) except Exception as err: # pylint: disable=broad-except connection.logger.exception(err) - connection.send_message( - messages.error_message(msg["id"], const.ERR_UNKNOWN_ERROR, str(err)) - ) + connection.send_error(msg["id"], const.ERR_UNKNOWN_ERROR, str(err)) @callback @@ -244,7 +225,7 @@ def handle_get_states( if entity_perm(state.entity_id, "read") ] - connection.send_message(messages.result_message(msg["id"], states)) + connection.send_result(msg["id"], states) @decorators.websocket_command({vol.Required("type"): "get_services"}) @@ -254,7 +235,7 @@ async def handle_get_services( ) -> None: """Handle get services command.""" descriptions = await async_get_all_descriptions(hass) - connection.send_message(messages.result_message(msg["id"], descriptions)) + connection.send_result(msg["id"], descriptions) @callback @@ -263,7 +244,7 @@ def handle_get_config( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Handle get config command.""" - connection.send_message(messages.result_message(msg["id"], hass.config.as_dict())) + connection.send_result(msg["id"], hass.config.as_dict()) @decorators.websocket_command({vol.Required("type"): "manifest/list"}) @@ -417,7 +398,7 @@ def handle_entity_source( if entity_perm(entity_id, "read") } - connection.send_message(messages.result_message(msg["id"], sources)) + connection.send_result(msg["id"], sources) return sources = {} @@ -535,7 +516,7 @@ async def handle_execute_script( context = connection.context(msg) script_obj = Script(hass, msg["sequence"], f"{const.DOMAIN} script", const.DOMAIN) await script_obj.async_run(msg.get("variables"), context=context) - connection.send_message(messages.result_message(msg["id"], {"context": context})) + connection.send_result(msg["id"], {"context": context}) @decorators.websocket_command( @@ -555,3 +536,40 @@ async def handle_fire_event( hass.bus.async_fire(msg["event_type"], msg.get("event_data"), context=context) connection.send_result(msg["id"], {"context": context}) + + +@decorators.websocket_command( + { + vol.Required("type"): "validate_config", + vol.Optional("trigger"): cv.match_all, + vol.Optional("condition"): cv.match_all, + vol.Optional("action"): cv.match_all, + } +) +@decorators.async_response +async def handle_validate_config( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: + """Handle validate config command.""" + # Circular dep + # pylint: disable=import-outside-toplevel + from homeassistant.helpers import condition, script, trigger + + result = {} + + for key, schema, validator in ( + ("trigger", cv.TRIGGER_SCHEMA, trigger.async_validate_trigger_config), + ("condition", cv.CONDITION_SCHEMA, condition.async_validate_condition_config), + ("action", cv.SCRIPT_SCHEMA, script.async_validate_actions_config), + ): + if key not in msg: + continue + + try: + await validator(hass, schema(msg[key])) # type: ignore + except vol.Invalid as err: + result[key] = {"valid": False, "error": str(err)} + else: + result[key] = {"valid": True, "error": None} + + connection.send_result(msg["id"], result) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 30ce647132e..7f33ec1f1ec 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -1033,7 +1033,12 @@ def script_action(value: Any) -> dict: if not isinstance(value, dict): raise vol.Invalid("expected dictionary") - return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value) + try: + action = determine_script_action(value) + except ValueError as err: + raise vol.Invalid(str(err)) + + return ACTION_TYPE_SCHEMAS[action](value) SCRIPT_SCHEMA = vol.All(ensure_list, [script_action]) @@ -1444,7 +1449,10 @@ def determine_script_action(action: dict[str, Any]) -> str: if CONF_VARIABLES in action: return SCRIPT_ACTION_VARIABLES - return SCRIPT_ACTION_CALL_SERVICE + if CONF_SERVICE in action or CONF_SERVICE_TEMPLATE in action: + return SCRIPT_ACTION_CALL_SERVICE + + raise ValueError("Unable to determine action") ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = { diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 8304b093a14..130870f73f0 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -1286,3 +1286,60 @@ async def test_integration_setup_info(hass, websocket_client, hass_admin_user): {"domain": "august", "seconds": 12.5}, {"domain": "isy994", "seconds": 12.8}, ] + + +@pytest.mark.parametrize( + "key,config", + ( + ("trigger", {"platform": "event", "event_type": "hello"}), + ( + "condition", + {"condition": "state", "entity_id": "hello.world", "state": "paulus"}, + ), + ("action", {"service": "domain_test.test_service"}), + ), +) +async def test_validate_config_works(websocket_client, key, config): + """Test config validation.""" + await websocket_client.send_json({"id": 7, "type": "validate_config", key: config}) + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + assert msg["result"] == {key: {"valid": True, "error": None}} + + +@pytest.mark.parametrize( + "key,config,error", + ( + ( + "trigger", + {"platform": "non_existing", "event_type": "hello"}, + "Invalid platform 'non_existing' specified", + ), + ( + "condition", + { + "condition": "non_existing", + "entity_id": "hello.world", + "state": "paulus", + }, + "Unexpected value for condition: 'non_existing'. Expected and, device, not, numeric_state, or, state, sun, template, time, trigger, zone", + ), + ( + "action", + {"non_existing": "domain_test.test_service"}, + "Unable to determine action @ data[0]", + ), + ), +) +async def test_validate_config_invalid(websocket_client, key, config, error): + """Test config validation.""" + await websocket_client.send_json({"id": 7, "type": "validate_config", key: config}) + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + assert msg["result"] == {key: {"valid": False, "error": error}}