Small cleanups to the websocket api handler (#108274)

This commit is contained in:
J. Nick Koston 2024-01-17 21:39:49 -10:00 committed by GitHub
parent c656024365
commit b4b041d4bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 21 deletions

View File

@ -12,6 +12,7 @@ from homeassistant.auth.models import RefreshToken, User
from homeassistant.components.http.ban import process_success_login, process_wrong_login from homeassistant.components.http.ban import process_success_login, process_wrong_login
from homeassistant.const import __version__ from homeassistant.const import __version__
from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.helpers.json import json_bytes
from homeassistant.util.json import JsonValueType from homeassistant.util.json import JsonValueType
from .connection import ActiveConnection from .connection import ActiveConnection
@ -34,15 +35,10 @@ AUTH_MESSAGE_SCHEMA: Final = vol.Schema(
} }
) )
AUTH_OK_MESSAGE = json_bytes({"type": TYPE_AUTH_OK, "ha_version": __version__})
def auth_ok_message() -> dict[str, str]: AUTH_REQUIRED_MESSAGE = json_bytes(
"""Return an auth_ok message.""" {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
return {"type": TYPE_AUTH_OK, "ha_version": __version__} )
def auth_required_message() -> dict[str, str]:
"""Return an auth_required message."""
return {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
def auth_invalid_message(message: str) -> dict[str, str]: def auth_invalid_message(message: str) -> dict[str, str]:
@ -104,7 +100,7 @@ class AuthPhase:
"""Create an active connection.""" """Create an active connection."""
self._logger.debug("Auth OK") self._logger.debug("Auth OK")
process_success_login(self._request) process_success_login(self._request)
self._send_message(auth_ok_message()) self._send_message(AUTH_OK_MESSAGE)
return ActiveConnection( return ActiveConnection(
self._logger, self._hass, self._send_message, user, refresh_token self._logger, self._hass, self._send_message, user, refresh_token
) )

View File

@ -18,7 +18,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from homeassistant.util.json import json_loads from homeassistant.util.json import json_loads
from .auth import AuthPhase, auth_required_message from .auth import AUTH_REQUIRED_MESSAGE, AuthPhase
from .const import ( from .const import (
DATA_CONNECTIONS, DATA_CONNECTIONS,
MAX_PENDING_MSG, MAX_PENDING_MSG,
@ -266,6 +266,11 @@ class WebSocketHandler:
if self._writer_task is not None: if self._writer_task is not None:
self._writer_task.cancel() self._writer_task.cancel()
@callback
def _async_handle_hass_stop(self, event: Event) -> None:
"""Cancel this connection."""
self._cancel()
async def async_handle(self) -> web.WebSocketResponse: async def async_handle(self) -> web.WebSocketResponse:
"""Handle a websocket response.""" """Handle a websocket response."""
request = self._request request = self._request
@ -286,12 +291,9 @@ class WebSocketHandler:
debug("%s: Connected from %s", self.description, request.remote) debug("%s: Connected from %s", self.description, request.remote)
self._handle_task = asyncio.current_task() self._handle_task = asyncio.current_task()
@callback unsub_stop = hass.bus.async_listen(
def handle_hass_stop(event: Event) -> None: EVENT_HOMEASSISTANT_STOP, self._async_handle_hass_stop
"""Cancel this connection.""" )
self._cancel()
unsub_stop = hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
# As the webserver is now started before the start # As the webserver is now started before the start
# event we do not want to block for websocket responses # event we do not want to block for websocket responses
@ -302,7 +304,7 @@ class WebSocketHandler:
disconnect_warn = None disconnect_warn = None
try: try:
self._send_message(auth_required_message()) self._send_message(AUTH_REQUIRED_MESSAGE)
# Auth Phase # Auth Phase
try: try:
@ -379,7 +381,7 @@ class WebSocketHandler:
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
break break
if msg.type == WSMsgType.BINARY: if msg.type is WSMsgType.BINARY:
if len(msg.data) < 1: if len(msg.data) < 1:
disconnect_warn = "Received invalid binary message." disconnect_warn = "Received invalid binary message."
break break
@ -388,7 +390,7 @@ class WebSocketHandler:
async_handle_binary(handler, payload) async_handle_binary(handler, payload)
continue continue
if msg.type != WSMsgType.TEXT: if msg.type is not WSMsgType.TEXT:
disconnect_warn = "Received non-Text message." disconnect_warn = "Received non-Text message."
break break
@ -401,7 +403,8 @@ class WebSocketHandler:
if is_enabled_for(logging_debug): if is_enabled_for(logging_debug):
debug("%s: Received %s", self.description, command_msg_data) debug("%s: Received %s", self.description, command_msg_data)
if not isinstance(command_msg_data, list): # command_msg_data is always deserialized from JSON as a list
if type(command_msg_data) is not list: # noqa: E721
async_handle_str(command_msg_data) async_handle_str(command_msg_data)
continue continue