Add Home Assistant Content user (#64337)

This commit is contained in:
Paulus Schoutsen 2022-01-21 10:06:39 -08:00 committed by GitHub
parent b3cda6b681
commit 63f8e437ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 246 additions and 96 deletions

View File

@ -588,7 +588,6 @@ def websocket_sign_path(
{ {
"path": async_sign_path( "path": async_sign_path(
hass, hass,
connection.refresh_token_id,
msg["path"], msg["path"],
timedelta(seconds=msg["expires"]), timedelta(seconds=msg["expires"]),
) )

View File

@ -21,7 +21,6 @@ from pychromecast.socket_client import (
) )
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.models import RefreshToken
from homeassistant.components import media_source, zeroconf from homeassistant.components import media_source, zeroconf
from homeassistant.components.http.auth import async_sign_path from homeassistant.components.http.auth import async_sign_path
from homeassistant.components.media_player import MediaPlayerEntity from homeassistant.components.media_player import MediaPlayerEntity
@ -472,17 +471,8 @@ class CastDevice(MediaPlayerEntity):
# If media ID is a relative URL, we serve it from HA. # If media ID is a relative URL, we serve it from HA.
# Create a signed path. # Create a signed path.
if media_id[0] == "/": if media_id[0] == "/":
# Sign URL with Home Assistant Cast User
config_entry_id = self.registry_entry.config_entry_id
config_entry = self.hass.config_entries.async_get_entry(config_entry_id)
user_id = config_entry.data["user_id"]
user = await self.hass.auth.async_get_user(user_id)
if user.refresh_tokens:
refresh_token: RefreshToken = list(user.refresh_tokens.values())[0]
media_id = async_sign_path( media_id = async_sign_path(
self.hass, self.hass,
refresh_token.id,
quote(media_id), quote(media_id),
timedelta(seconds=media_source.DEFAULT_EXPIRY_TIME), timedelta(seconds=media_source.DEFAULT_EXPIRY_TIME),
) )

View File

@ -22,7 +22,7 @@ from homeassistant.loader import bind_hass
from homeassistant.setup import async_start_setup, async_when_setup_or_start from homeassistant.setup import async_start_setup, async_when_setup_or_start
from homeassistant.util import ssl as ssl_util from homeassistant.util import ssl as ssl_util
from .auth import setup_auth from .auth import async_setup_auth
from .ban import setup_bans from .ban import setup_bans
from .const import KEY_AUTHENTICATED, KEY_HASS, KEY_HASS_USER # noqa: F401 from .const import KEY_AUTHENTICATED, KEY_HASS, KEY_HASS_USER # noqa: F401
from .cors import setup_cors from .cors import setup_cors
@ -165,12 +165,14 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
ssl_certificate=ssl_certificate, ssl_certificate=ssl_certificate,
ssl_peer_certificate=ssl_peer_certificate, ssl_peer_certificate=ssl_peer_certificate,
ssl_key=ssl_key, ssl_key=ssl_key,
trusted_proxies=trusted_proxies,
ssl_profile=ssl_profile,
)
await server.async_initialize(
cors_origins=cors_origins, cors_origins=cors_origins,
use_x_forwarded_for=use_x_forwarded_for, use_x_forwarded_for=use_x_forwarded_for,
trusted_proxies=trusted_proxies,
login_threshold=login_threshold, login_threshold=login_threshold,
is_ban_enabled=is_ban_enabled, is_ban_enabled=is_ban_enabled,
ssl_profile=ssl_profile,
) )
async def stop_server(event: Event) -> None: async def stop_server(event: Event) -> None:
@ -214,34 +216,11 @@ class HomeAssistantHTTP:
ssl_key: str | None, ssl_key: str | None,
server_host: list[str] | None, server_host: list[str] | None,
server_port: int, server_port: int,
cors_origins: list[str],
use_x_forwarded_for: bool,
trusted_proxies: list[str], trusted_proxies: list[str],
login_threshold: int,
is_ban_enabled: bool,
ssl_profile: str, ssl_profile: str,
) -> None: ) -> None:
"""Initialize the HTTP Home Assistant server.""" """Initialize the HTTP Home Assistant server."""
app = self.app = web.Application( self.app = web.Application(middlewares=[], client_max_size=MAX_CLIENT_SIZE)
middlewares=[], client_max_size=MAX_CLIENT_SIZE
)
app[KEY_HASS] = hass
# Order matters, security filters middle ware needs to go first,
# forwarded middleware needs to go second.
setup_security_filter(app)
async_setup_forwarded(app, use_x_forwarded_for, trusted_proxies)
setup_request_context(app, current_request)
if is_ban_enabled:
setup_bans(hass, app, login_threshold)
setup_auth(hass, app)
setup_cors(app, cors_origins)
self.hass = hass self.hass = hass
self.ssl_certificate = ssl_certificate self.ssl_certificate = ssl_certificate
self.ssl_peer_certificate = ssl_peer_certificate self.ssl_peer_certificate = ssl_peer_certificate
@ -249,12 +228,36 @@ class HomeAssistantHTTP:
self.server_host = server_host self.server_host = server_host
self.server_port = server_port self.server_port = server_port
self.trusted_proxies = trusted_proxies self.trusted_proxies = trusted_proxies
self.is_ban_enabled = is_ban_enabled
self.ssl_profile = ssl_profile self.ssl_profile = ssl_profile
self._handler = None
self.runner: web.AppRunner | None = None self.runner: web.AppRunner | None = None
self.site: HomeAssistantTCPSite | None = None self.site: HomeAssistantTCPSite | None = None
async def async_initialize(
self,
*,
cors_origins: list[str],
use_x_forwarded_for: bool,
login_threshold: int,
is_ban_enabled: bool,
) -> None:
"""Initialize the server."""
self.app[KEY_HASS] = self.hass
# Order matters, security filters middleware needs to go first,
# forwarded middleware needs to go second.
setup_security_filter(self.app)
async_setup_forwarded(self.app, use_x_forwarded_for, self.trusted_proxies)
setup_request_context(self.app, current_request)
if is_ban_enabled:
setup_bans(self.hass, self.app, login_threshold)
await async_setup_auth(self.hass, self.app)
setup_cors(self.app, cors_origins)
def register_view(self, view: HomeAssistantView | type[HomeAssistantView]) -> None: def register_view(self, view: HomeAssistantView | type[HomeAssistantView]) -> None:
"""Register a view with the WSGI server. """Register a view with the WSGI server.

View File

@ -13,7 +13,9 @@ from aiohttp import hdrs
from aiohttp.web import Application, Request, StreamResponse, middleware from aiohttp.web import Application, Request, StreamResponse, middleware
import jwt import jwt
from homeassistant.auth.const import GROUP_ID_READ_ONLY
from homeassistant.auth.models import User from homeassistant.auth.models import User
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.network import is_local from homeassistant.util.network import is_local
@ -27,15 +29,33 @@ DATA_API_PASSWORD: Final = "api_password"
DATA_SIGN_SECRET: Final = "http.auth.sign_secret" DATA_SIGN_SECRET: Final = "http.auth.sign_secret"
SIGN_QUERY_PARAM: Final = "authSig" SIGN_QUERY_PARAM: Final = "authSig"
STORAGE_VERSION = 1
STORAGE_KEY = "http.auth"
CONTENT_USER_NAME = "Home Assistant Content"
@callback @callback
def async_sign_path( def async_sign_path(
hass: HomeAssistant, refresh_token_id: str, path: str, expiration: timedelta hass: HomeAssistant,
path: str,
expiration: timedelta,
*,
refresh_token_id: str | None = None,
) -> str: ) -> str:
"""Sign a path for temporary access without auth header.""" """Sign a path for temporary access without auth header."""
if (secret := hass.data.get(DATA_SIGN_SECRET)) is None: if (secret := hass.data.get(DATA_SIGN_SECRET)) is None:
secret = hass.data[DATA_SIGN_SECRET] = secrets.token_hex() secret = hass.data[DATA_SIGN_SECRET] = secrets.token_hex()
if refresh_token_id is None:
if connection := websocket_api.current_connection.get():
refresh_token_id = connection.refresh_token_id
elif (
request := current_request.get()
) and KEY_HASS_REFRESH_TOKEN_ID in request:
refresh_token_id = request[KEY_HASS_REFRESH_TOKEN_ID]
else:
refresh_token_id = hass.data[STORAGE_KEY]
now = dt_util.utcnow() now = dt_util.utcnow()
encoded = jwt.encode( encoded = jwt.encode(
{ {
@ -86,9 +106,27 @@ def async_user_not_allowed_do_auth(
return "User cannot authenticate remotely" return "User cannot authenticate remotely"
@callback async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
def setup_auth(hass: HomeAssistant, app: Application) -> None:
"""Create auth middleware for the app.""" """Create auth middleware for the app."""
store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
if (data := await store.async_load()) is None:
data = {}
refresh_token = None
if "content_user" in data:
user = await hass.auth.async_get_user(data["content_user"])
if user and user.refresh_tokens:
refresh_token = list(user.refresh_tokens.values())[0]
if refresh_token is None:
user = await hass.auth.async_create_system_user(
CONTENT_USER_NAME, group_ids=[GROUP_ID_READ_ONLY]
)
refresh_token = await hass.auth.async_create_refresh_token(user)
data["content_user"] = user.id
await store.async_save(data)
hass.data[STORAGE_KEY] = refresh_token.id
async def async_validate_auth_header(request: Request) -> bool: async def async_validate_auth_header(request: Request) -> bool:
""" """

View File

@ -132,7 +132,6 @@ async def websocket_resolve_media(
if url[0] == "/": if url[0] == "/":
url = async_sign_path( url = async_sign_path(
hass, hass,
connection.refresh_token_id,
quote(url), quote(url),
timedelta(seconds=msg["expires"]), timedelta(seconds=msg["expires"]),
) )

View File

@ -10,7 +10,7 @@ from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from . import commands, connection, const, decorators, http, messages # noqa: F401 from . import commands, connection, const, decorators, http, messages # noqa: F401
from .connection import ActiveConnection # noqa: F401 from .connection import ActiveConnection, current_connection # noqa: F401
from .const import ( # noqa: F401 from .const import ( # noqa: F401
ERR_HOME_ASSISTANT_ERROR, ERR_HOME_ASSISTANT_ERROR,
ERR_INVALID_FORMAT, ERR_INVALID_FORMAT,

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Hashable from collections.abc import Callable, Hashable
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import voluptuous as vol import voluptuous as vol
@ -17,6 +18,11 @@ if TYPE_CHECKING:
from .http import WebSocketAdapter from .http import WebSocketAdapter
current_connection = ContextVar["ActiveConnection | None"](
"current_connection", default=None
)
class ActiveConnection: class ActiveConnection:
"""Handle an active websocket client connection.""" """Handle an active websocket client connection."""
@ -36,6 +42,7 @@ class ActiveConnection:
self.refresh_token_id = refresh_token.id self.refresh_token_id = refresh_token.id
self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0 self.last_id = 0
current_connection.set(self)
def context(self, msg: dict[str, Any]) -> Context: def context(self, msg: dict[str, Any]) -> Context:
"""Return a context.""" """Return a context."""

View File

@ -484,8 +484,6 @@ async def test_ws_sign_path(hass, hass_ws_client, hass_access_token):
assert await async_setup_component(hass, "auth", {"http": {}}) assert await async_setup_component(hass, "auth", {"http": {}})
ws_client = await hass_ws_client(hass, hass_access_token) ws_client = await hass_ws_client(hass, hass_access_token)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
with patch( with patch(
"homeassistant.components.auth.async_sign_path", return_value="hello_world" "homeassistant.components.auth.async_sign_path", return_value="hello_world"
) as mock_sign: ) as mock_sign:
@ -502,7 +500,6 @@ async def test_ws_sign_path(hass, hass_ws_client, hass_access_token):
assert result["success"], result assert result["success"], result
assert result["result"] == {"path": "hello_world"} assert result["result"] == {"path": "hello_world"}
assert len(mock_sign.mock_calls) == 1 assert len(mock_sign.mock_calls) == 1
hass, p_refresh_token, path, expires = mock_sign.mock_calls[0][1] hass, path, expires = mock_sign.mock_calls[0][1]
assert p_refresh_token == refresh_token.id
assert path == "/api/hello" assert path == "/api/hello"
assert expires.total_seconds() == 20 assert expires.total_seconds() == 20

View File

@ -59,7 +59,7 @@ async def test_list(hass, hass_ws_client, hass_admin_user):
result = await client.receive_json() result = await client.receive_json()
assert result["success"], result assert result["success"], result
data = result["result"] data = result["result"]
assert len(data) == 4 assert len(data) == 5
assert data[0] == { assert data[0] == {
"id": hass_admin_user.id, "id": hass_admin_user.id,
"username": "admin", "username": "admin",
@ -151,7 +151,7 @@ async def test_delete(hass, hass_ws_client, hass_access_token):
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
test_user = MockUser(id="efg").add_to_hass(hass) test_user = MockUser(id="efg").add_to_hass(hass)
assert len(await hass.auth.async_get_users()) == 2 cur_users = len(await hass.auth.async_get_users())
await client.send_json( await client.send_json(
{"id": 5, "type": auth_config.WS_TYPE_DELETE, "user_id": test_user.id} {"id": 5, "type": auth_config.WS_TYPE_DELETE, "user_id": test_user.id}
@ -159,20 +159,20 @@ async def test_delete(hass, hass_ws_client, hass_access_token):
result = await client.receive_json() result = await client.receive_json()
assert result["success"], result assert result["success"], result
assert len(await hass.auth.async_get_users()) == 1 assert len(await hass.auth.async_get_users()) == cur_users - 1
async def test_create(hass, hass_ws_client, hass_access_token): async def test_create(hass, hass_ws_client, hass_access_token):
"""Test create command works.""" """Test create command works."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
assert len(await hass.auth.async_get_users()) == 1 cur_users = len(await hass.auth.async_get_users())
await client.send_json({"id": 5, "type": "config/auth/create", "name": "Paulus"}) await client.send_json({"id": 5, "type": "config/auth/create", "name": "Paulus"})
result = await client.receive_json() result = await client.receive_json()
assert result["success"], result assert result["success"], result
assert len(await hass.auth.async_get_users()) == 2 assert len(await hass.auth.async_get_users()) == cur_users + 1
data_user = result["result"]["user"] data_user = result["result"]["user"]
user = await hass.auth.async_get_user(data_user["id"]) user = await hass.auth.async_get_user(data_user["id"])
assert user is not None assert user is not None
@ -188,7 +188,7 @@ async def test_create_user_group(hass, hass_ws_client, hass_access_token):
"""Test create user with a group.""" """Test create user with a group."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
assert len(await hass.auth.async_get_users()) == 1 cur_users = len(await hass.auth.async_get_users())
await client.send_json( await client.send_json(
{ {
@ -201,7 +201,7 @@ async def test_create_user_group(hass, hass_ws_client, hass_access_token):
result = await client.receive_json() result = await client.receive_json()
assert result["success"], result assert result["success"], result
assert len(await hass.auth.async_get_users()) == 2 assert len(await hass.auth.async_get_users()) == cur_users + 1
data_user = result["result"]["user"] data_user = result["result"]["user"]
user = await hass.auth.async_get_user(data_user["id"]) user = await hass.auth.async_get_user(data_user["id"])
assert user is not None assert user is not None

View File

@ -6,16 +6,29 @@ from unittest.mock import Mock, patch
from aiohttp import BasicAuth, web from aiohttp import BasicAuth, web
from aiohttp.web_exceptions import HTTPUnauthorized from aiohttp.web_exceptions import HTTPUnauthorized
import jwt
import pytest import pytest
import yarl
from homeassistant.auth.const import GROUP_ID_READ_ONLY
from homeassistant.auth.models import User
from homeassistant.auth.providers import trusted_networks from homeassistant.auth.providers import trusted_networks
from homeassistant.components import websocket_api
from homeassistant.components.http.auth import ( from homeassistant.components.http.auth import (
CONTENT_USER_NAME,
DATA_SIGN_SECRET,
STORAGE_KEY,
async_setup_auth,
async_sign_path, async_sign_path,
async_user_not_allowed_do_auth, async_user_not_allowed_do_auth,
setup_auth,
) )
from homeassistant.components.http.const import KEY_AUTHENTICATED from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.components.http.forwarded import async_setup_forwarded from homeassistant.components.http.forwarded import async_setup_forwarded
from homeassistant.components.http.request_context import (
current_request,
setup_request_context,
)
from homeassistant.core import callback
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import HTTP_HEADER_HA_AUTH, mock_real_ip from . import HTTP_HEADER_HA_AUTH, mock_real_ip
@ -86,7 +99,7 @@ def trusted_networks_auth(hass):
async def test_auth_middleware_loaded_by_default(hass): async def test_auth_middleware_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off.""" """Test accessing to server from banned IP when feature is off."""
with patch("homeassistant.components.http.setup_auth") as mock_setup: with patch("homeassistant.components.http.async_setup_auth") as mock_setup:
await async_setup_component(hass, "http", {"http": {}}) await async_setup_component(hass, "http", {"http": {}})
assert len(mock_setup.mock_calls) == 1 assert len(mock_setup.mock_calls) == 1
@ -96,7 +109,7 @@ async def test_cant_access_with_password_in_header(
app, aiohttp_client, legacy_auth, hass app, aiohttp_client, legacy_auth, hass
): ):
"""Test access with password in header.""" """Test access with password in header."""
setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD}) req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
@ -110,7 +123,7 @@ async def test_cant_access_with_password_in_query(
app, aiohttp_client, legacy_auth, hass app, aiohttp_client, legacy_auth, hass
): ):
"""Test access with password in URL.""" """Test access with password in URL."""
setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
resp = await client.get("/", params={"api_password": API_PASSWORD}) resp = await client.get("/", params={"api_password": API_PASSWORD})
@ -125,7 +138,7 @@ async def test_cant_access_with_password_in_query(
async def test_basic_auth_does_not_work(app, aiohttp_client, hass, legacy_auth): async def test_basic_auth_does_not_work(app, aiohttp_client, hass, legacy_auth):
"""Test access with basic authentication.""" """Test access with basic authentication."""
setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
req = await client.get("/", auth=BasicAuth("homeassistant", API_PASSWORD)) req = await client.get("/", auth=BasicAuth("homeassistant", API_PASSWORD))
@ -145,7 +158,7 @@ async def test_cannot_access_with_trusted_ip(
hass, app2, trusted_networks_auth, aiohttp_client, hass_owner_user hass, app2, trusted_networks_auth, aiohttp_client, hass_owner_user
): ):
"""Test access with an untrusted ip address.""" """Test access with an untrusted ip address."""
setup_auth(hass, app2) await async_setup_auth(hass, app2)
set_mock_ip = mock_real_ip(app2) set_mock_ip = mock_real_ip(app2)
client = await aiohttp_client(app2) client = await aiohttp_client(app2)
@ -170,7 +183,7 @@ async def test_auth_active_access_with_access_token_in_header(
): ):
"""Test access with access token in header.""" """Test access with access token in header."""
token = hass_access_token token = hass_access_token
setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
@ -202,7 +215,7 @@ async def test_auth_active_access_with_trusted_ip(
hass, app2, trusted_networks_auth, aiohttp_client, hass_owner_user hass, app2, trusted_networks_auth, aiohttp_client, hass_owner_user
): ):
"""Test access with an untrusted ip address.""" """Test access with an untrusted ip address."""
setup_auth(hass, app2) await async_setup_auth(hass, app2)
set_mock_ip = mock_real_ip(app2) set_mock_ip = mock_real_ip(app2)
client = await aiohttp_client(app2) client = await aiohttp_client(app2)
@ -226,7 +239,7 @@ async def test_auth_legacy_support_api_password_cannot_access(
app, aiohttp_client, legacy_auth, hass app, aiohttp_client, legacy_auth, hass
): ):
"""Test access using api_password if auth.support_legacy.""" """Test access using api_password if auth.support_legacy."""
setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD}) req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
@ -239,16 +252,20 @@ async def test_auth_legacy_support_api_password_cannot_access(
assert req.status == HTTPStatus.UNAUTHORIZED assert req.status == HTTPStatus.UNAUTHORIZED
async def test_auth_access_signed_path(hass, app, aiohttp_client, hass_access_token): async def test_auth_access_signed_path_with_refresh_token(
hass, app, aiohttp_client, hass_access_token
):
"""Test access with signed url.""" """Test access with signed url."""
app.router.add_post("/", mock_handler) app.router.add_post("/", mock_handler)
app.router.add_get("/another_path", mock_handler) app.router.add_get("/another_path", mock_handler)
setup_auth(hass, app) await async_setup_auth(hass, app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
signed_path = async_sign_path(hass, refresh_token.id, "/", timedelta(seconds=5)) signed_path = async_sign_path(
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
)
req = await client.get(signed_path) req = await client.get(signed_path)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
@ -265,7 +282,7 @@ async def test_auth_access_signed_path(hass, app, aiohttp_client, hass_access_to
# Never valid as expired in the past. # Never valid as expired in the past.
expired_signed_path = async_sign_path( expired_signed_path = async_sign_path(
hass, refresh_token.id, "/", timedelta(seconds=-5) hass, "/", timedelta(seconds=-5), refresh_token_id=refresh_token.id
) )
req = await client.get(expired_signed_path) req = await client.get(expired_signed_path)
@ -277,10 +294,94 @@ async def test_auth_access_signed_path(hass, app, aiohttp_client, hass_access_to
assert req.status == HTTPStatus.UNAUTHORIZED assert req.status == HTTPStatus.UNAUTHORIZED
async def test_auth_access_signed_path_via_websocket(
hass, app, hass_ws_client, hass_read_only_access_token
):
"""Test signed url via websockets uses connection user."""
@websocket_api.websocket_command({"type": "diagnostics/list"})
@callback
def get_signed_path(hass, connection, msg):
connection.send_result(
msg["id"], {"path": async_sign_path(hass, "/", timedelta(seconds=5))}
)
websocket_api.async_register_command(hass, get_signed_path)
# We use hass_read_only_access_token to make sure the connection WS is used.
client = await hass_ws_client(access_token=hass_read_only_access_token)
await client.send_json({"id": 5, "type": "diagnostics/list"})
msg = await client.receive_json()
assert msg["id"] == 5
assert msg["success"]
refresh_token = await hass.auth.async_validate_access_token(
hass_read_only_access_token
)
signature = yarl.URL(msg["result"]["path"]).query["authSig"]
claims = jwt.decode(
signature,
hass.data[DATA_SIGN_SECRET],
algorithms=["HS256"],
options={"verify_signature": False},
)
assert claims["iss"] == refresh_token.id
async def test_auth_access_signed_path_with_http(
hass, app, aiohttp_client, hass_access_token
):
"""Test signed url via HTTP uses HTTP user."""
setup_request_context(app, current_request)
async def mock_handler(request):
"""Return signed path."""
return web.json_response(
data={"path": async_sign_path(hass, "/", timedelta(seconds=-5))}
)
app.router.add_get("/hello", mock_handler)
await async_setup_auth(hass, app)
client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
req = await client.get(
"/hello", headers={"Authorization": f"Bearer {hass_access_token}"}
)
assert req.status == HTTPStatus.OK
data = await req.json()
signature = yarl.URL(data["path"]).query["authSig"]
claims = jwt.decode(
signature,
hass.data[DATA_SIGN_SECRET],
algorithms=["HS256"],
options={"verify_signature": False},
)
assert claims["iss"] == refresh_token.id
async def test_auth_access_signed_path_with_content_user(hass, app, aiohttp_client):
"""Test access signed url uses content user."""
await async_setup_auth(hass, app)
signed_path = async_sign_path(hass, "/", timedelta(seconds=5))
signature = yarl.URL(signed_path).query["authSig"]
claims = jwt.decode(
signature,
hass.data[DATA_SIGN_SECRET],
algorithms=["HS256"],
options={"verify_signature": False},
)
assert claims["iss"] == hass.data[STORAGE_KEY]
async def test_local_only_user_rejected(hass, app, aiohttp_client, hass_access_token): async def test_local_only_user_rejected(hass, app, aiohttp_client, hass_access_token):
"""Test access with access token in header.""" """Test access with access token in header."""
token = hass_access_token token = hass_access_token
setup_auth(hass, app) await async_setup_auth(hass, app)
set_mock_ip = mock_real_ip(app) set_mock_ip = mock_real_ip(app)
client = await aiohttp_client(app) client = await aiohttp_client(app)
refresh_token = await hass.auth.async_validate_access_token(hass_access_token) refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
@ -340,3 +441,25 @@ async def test_async_user_not_allowed_do_auth(hass, app):
async_user_not_allowed_do_auth(hass, user, trusted_request) async_user_not_allowed_do_auth(hass, user, trusted_request)
== "User is local only" == "User is local only"
) )
async def test_create_user_once(hass):
"""Test that we reuse the user."""
cur_users = len(await hass.auth.async_get_users())
app = web.Application()
await async_setup_auth(hass, app)
users = await hass.auth.async_get_users()
assert len(users) == cur_users + 1
user: User = next((user for user in users if user.name == CONTENT_USER_NAME), None)
assert user is not None, users
assert len(user.groups) == 1
assert user.groups[0].id == GROUP_ID_READ_ONLY
assert len(user.refresh_tokens) == 1
assert user.system_generated
await async_setup_auth(hass, app)
# test it did not create a user
assert len(await hass.auth.async_get_users()) == cur_users + 1

View File

@ -139,6 +139,7 @@ async def test_onboarding_user(hass, hass_storage, hass_client_no_auth):
assert await async_setup_component(hass, "onboarding", {}) assert await async_setup_component(hass, "onboarding", {})
await hass.async_block_till_done() await hass.async_block_till_done()
cur_users = len(await hass.auth.async_get_users())
client = await hass_client_no_auth() client = await hass_client_no_auth()
resp = await client.post( resp = await client.post(
@ -159,9 +160,9 @@ async def test_onboarding_user(hass, hass_storage, hass_client_no_auth):
assert "auth_code" in data assert "auth_code" in data
users = await hass.auth.async_get_users() users = await hass.auth.async_get_users()
assert len(users) == 1 assert len(await hass.auth.async_get_users()) == cur_users + 1
user = users[0] user = next((user for user in users if user.name == "Test Name"), None)
assert user.name == "Test Name" assert user is not None
assert len(user.credentials) == 1 assert len(user.credentials) == 1
assert user.credentials[0].data["username"] == "test-user" assert user.credentials[0].data["username"] == "test-user"
assert len(hass.data["person"][1].async_items()) == 1 assert len(hass.data["person"][1].async_items()) == 1
@ -287,7 +288,7 @@ async def test_onboarding_integration(hass, hass_storage, hass_client, hass_admi
) )
# Onboarding refresh token and new refresh token # Onboarding refresh token and new refresh token
for user in await hass.auth.async_get_users(): user = await hass.auth.async_get_user(hass_admin_user.id)
assert len(user.refresh_tokens) == 2, user assert len(user.refresh_tokens) == 2, user

View File

@ -503,22 +503,15 @@ def hass_ws_client(aiohttp_client, hass_access_token, hass, socket_enabled):
async def create_client(hass=hass, access_token=hass_access_token): async def create_client(hass=hass, access_token=hass_access_token):
"""Create a websocket client.""" """Create a websocket client."""
assert await async_setup_component(hass, "websocket_api", {}) assert await async_setup_component(hass, "websocket_api", {})
client = await aiohttp_client(hass.http.app) client = await aiohttp_client(hass.http.app)
with patch("homeassistant.components.http.auth.setup_auth"):
websocket = await client.ws_connect(URL) websocket = await client.ws_connect(URL)
auth_resp = await websocket.receive_json() auth_resp = await websocket.receive_json()
assert auth_resp["type"] == TYPE_AUTH_REQUIRED assert auth_resp["type"] == TYPE_AUTH_REQUIRED
if access_token is None: if access_token is None:
await websocket.send_json( await websocket.send_json({"type": TYPE_AUTH, "access_token": "incorrect"})
{"type": TYPE_AUTH, "access_token": "incorrect"}
)
else: else:
await websocket.send_json( await websocket.send_json({"type": TYPE_AUTH, "access_token": access_token})
{"type": TYPE_AUTH, "access_token": access_token}
)
auth_ok = await websocket.receive_json() auth_ok = await websocket.receive_json()
assert auth_ok["type"] == TYPE_AUTH_OK assert auth_ok["type"] == TYPE_AUTH_OK