mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Small cleanups to the websocket api handler (#108274)
This commit is contained in:
parent
c656024365
commit
b4b041d4bf
@ -12,6 +12,7 @@ from homeassistant.auth.models import RefreshToken, User
|
|||||||
from homeassistant.components.http.ban import process_success_login, process_wrong_login
|
from homeassistant.components.http.ban import process_success_login, process_wrong_login
|
||||||
from homeassistant.const import __version__
|
from homeassistant.const import __version__
|
||||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
||||||
|
from homeassistant.helpers.json import json_bytes
|
||||||
from homeassistant.util.json import JsonValueType
|
from homeassistant.util.json import JsonValueType
|
||||||
|
|
||||||
from .connection import ActiveConnection
|
from .connection import ActiveConnection
|
||||||
@ -34,15 +35,10 @@ AUTH_MESSAGE_SCHEMA: Final = vol.Schema(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
AUTH_OK_MESSAGE = json_bytes({"type": TYPE_AUTH_OK, "ha_version": __version__})
|
||||||
def auth_ok_message() -> dict[str, str]:
|
AUTH_REQUIRED_MESSAGE = json_bytes(
|
||||||
"""Return an auth_ok message."""
|
{"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
|
||||||
return {"type": TYPE_AUTH_OK, "ha_version": __version__}
|
)
|
||||||
|
|
||||||
|
|
||||||
def auth_required_message() -> dict[str, str]:
|
|
||||||
"""Return an auth_required message."""
|
|
||||||
return {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
|
|
||||||
|
|
||||||
|
|
||||||
def auth_invalid_message(message: str) -> dict[str, str]:
|
def auth_invalid_message(message: str) -> dict[str, str]:
|
||||||
@ -104,7 +100,7 @@ class AuthPhase:
|
|||||||
"""Create an active connection."""
|
"""Create an active connection."""
|
||||||
self._logger.debug("Auth OK")
|
self._logger.debug("Auth OK")
|
||||||
process_success_login(self._request)
|
process_success_login(self._request)
|
||||||
self._send_message(auth_ok_message())
|
self._send_message(AUTH_OK_MESSAGE)
|
||||||
return ActiveConnection(
|
return ActiveConnection(
|
||||||
self._logger, self._hass, self._send_message, user, refresh_token
|
self._logger, self._hass, self._send_message, user, refresh_token
|
||||||
)
|
)
|
||||||
|
@ -18,7 +18,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
|
|||||||
from homeassistant.helpers.event import async_call_later
|
from homeassistant.helpers.event import async_call_later
|
||||||
from homeassistant.util.json import json_loads
|
from homeassistant.util.json import json_loads
|
||||||
|
|
||||||
from .auth import AuthPhase, auth_required_message
|
from .auth import AUTH_REQUIRED_MESSAGE, AuthPhase
|
||||||
from .const import (
|
from .const import (
|
||||||
DATA_CONNECTIONS,
|
DATA_CONNECTIONS,
|
||||||
MAX_PENDING_MSG,
|
MAX_PENDING_MSG,
|
||||||
@ -266,6 +266,11 @@ class WebSocketHandler:
|
|||||||
if self._writer_task is not None:
|
if self._writer_task is not None:
|
||||||
self._writer_task.cancel()
|
self._writer_task.cancel()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_handle_hass_stop(self, event: Event) -> None:
|
||||||
|
"""Cancel this connection."""
|
||||||
|
self._cancel()
|
||||||
|
|
||||||
async def async_handle(self) -> web.WebSocketResponse:
|
async def async_handle(self) -> web.WebSocketResponse:
|
||||||
"""Handle a websocket response."""
|
"""Handle a websocket response."""
|
||||||
request = self._request
|
request = self._request
|
||||||
@ -286,12 +291,9 @@ class WebSocketHandler:
|
|||||||
debug("%s: Connected from %s", self.description, request.remote)
|
debug("%s: Connected from %s", self.description, request.remote)
|
||||||
self._handle_task = asyncio.current_task()
|
self._handle_task = asyncio.current_task()
|
||||||
|
|
||||||
@callback
|
unsub_stop = hass.bus.async_listen(
|
||||||
def handle_hass_stop(event: Event) -> None:
|
EVENT_HOMEASSISTANT_STOP, self._async_handle_hass_stop
|
||||||
"""Cancel this connection."""
|
)
|
||||||
self._cancel()
|
|
||||||
|
|
||||||
unsub_stop = hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
|
|
||||||
|
|
||||||
# As the webserver is now started before the start
|
# As the webserver is now started before the start
|
||||||
# event we do not want to block for websocket responses
|
# event we do not want to block for websocket responses
|
||||||
@ -302,7 +304,7 @@ class WebSocketHandler:
|
|||||||
disconnect_warn = None
|
disconnect_warn = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._send_message(auth_required_message())
|
self._send_message(AUTH_REQUIRED_MESSAGE)
|
||||||
|
|
||||||
# Auth Phase
|
# Auth Phase
|
||||||
try:
|
try:
|
||||||
@ -379,7 +381,7 @@ class WebSocketHandler:
|
|||||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
||||||
break
|
break
|
||||||
|
|
||||||
if msg.type == WSMsgType.BINARY:
|
if msg.type is WSMsgType.BINARY:
|
||||||
if len(msg.data) < 1:
|
if len(msg.data) < 1:
|
||||||
disconnect_warn = "Received invalid binary message."
|
disconnect_warn = "Received invalid binary message."
|
||||||
break
|
break
|
||||||
@ -388,7 +390,7 @@ class WebSocketHandler:
|
|||||||
async_handle_binary(handler, payload)
|
async_handle_binary(handler, payload)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if msg.type != WSMsgType.TEXT:
|
if msg.type is not WSMsgType.TEXT:
|
||||||
disconnect_warn = "Received non-Text message."
|
disconnect_warn = "Received non-Text message."
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -401,7 +403,8 @@ class WebSocketHandler:
|
|||||||
if is_enabled_for(logging_debug):
|
if is_enabled_for(logging_debug):
|
||||||
debug("%s: Received %s", self.description, command_msg_data)
|
debug("%s: Received %s", self.description, command_msg_data)
|
||||||
|
|
||||||
if not isinstance(command_msg_data, list):
|
# command_msg_data is always deserialized from JSON as a list
|
||||||
|
if type(command_msg_data) is not list: # noqa: E721
|
||||||
async_handle_str(command_msg_data)
|
async_handle_str(command_msg_data)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user