diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index 29dc6113350..11aca19bab9 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -36,6 +36,8 @@ from .error import Disconnect from .messages import message_to_json_bytes from .util import describe_request +CLOSE_MSG_TYPES = {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING} + if TYPE_CHECKING: from .connection import ActiveConnection @@ -344,7 +346,7 @@ class WebSocketHandler: try: connection = await self._async_handle_auth_phase(auth, send_bytes_text) self._async_increase_writer_limit(writer) - await self._async_websocket_command_phase(connection, send_bytes_text) + await self._async_websocket_command_phase(connection) except asyncio.CancelledError: logger.debug("%s: Connection cancelled", self.description) raise @@ -454,9 +456,7 @@ class WebSocketHandler: writer._limit = 2**20 # noqa: SLF001 async def _async_websocket_command_phase( - self, - connection: ActiveConnection, - send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]], + self, connection: ActiveConnection ) -> None: """Handle the command phase of the websocket connection.""" wsock = self._wsock @@ -467,24 +467,26 @@ class WebSocketHandler: # Command phase while not wsock.closed: msg = await wsock.receive() + msg_type = msg.type + msg_data = msg.data - if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): + if msg_type in CLOSE_MSG_TYPES: break - if msg.type is WSMsgType.BINARY: - if len(msg.data) < 1: + if msg_type is WSMsgType.BINARY: + if len(msg_data) < 1: raise Disconnect("Received invalid binary message.") - handler = msg.data[0] - payload = msg.data[1:] + handler = msg_data[0] + payload = msg_data[1:] async_handle_binary(handler, payload) continue - if msg.type is not WSMsgType.TEXT: + if msg_type is not WSMsgType.TEXT: raise Disconnect("Received non-Text message.") try: - command_msg_data = json_loads(msg.data) + command_msg_data = json_loads(msg_data) except ValueError as ex: raise Disconnect("Received invalid JSON.") from ex