diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index 7c5fedf888f..4ddfccd032c 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -588,7 +588,6 @@ def websocket_sign_path( { "path": async_sign_path( hass, - connection.refresh_token_id, msg["path"], timedelta(seconds=msg["expires"]), ) diff --git a/homeassistant/components/cast/media_player.py b/homeassistant/components/cast/media_player.py index c1d227c3147..5fdacb7daa6 100644 --- a/homeassistant/components/cast/media_player.py +++ b/homeassistant/components/cast/media_player.py @@ -21,7 +21,6 @@ from pychromecast.socket_client import ( ) import voluptuous as vol -from homeassistant.auth.models import RefreshToken from homeassistant.components import media_source, zeroconf from homeassistant.components.http.auth import async_sign_path from homeassistant.components.media_player import MediaPlayerEntity @@ -472,20 +471,11 @@ class CastDevice(MediaPlayerEntity): # If media ID is a relative URL, we serve it from HA. # Create a signed path. if media_id[0] == "/": - # Sign URL with Home Assistant Cast User - config_entry_id = self.registry_entry.config_entry_id - config_entry = self.hass.config_entries.async_get_entry(config_entry_id) - user_id = config_entry.data["user_id"] - user = await self.hass.auth.async_get_user(user_id) - if user.refresh_tokens: - refresh_token: RefreshToken = list(user.refresh_tokens.values())[0] - - media_id = async_sign_path( - self.hass, - refresh_token.id, - quote(media_id), - timedelta(seconds=media_source.DEFAULT_EXPIRY_TIME), - ) + media_id = async_sign_path( + self.hass, + quote(media_id), + timedelta(seconds=media_source.DEFAULT_EXPIRY_TIME), + ) # prepend external URL hass_url = get_url(self.hass, prefer_external=True) diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 60bc6833caf..9e77563f7a2 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -22,7 +22,7 @@ from homeassistant.loader import bind_hass from homeassistant.setup import async_start_setup, async_when_setup_or_start from homeassistant.util import ssl as ssl_util -from .auth import setup_auth +from .auth import async_setup_auth from .ban import setup_bans from .const import KEY_AUTHENTICATED, KEY_HASS, KEY_HASS_USER # noqa: F401 from .cors import setup_cors @@ -165,12 +165,14 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ssl_certificate=ssl_certificate, ssl_peer_certificate=ssl_peer_certificate, ssl_key=ssl_key, + trusted_proxies=trusted_proxies, + ssl_profile=ssl_profile, + ) + await server.async_initialize( cors_origins=cors_origins, use_x_forwarded_for=use_x_forwarded_for, - trusted_proxies=trusted_proxies, login_threshold=login_threshold, is_ban_enabled=is_ban_enabled, - ssl_profile=ssl_profile, ) async def stop_server(event: Event) -> None: @@ -214,34 +216,11 @@ class HomeAssistantHTTP: ssl_key: str | None, server_host: list[str] | None, server_port: int, - cors_origins: list[str], - use_x_forwarded_for: bool, trusted_proxies: list[str], - login_threshold: int, - is_ban_enabled: bool, ssl_profile: str, ) -> None: """Initialize the HTTP Home Assistant server.""" - app = self.app = web.Application( - middlewares=[], client_max_size=MAX_CLIENT_SIZE - ) - app[KEY_HASS] = hass - - # Order matters, security filters middle ware needs to go first, - # forwarded middleware needs to go second. - setup_security_filter(app) - - async_setup_forwarded(app, use_x_forwarded_for, trusted_proxies) - - setup_request_context(app, current_request) - - if is_ban_enabled: - setup_bans(hass, app, login_threshold) - - setup_auth(hass, app) - - setup_cors(app, cors_origins) - + self.app = web.Application(middlewares=[], client_max_size=MAX_CLIENT_SIZE) self.hass = hass self.ssl_certificate = ssl_certificate self.ssl_peer_certificate = ssl_peer_certificate @@ -249,12 +228,36 @@ class HomeAssistantHTTP: self.server_host = server_host self.server_port = server_port self.trusted_proxies = trusted_proxies - self.is_ban_enabled = is_ban_enabled self.ssl_profile = ssl_profile - self._handler = None self.runner: web.AppRunner | None = None self.site: HomeAssistantTCPSite | None = None + async def async_initialize( + self, + *, + cors_origins: list[str], + use_x_forwarded_for: bool, + login_threshold: int, + is_ban_enabled: bool, + ) -> None: + """Initialize the server.""" + self.app[KEY_HASS] = self.hass + + # Order matters, security filters middleware needs to go first, + # forwarded middleware needs to go second. + setup_security_filter(self.app) + + async_setup_forwarded(self.app, use_x_forwarded_for, self.trusted_proxies) + + setup_request_context(self.app, current_request) + + if is_ban_enabled: + setup_bans(self.hass, self.app, login_threshold) + + await async_setup_auth(self.hass, self.app) + + setup_cors(self.app, cors_origins) + def register_view(self, view: HomeAssistantView | type[HomeAssistantView]) -> None: """Register a view with the WSGI server. diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py index 19f7c429a1e..117d1b2d92e 100644 --- a/homeassistant/components/http/auth.py +++ b/homeassistant/components/http/auth.py @@ -13,7 +13,9 @@ from aiohttp import hdrs from aiohttp.web import Application, Request, StreamResponse, middleware import jwt +from homeassistant.auth.const import GROUP_ID_READ_ONLY from homeassistant.auth.models import User +from homeassistant.components import websocket_api from homeassistant.core import HomeAssistant, callback from homeassistant.util import dt as dt_util from homeassistant.util.network import is_local @@ -27,15 +29,33 @@ DATA_API_PASSWORD: Final = "api_password" DATA_SIGN_SECRET: Final = "http.auth.sign_secret" SIGN_QUERY_PARAM: Final = "authSig" +STORAGE_VERSION = 1 +STORAGE_KEY = "http.auth" +CONTENT_USER_NAME = "Home Assistant Content" + @callback def async_sign_path( - hass: HomeAssistant, refresh_token_id: str, path: str, expiration: timedelta + hass: HomeAssistant, + path: str, + expiration: timedelta, + *, + refresh_token_id: str | None = None, ) -> str: """Sign a path for temporary access without auth header.""" if (secret := hass.data.get(DATA_SIGN_SECRET)) is None: secret = hass.data[DATA_SIGN_SECRET] = secrets.token_hex() + if refresh_token_id is None: + if connection := websocket_api.current_connection.get(): + refresh_token_id = connection.refresh_token_id + elif ( + request := current_request.get() + ) and KEY_HASS_REFRESH_TOKEN_ID in request: + refresh_token_id = request[KEY_HASS_REFRESH_TOKEN_ID] + else: + refresh_token_id = hass.data[STORAGE_KEY] + now = dt_util.utcnow() encoded = jwt.encode( { @@ -86,9 +106,27 @@ def async_user_not_allowed_do_auth( return "User cannot authenticate remotely" -@callback -def setup_auth(hass: HomeAssistant, app: Application) -> None: +async def async_setup_auth(hass: HomeAssistant, app: Application) -> None: """Create auth middleware for the app.""" + store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) + if (data := await store.async_load()) is None: + data = {} + + refresh_token = None + if "content_user" in data: + user = await hass.auth.async_get_user(data["content_user"]) + if user and user.refresh_tokens: + refresh_token = list(user.refresh_tokens.values())[0] + + if refresh_token is None: + user = await hass.auth.async_create_system_user( + CONTENT_USER_NAME, group_ids=[GROUP_ID_READ_ONLY] + ) + refresh_token = await hass.auth.async_create_refresh_token(user) + data["content_user"] = user.id + await store.async_save(data) + + hass.data[STORAGE_KEY] = refresh_token.id async def async_validate_auth_header(request: Request) -> bool: """ diff --git a/homeassistant/components/media_source/__init__.py b/homeassistant/components/media_source/__init__.py index 717a4ad29d0..e20e4b33690 100644 --- a/homeassistant/components/media_source/__init__.py +++ b/homeassistant/components/media_source/__init__.py @@ -132,7 +132,6 @@ async def websocket_resolve_media( if url[0] == "/": url = async_sign_path( hass, - connection.refresh_token_id, quote(url), timedelta(seconds=msg["expires"]), ) diff --git a/homeassistant/components/websocket_api/__init__.py b/homeassistant/components/websocket_api/__init__.py index 13939338c3e..c98ca54d25a 100644 --- a/homeassistant/components/websocket_api/__init__.py +++ b/homeassistant/components/websocket_api/__init__.py @@ -10,7 +10,7 @@ from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass from . import commands, connection, const, decorators, http, messages # noqa: F401 -from .connection import ActiveConnection # noqa: F401 +from .connection import ActiveConnection, current_connection # noqa: F401 from .const import ( # noqa: F401 ERR_HOME_ASSISTANT_ERROR, ERR_INVALID_FORMAT, diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index aec56fdfbf2..075aed86453 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable, Hashable +from contextvars import ContextVar from typing import TYPE_CHECKING, Any import voluptuous as vol @@ -17,6 +18,11 @@ if TYPE_CHECKING: from .http import WebSocketAdapter +current_connection = ContextVar["ActiveConnection | None"]( + "current_connection", default=None +) + + class ActiveConnection: """Handle an active websocket client connection.""" @@ -36,6 +42,7 @@ class ActiveConnection: self.refresh_token_id = refresh_token.id self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.last_id = 0 + current_connection.set(self) def context(self, msg: dict[str, Any]) -> Context: """Return a context.""" diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index 39c7c4897c4..f6d0695d97d 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -484,8 +484,6 @@ async def test_ws_sign_path(hass, hass_ws_client, hass_access_token): assert await async_setup_component(hass, "auth", {"http": {}}) ws_client = await hass_ws_client(hass, hass_access_token) - refresh_token = await hass.auth.async_validate_access_token(hass_access_token) - with patch( "homeassistant.components.auth.async_sign_path", return_value="hello_world" ) as mock_sign: @@ -502,7 +500,6 @@ async def test_ws_sign_path(hass, hass_ws_client, hass_access_token): assert result["success"], result assert result["result"] == {"path": "hello_world"} assert len(mock_sign.mock_calls) == 1 - hass, p_refresh_token, path, expires = mock_sign.mock_calls[0][1] - assert p_refresh_token == refresh_token.id + hass, path, expires = mock_sign.mock_calls[0][1] assert path == "/api/hello" assert expires.total_seconds() == 20 diff --git a/tests/components/config/test_auth.py b/tests/components/config/test_auth.py index 7460de6a751..16f6fa7336b 100644 --- a/tests/components/config/test_auth.py +++ b/tests/components/config/test_auth.py @@ -59,7 +59,7 @@ async def test_list(hass, hass_ws_client, hass_admin_user): result = await client.receive_json() assert result["success"], result data = result["result"] - assert len(data) == 4 + assert len(data) == 5 assert data[0] == { "id": hass_admin_user.id, "username": "admin", @@ -151,7 +151,7 @@ async def test_delete(hass, hass_ws_client, hass_access_token): client = await hass_ws_client(hass, hass_access_token) test_user = MockUser(id="efg").add_to_hass(hass) - assert len(await hass.auth.async_get_users()) == 2 + cur_users = len(await hass.auth.async_get_users()) await client.send_json( {"id": 5, "type": auth_config.WS_TYPE_DELETE, "user_id": test_user.id} @@ -159,20 +159,20 @@ async def test_delete(hass, hass_ws_client, hass_access_token): result = await client.receive_json() assert result["success"], result - assert len(await hass.auth.async_get_users()) == 1 + assert len(await hass.auth.async_get_users()) == cur_users - 1 async def test_create(hass, hass_ws_client, hass_access_token): """Test create command works.""" client = await hass_ws_client(hass, hass_access_token) - assert len(await hass.auth.async_get_users()) == 1 + cur_users = len(await hass.auth.async_get_users()) await client.send_json({"id": 5, "type": "config/auth/create", "name": "Paulus"}) result = await client.receive_json() assert result["success"], result - assert len(await hass.auth.async_get_users()) == 2 + assert len(await hass.auth.async_get_users()) == cur_users + 1 data_user = result["result"]["user"] user = await hass.auth.async_get_user(data_user["id"]) assert user is not None @@ -188,7 +188,7 @@ async def test_create_user_group(hass, hass_ws_client, hass_access_token): """Test create user with a group.""" client = await hass_ws_client(hass, hass_access_token) - assert len(await hass.auth.async_get_users()) == 1 + cur_users = len(await hass.auth.async_get_users()) await client.send_json( { @@ -201,7 +201,7 @@ async def test_create_user_group(hass, hass_ws_client, hass_access_token): result = await client.receive_json() assert result["success"], result - assert len(await hass.auth.async_get_users()) == 2 + assert len(await hass.auth.async_get_users()) == cur_users + 1 data_user = result["result"]["user"] user = await hass.auth.async_get_user(data_user["id"]) assert user is not None diff --git a/tests/components/http/test_auth.py b/tests/components/http/test_auth.py index 1f1a3d32d2c..4a2e1e8aed3 100644 --- a/tests/components/http/test_auth.py +++ b/tests/components/http/test_auth.py @@ -6,16 +6,29 @@ from unittest.mock import Mock, patch from aiohttp import BasicAuth, web from aiohttp.web_exceptions import HTTPUnauthorized +import jwt import pytest +import yarl +from homeassistant.auth.const import GROUP_ID_READ_ONLY +from homeassistant.auth.models import User from homeassistant.auth.providers import trusted_networks +from homeassistant.components import websocket_api from homeassistant.components.http.auth import ( + CONTENT_USER_NAME, + DATA_SIGN_SECRET, + STORAGE_KEY, + async_setup_auth, async_sign_path, async_user_not_allowed_do_auth, - setup_auth, ) from homeassistant.components.http.const import KEY_AUTHENTICATED from homeassistant.components.http.forwarded import async_setup_forwarded +from homeassistant.components.http.request_context import ( + current_request, + setup_request_context, +) +from homeassistant.core import callback from homeassistant.setup import async_setup_component from . import HTTP_HEADER_HA_AUTH, mock_real_ip @@ -86,7 +99,7 @@ def trusted_networks_auth(hass): async def test_auth_middleware_loaded_by_default(hass): """Test accessing to server from banned IP when feature is off.""" - with patch("homeassistant.components.http.setup_auth") as mock_setup: + with patch("homeassistant.components.http.async_setup_auth") as mock_setup: await async_setup_component(hass, "http", {"http": {}}) assert len(mock_setup.mock_calls) == 1 @@ -96,7 +109,7 @@ async def test_cant_access_with_password_in_header( app, aiohttp_client, legacy_auth, hass ): """Test access with password in header.""" - setup_auth(hass, app) + await async_setup_auth(hass, app) client = await aiohttp_client(app) req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD}) @@ -110,7 +123,7 @@ async def test_cant_access_with_password_in_query( app, aiohttp_client, legacy_auth, hass ): """Test access with password in URL.""" - setup_auth(hass, app) + await async_setup_auth(hass, app) client = await aiohttp_client(app) resp = await client.get("/", params={"api_password": API_PASSWORD}) @@ -125,7 +138,7 @@ async def test_cant_access_with_password_in_query( async def test_basic_auth_does_not_work(app, aiohttp_client, hass, legacy_auth): """Test access with basic authentication.""" - setup_auth(hass, app) + await async_setup_auth(hass, app) client = await aiohttp_client(app) req = await client.get("/", auth=BasicAuth("homeassistant", API_PASSWORD)) @@ -145,7 +158,7 @@ async def test_cannot_access_with_trusted_ip( hass, app2, trusted_networks_auth, aiohttp_client, hass_owner_user ): """Test access with an untrusted ip address.""" - setup_auth(hass, app2) + await async_setup_auth(hass, app2) set_mock_ip = mock_real_ip(app2) client = await aiohttp_client(app2) @@ -170,7 +183,7 @@ async def test_auth_active_access_with_access_token_in_header( ): """Test access with access token in header.""" token = hass_access_token - setup_auth(hass, app) + await async_setup_auth(hass, app) client = await aiohttp_client(app) refresh_token = await hass.auth.async_validate_access_token(hass_access_token) @@ -202,7 +215,7 @@ async def test_auth_active_access_with_trusted_ip( hass, app2, trusted_networks_auth, aiohttp_client, hass_owner_user ): """Test access with an untrusted ip address.""" - setup_auth(hass, app2) + await async_setup_auth(hass, app2) set_mock_ip = mock_real_ip(app2) client = await aiohttp_client(app2) @@ -226,7 +239,7 @@ async def test_auth_legacy_support_api_password_cannot_access( app, aiohttp_client, legacy_auth, hass ): """Test access using api_password if auth.support_legacy.""" - setup_auth(hass, app) + await async_setup_auth(hass, app) client = await aiohttp_client(app) req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD}) @@ -239,16 +252,20 @@ async def test_auth_legacy_support_api_password_cannot_access( assert req.status == HTTPStatus.UNAUTHORIZED -async def test_auth_access_signed_path(hass, app, aiohttp_client, hass_access_token): +async def test_auth_access_signed_path_with_refresh_token( + hass, app, aiohttp_client, hass_access_token +): """Test access with signed url.""" app.router.add_post("/", mock_handler) app.router.add_get("/another_path", mock_handler) - setup_auth(hass, app) + await async_setup_auth(hass, app) client = await aiohttp_client(app) refresh_token = await hass.auth.async_validate_access_token(hass_access_token) - signed_path = async_sign_path(hass, refresh_token.id, "/", timedelta(seconds=5)) + signed_path = async_sign_path( + hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id + ) req = await client.get(signed_path) assert req.status == HTTPStatus.OK @@ -265,7 +282,7 @@ async def test_auth_access_signed_path(hass, app, aiohttp_client, hass_access_to # Never valid as expired in the past. expired_signed_path = async_sign_path( - hass, refresh_token.id, "/", timedelta(seconds=-5) + hass, "/", timedelta(seconds=-5), refresh_token_id=refresh_token.id ) req = await client.get(expired_signed_path) @@ -277,10 +294,94 @@ async def test_auth_access_signed_path(hass, app, aiohttp_client, hass_access_to assert req.status == HTTPStatus.UNAUTHORIZED +async def test_auth_access_signed_path_via_websocket( + hass, app, hass_ws_client, hass_read_only_access_token +): + """Test signed url via websockets uses connection user.""" + + @websocket_api.websocket_command({"type": "diagnostics/list"}) + @callback + def get_signed_path(hass, connection, msg): + connection.send_result( + msg["id"], {"path": async_sign_path(hass, "/", timedelta(seconds=5))} + ) + + websocket_api.async_register_command(hass, get_signed_path) + + # We use hass_read_only_access_token to make sure the connection WS is used. + client = await hass_ws_client(access_token=hass_read_only_access_token) + + await client.send_json({"id": 5, "type": "diagnostics/list"}) + + msg = await client.receive_json() + + assert msg["id"] == 5 + assert msg["success"] + + refresh_token = await hass.auth.async_validate_access_token( + hass_read_only_access_token + ) + signature = yarl.URL(msg["result"]["path"]).query["authSig"] + claims = jwt.decode( + signature, + hass.data[DATA_SIGN_SECRET], + algorithms=["HS256"], + options={"verify_signature": False}, + ) + assert claims["iss"] == refresh_token.id + + +async def test_auth_access_signed_path_with_http( + hass, app, aiohttp_client, hass_access_token +): + """Test signed url via HTTP uses HTTP user.""" + setup_request_context(app, current_request) + + async def mock_handler(request): + """Return signed path.""" + return web.json_response( + data={"path": async_sign_path(hass, "/", timedelta(seconds=-5))} + ) + + app.router.add_get("/hello", mock_handler) + await async_setup_auth(hass, app) + client = await aiohttp_client(app) + + refresh_token = await hass.auth.async_validate_access_token(hass_access_token) + + req = await client.get( + "/hello", headers={"Authorization": f"Bearer {hass_access_token}"} + ) + assert req.status == HTTPStatus.OK + data = await req.json() + signature = yarl.URL(data["path"]).query["authSig"] + claims = jwt.decode( + signature, + hass.data[DATA_SIGN_SECRET], + algorithms=["HS256"], + options={"verify_signature": False}, + ) + assert claims["iss"] == refresh_token.id + + +async def test_auth_access_signed_path_with_content_user(hass, app, aiohttp_client): + """Test access signed url uses content user.""" + await async_setup_auth(hass, app) + signed_path = async_sign_path(hass, "/", timedelta(seconds=5)) + signature = yarl.URL(signed_path).query["authSig"] + claims = jwt.decode( + signature, + hass.data[DATA_SIGN_SECRET], + algorithms=["HS256"], + options={"verify_signature": False}, + ) + assert claims["iss"] == hass.data[STORAGE_KEY] + + async def test_local_only_user_rejected(hass, app, aiohttp_client, hass_access_token): """Test access with access token in header.""" token = hass_access_token - setup_auth(hass, app) + await async_setup_auth(hass, app) set_mock_ip = mock_real_ip(app) client = await aiohttp_client(app) refresh_token = await hass.auth.async_validate_access_token(hass_access_token) @@ -340,3 +441,25 @@ async def test_async_user_not_allowed_do_auth(hass, app): async_user_not_allowed_do_auth(hass, user, trusted_request) == "User is local only" ) + + +async def test_create_user_once(hass): + """Test that we reuse the user.""" + cur_users = len(await hass.auth.async_get_users()) + app = web.Application() + await async_setup_auth(hass, app) + users = await hass.auth.async_get_users() + assert len(users) == cur_users + 1 + + user: User = next((user for user in users if user.name == CONTENT_USER_NAME), None) + assert user is not None, users + + assert len(user.groups) == 1 + assert user.groups[0].id == GROUP_ID_READ_ONLY + assert len(user.refresh_tokens) == 1 + assert user.system_generated + + await async_setup_auth(hass, app) + + # test it did not create a user + assert len(await hass.auth.async_get_users()) == cur_users + 1 diff --git a/tests/components/onboarding/test_views.py b/tests/components/onboarding/test_views.py index 45fe9a19546..9605fb9e71c 100644 --- a/tests/components/onboarding/test_views.py +++ b/tests/components/onboarding/test_views.py @@ -139,6 +139,7 @@ async def test_onboarding_user(hass, hass_storage, hass_client_no_auth): assert await async_setup_component(hass, "onboarding", {}) await hass.async_block_till_done() + cur_users = len(await hass.auth.async_get_users()) client = await hass_client_no_auth() resp = await client.post( @@ -159,9 +160,9 @@ async def test_onboarding_user(hass, hass_storage, hass_client_no_auth): assert "auth_code" in data users = await hass.auth.async_get_users() - assert len(users) == 1 - user = users[0] - assert user.name == "Test Name" + assert len(await hass.auth.async_get_users()) == cur_users + 1 + user = next((user for user in users if user.name == "Test Name"), None) + assert user is not None assert len(user.credentials) == 1 assert user.credentials[0].data["username"] == "test-user" assert len(hass.data["person"][1].async_items()) == 1 @@ -287,8 +288,8 @@ async def test_onboarding_integration(hass, hass_storage, hass_client, hass_admi ) # Onboarding refresh token and new refresh token - for user in await hass.auth.async_get_users(): - assert len(user.refresh_tokens) == 2, user + user = await hass.auth.async_get_user(hass_admin_user.id) + assert len(user.refresh_tokens) == 2, user async def test_onboarding_integration_missing_credential( diff --git a/tests/conftest.py b/tests/conftest.py index 56be04edeeb..9f0958e6ace 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -503,25 +503,18 @@ def hass_ws_client(aiohttp_client, hass_access_token, hass, socket_enabled): async def create_client(hass=hass, access_token=hass_access_token): """Create a websocket client.""" assert await async_setup_component(hass, "websocket_api", {}) - client = await aiohttp_client(hass.http.app) + websocket = await client.ws_connect(URL) + auth_resp = await websocket.receive_json() + assert auth_resp["type"] == TYPE_AUTH_REQUIRED - with patch("homeassistant.components.http.auth.setup_auth"): - websocket = await client.ws_connect(URL) - auth_resp = await websocket.receive_json() - assert auth_resp["type"] == TYPE_AUTH_REQUIRED + if access_token is None: + await websocket.send_json({"type": TYPE_AUTH, "access_token": "incorrect"}) + else: + await websocket.send_json({"type": TYPE_AUTH, "access_token": access_token}) - if access_token is None: - await websocket.send_json( - {"type": TYPE_AUTH, "access_token": "incorrect"} - ) - else: - await websocket.send_json( - {"type": TYPE_AUTH, "access_token": access_token} - ) - - auth_ok = await websocket.receive_json() - assert auth_ok["type"] == TYPE_AUTH_OK + auth_ok = await websocket.receive_json() + assert auth_ok["type"] == TYPE_AUTH_OK # wrap in client websocket.client = client