diff --git a/homeassistant/components/websocket_api/__init__.py b/homeassistant/components/websocket_api/__init__.py index bfcfc796bad..2d591455eaf 100644 --- a/homeassistant/components/websocket_api/__init__.py +++ b/homeassistant/components/websocket_api/__init__.py @@ -6,7 +6,31 @@ import voluptuous as vol from homeassistant.core import HomeAssistant, callback from homeassistant.loader import bind_hass -from . import commands, connection, const, decorators, http, messages +from . import commands, connection, const, decorators, http, messages # noqa +from .connection import ActiveConnection # noqa +from .const import ( # noqa + ERR_HOME_ASSISTANT_ERROR, + ERR_INVALID_FORMAT, + ERR_NOT_FOUND, + ERR_NOT_SUPPORTED, + ERR_TEMPLATE_ERROR, + ERR_TIMEOUT, + ERR_UNAUTHORIZED, + ERR_UNKNOWN_COMMAND, + ERR_UNKNOWN_ERROR, +) +from .decorators import ( # noqa + async_response, + require_admin, + websocket_command, + ws_require_user, +) +from .messages import ( # noqa + BASE_COMMAND_MESSAGE_SCHEMA, + error_message, + event_message, + result_message, +) # mypy: allow-untyped-calls, allow-untyped-defs @@ -14,17 +38,6 @@ DOMAIN = const.DOMAIN DEPENDENCIES = ("http",) -# Backwards compat / Make it easier to integrate -ActiveConnection = connection.ActiveConnection -BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA -error_message = messages.error_message -result_message = messages.result_message -event_message = messages.event_message -async_response = decorators.async_response -require_admin = decorators.require_admin -ws_require_user = decorators.ws_require_user -websocket_command = decorators.websocket_command - @bind_hass @callback diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index ae2bb16c6d2..108d4de5ada 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, Hashable, Optional import voluptuous as vol from homeassistant.core import Context, callback -from homeassistant.exceptions import Unauthorized +from homeassistant.exceptions import HomeAssistantError, Unauthorized from . import const, messages @@ -118,6 +118,9 @@ class ActiveConnection: elif isinstance(err, asyncio.TimeoutError): code = const.ERR_TIMEOUT err_message = "Timeout" + elif isinstance(err, HomeAssistantError): + code = const.ERR_UNKNOWN_ERROR + err_message = str(err) else: code = const.ERR_UNKNOWN_ERROR err_message = "Unknown error" diff --git a/tests/components/websocket_api/test_connection.py b/tests/components/websocket_api/test_connection.py index 5890fc9a2fa..55126ff1333 100644 --- a/tests/components/websocket_api/test_connection.py +++ b/tests/components/websocket_api/test_connection.py @@ -1,4 +1,10 @@ """Test WebSocket Connection class.""" +import asyncio +import logging + +import voluptuous as vol + +from homeassistant import exceptions from homeassistant.components import websocket_api from homeassistant.components.websocket_api import const @@ -20,3 +26,32 @@ async def test_send_big_result(hass, websocket_client): assert msg["type"] == const.TYPE_RESULT assert msg["success"] assert msg["result"] == {"big": "result"} + + +async def test_exception_handling(): + """Test handling of exceptions.""" + send_messages = [] + conn = websocket_api.ActiveConnection( + logging.getLogger(__name__), None, send_messages.append, None, None + ) + + for (exc, code, err) in ( + (exceptions.Unauthorized(), websocket_api.ERR_UNAUTHORIZED, "Unauthorized"), + ( + vol.Invalid("Invalid something"), + websocket_api.ERR_INVALID_FORMAT, + "Invalid something. Got {'id': 5}", + ), + (asyncio.TimeoutError(), websocket_api.ERR_TIMEOUT, "Timeout"), + ( + exceptions.HomeAssistantError("Failed to do X"), + websocket_api.ERR_UNKNOWN_ERROR, + "Failed to do X", + ), + (ValueError("Really bad"), websocket_api.ERR_UNKNOWN_ERROR, "Unknown error"), + ): + send_messages.clear() + conn.async_handle_exception({"id": 5}, exc) + assert len(send_messages) == 1 + assert send_messages[0]["error"]["code"] == code + assert send_messages[0]["error"]["message"] == err