diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 3e8328c117c..90c7e9906e4 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -46,6 +46,7 @@ class ActiveConnection: self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.last_id = 0 self.supported_features: dict[str, float] = {} + self.handlers = self.hass.data[const.DOMAIN] current_connection.set(self) def get_description(self, request: web.Request | None) -> str: @@ -72,12 +73,17 @@ class ActiveConnection: @callback def async_handle(self, msg: dict[str, Any]) -> None: """Handle a single incoming message.""" - handlers = self.hass.data[const.DOMAIN] - - try: - msg = messages.MINIMAL_MESSAGE_SCHEMA(msg) - cur_id = msg["id"] - except vol.Invalid: + if ( + # Not using isinstance as we don't care about children + # as these are always coming from JSON + type(msg) is not dict # pylint: disable=unidiomatic-typecheck + or ( + not (cur_id := msg.get("id")) + or type(cur_id) is not int # pylint: disable=unidiomatic-typecheck + or not (type_ := msg.get("type")) + or type(type_) is not str # pylint: disable=unidiomatic-typecheck + ) + ): self.logger.error("Received invalid command", msg) self.send_message( messages.error_message( @@ -96,8 +102,8 @@ class ActiveConnection: ) return - if msg["type"] not in handlers: - self.logger.info("Received unknown command: {}".format(msg["type"])) + if not (handler_schema := self.handlers.get(type_)): + self.logger.info(f"Received unknown command: {type_}") self.send_message( messages.error_message( cur_id, const.ERR_UNKNOWN_COMMAND, "Unknown command." @@ -105,7 +111,7 @@ class ActiveConnection: ) return - handler, schema = handlers[msg["type"]] + handler, schema = handler_schema try: handler(self.hass, self, schema(msg)) diff --git a/tests/components/websocket_api/test_connection.py b/tests/components/websocket_api/test_connection.py index ce484939d8d..53baab98b4f 100644 --- a/tests/components/websocket_api/test_connection.py +++ b/tests/components/websocket_api/test_connection.py @@ -10,6 +10,8 @@ import voluptuous as vol from homeassistant import exceptions from homeassistant.components import websocket_api +from homeassistant.components.websocket_api.const import DOMAIN +from homeassistant.core import HomeAssistant from tests.common import MockUser @@ -56,6 +58,7 @@ from tests.common import MockUser ], ) async def test_exception_handling( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture, exc: Exception, code: str, @@ -67,6 +70,7 @@ async def test_exception_handling( user = MockUser() refresh_token = Mock() current_request = AsyncMock() + hass.data[DOMAIN] = {} def get_extra_info(key: str) -> Any: if key == "sslcontext": @@ -89,7 +93,7 @@ async def test_exception_handling( ) as current_request: current_request.get.return_value = mocked_request conn = websocket_api.ActiveConnection( - logging.getLogger(__name__), None, send_messages.append, user, refresh_token + logging.getLogger(__name__), hass, send_messages.append, user, refresh_token ) conn.async_handle_exception({"id": 5}, exc) diff --git a/tests/components/websocket_api/test_http.py b/tests/components/websocket_api/test_http.py index e68a6d4c492..fce6eb428ae 100644 --- a/tests/components/websocket_api/test_http.py +++ b/tests/components/websocket_api/test_http.py @@ -17,7 +17,7 @@ from tests.typing import WebSocketGenerator @pytest.fixture def mock_low_queue(): """Mock a low queue.""" - with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 5): + with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 1): yield