From fb391854209f53d396e350ae4d4eb13c2eaa50c3 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 27 Jan 2021 15:20:22 +0100 Subject: [PATCH] Add schema error handling to websocket_api (#45602) Co-authored-by: Martin Hjelmare --- .../components/websocket_api/commands.py | 4 + .../components/websocket_api/test_commands.py | 73 +++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 2dd6ff47e3c..77521c1ed98 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -156,6 +156,10 @@ async def handle_call_service(hass, connection, msg): msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err) ) ) + except vol.Invalid as err: + connection.send_message( + messages.error_message(msg["id"], const.ERR_INVALID_FORMAT, str(err)) + ) except HomeAssistantError as err: connection.logger.exception(err) connection.send_message( diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 721d178430e..a7aa17db6d3 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -1,5 +1,6 @@ """Tests for WebSocket API commands.""" from async_timeout import timeout +import voluptuous as vol from homeassistant.components.websocket_api import const from homeassistant.components.websocket_api.auth import ( @@ -11,6 +12,7 @@ from homeassistant.components.websocket_api.const import URL from homeassistant.core import Context, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity +from homeassistant.helpers.typing import HomeAssistantType from homeassistant.loader import async_get_integration from homeassistant.setup import async_setup_component @@ -94,6 +96,77 @@ async def test_call_service_child_not_found(hass, websocket_client): assert msg["error"]["code"] == const.ERR_HOME_ASSISTANT_ERROR +async def test_call_service_schema_validation_error( + hass: HomeAssistantType, websocket_client +): + """Test call service command with invalid service data.""" + + calls = [] + service_schema = vol.Schema( + { + vol.Required("message"): str, + } + ) + + @callback + def service_call(call): + calls.append(call) + + hass.services.async_register( + "domain_test", + "test_service", + service_call, + schema=service_schema, + ) + + await websocket_client.send_json( + { + "id": 5, + "type": "call_service", + "domain": "domain_test", + "service": "test_service", + "service_data": {}, + } + ) + msg = await websocket_client.receive_json() + assert msg["id"] == 5 + assert msg["type"] == const.TYPE_RESULT + assert not msg["success"] + assert msg["error"]["code"] == const.ERR_INVALID_FORMAT + + await websocket_client.send_json( + { + "id": 6, + "type": "call_service", + "domain": "domain_test", + "service": "test_service", + "service_data": {"extra_key": "not allowed"}, + } + ) + msg = await websocket_client.receive_json() + assert msg["id"] == 6 + assert msg["type"] == const.TYPE_RESULT + assert not msg["success"] + assert msg["error"]["code"] == const.ERR_INVALID_FORMAT + + await websocket_client.send_json( + { + "id": 7, + "type": "call_service", + "domain": "domain_test", + "service": "test_service", + "service_data": {"message": []}, + } + ) + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == const.TYPE_RESULT + assert not msg["success"] + assert msg["error"]["code"] == const.ERR_INVALID_FORMAT + + assert len(calls) == 0 + + async def test_call_service_error(hass, websocket_client): """Test call service command with error."""