From 348e1df949bae31fd1d3f805d5e892bba608e120 Mon Sep 17 00:00:00 2001 From: Robert Resch Date: Fri, 12 Apr 2024 14:47:46 +0200 Subject: [PATCH] Add strict connection (#112387) Co-authored-by: Martin Hjelmare --- homeassistant/auth/__init__.py | 8 +- homeassistant/auth/session.py | 205 +++++++++++ homeassistant/components/auth/__init__.py | 32 +- homeassistant/components/hassio/ingress.py | 1 - homeassistant/components/http/__init__.py | 93 ++++- homeassistant/components/http/auth.py | 94 ++++- homeassistant/components/http/const.py | 9 + homeassistant/components/http/icons.json | 5 + homeassistant/components/http/services.yaml | 1 + homeassistant/components/http/session.py | 160 +++++++++ .../http/strict_connection_static_page.html | 46 +++ homeassistant/components/http/strings.json | 16 + homeassistant/package_constraints.txt | 1 + pyproject.toml | 1 + requirements.txt | 1 + tests/components/api/test_init.py | 2 +- tests/components/http/test_auth.py | 339 ++++++++++++++++-- tests/components/http/test_init.py | 79 ++++ tests/components/http/test_session.py | 107 ++++++ tests/components/stream/conftest.py | 20 +- .../components/websocket_api/test_commands.py | 2 +- tests/helpers/test_service.py | 27 +- tests/scripts/test_check_config.py | 2 + 23 files changed, 1187 insertions(+), 64 deletions(-) create mode 100644 homeassistant/auth/session.py create mode 100644 homeassistant/components/http/icons.json create mode 100644 homeassistant/components/http/services.yaml create mode 100644 homeassistant/components/http/session.py create mode 100644 homeassistant/components/http/strict_connection_static_page.html create mode 100644 homeassistant/components/http/strings.json create mode 100644 tests/components/http/test_session.py diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 969fcc3529e..2a9525181f6 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -28,6 +28,7 @@ from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRA from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config from .models import AuthFlowResult from .providers import AuthProvider, LoginFlow, auth_provider_from_config +from .session import SessionManager EVENT_USER_ADDED = "user_added" EVENT_USER_UPDATED = "user_updated" @@ -85,7 +86,7 @@ async def auth_manager_from_config( module_hash[module.id] = module manager = AuthManager(hass, store, provider_hash, module_hash) - manager.async_setup() + await manager.async_setup() return manager @@ -180,9 +181,9 @@ class AuthManager: self._remove_expired_job = HassJob( self._async_remove_expired_refresh_tokens, job_type=HassJobType.Callback ) + self.session = SessionManager(hass, self) - @callback - def async_setup(self) -> None: + async def async_setup(self) -> None: """Set up the auth manager.""" hass = self.hass hass.async_add_shutdown_job( @@ -191,6 +192,7 @@ class AuthManager: ) ) self._async_track_next_refresh_token_expiration() + await self.session.async_setup() @property def auth_providers(self) -> list[AuthProvider]: diff --git a/homeassistant/auth/session.py b/homeassistant/auth/session.py new file mode 100644 index 00000000000..88297b50d90 --- /dev/null +++ b/homeassistant/auth/session.py @@ -0,0 +1,205 @@ +"""Session auth module.""" + +from __future__ import annotations + +from datetime import datetime, timedelta +import secrets +from typing import TYPE_CHECKING, Final, TypedDict + +from aiohttp.web import Request +from aiohttp_session import Session, get_session, new_session +from cryptography.fernet import Fernet + +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback +from homeassistant.helpers.event import async_call_later +from homeassistant.helpers.storage import Store +from homeassistant.util import dt as dt_util + +from .models import RefreshToken + +if TYPE_CHECKING: + from . import AuthManager + + +TEMP_TIMEOUT = timedelta(minutes=5) +TEMP_TIMEOUT_SECONDS = TEMP_TIMEOUT.total_seconds() + +SESSION_ID = "id" +STORAGE_VERSION = 1 +STORAGE_KEY = "auth.session" + + +class StrictConnectionTempSessionData: + """Data for accessing unauthorized resources for a short period of time.""" + + __slots__ = ("cancel_remove", "absolute_expiry") + + def __init__(self, cancel_remove: CALLBACK_TYPE) -> None: + """Initialize the temp session data.""" + self.cancel_remove: Final[CALLBACK_TYPE] = cancel_remove + self.absolute_expiry: Final[datetime] = dt_util.utcnow() + TEMP_TIMEOUT + + +class StoreData(TypedDict): + """Data to store.""" + + unauthorized_sessions: dict[str, str] + key: str + + +class SessionManager: + """Session manager.""" + + def __init__(self, hass: HomeAssistant, auth: AuthManager) -> None: + """Initialize the strict connection manager.""" + self._auth = auth + self._hass = hass + self._temp_sessions: dict[str, StrictConnectionTempSessionData] = {} + self._strict_connection_sessions: dict[str, str] = {} + self._store = Store[StoreData]( + hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True + ) + self._key: str | None = None + self._refresh_token_revoke_callbacks: dict[str, CALLBACK_TYPE] = {} + + @property + def key(self) -> str: + """Return the encryption key.""" + if self._key is None: + self._key = Fernet.generate_key().decode() + self._async_schedule_save() + return self._key + + async def async_validate_request_for_strict_connection_session( + self, + request: Request, + ) -> bool: + """Check if a request has a valid strict connection session.""" + session = await get_session(request) + if session.new or session.empty: + return False + result = self.async_validate_strict_connection_session(session) + if result is False: + session.invalidate() + return result + + @callback + def async_validate_strict_connection_session( + self, + session: Session, + ) -> bool: + """Validate a strict connection session.""" + if not (session_id := session.get(SESSION_ID)): + return False + + if token_id := self._strict_connection_sessions.get(session_id): + if self._auth.async_get_refresh_token(token_id): + return True + # refresh token is invalid, delete entry + self._strict_connection_sessions.pop(session_id) + self._async_schedule_save() + + if data := self._temp_sessions.get(session_id): + if dt_util.utcnow() <= data.absolute_expiry: + return True + # session expired, delete entry + self._temp_sessions.pop(session_id).cancel_remove() + + return False + + @callback + def _async_register_revoke_token_callback(self, refresh_token_id: str) -> None: + """Register a callback to revoke all sessions for a refresh token.""" + if refresh_token_id in self._refresh_token_revoke_callbacks: + return + + @callback + def async_invalidate_auth_sessions() -> None: + """Invalidate all sessions for a refresh token.""" + self._strict_connection_sessions = { + session_id: token_id + for session_id, token_id in self._strict_connection_sessions.items() + if token_id != refresh_token_id + } + self._async_schedule_save() + + self._refresh_token_revoke_callbacks[refresh_token_id] = ( + self._auth.async_register_revoke_token_callback( + refresh_token_id, async_invalidate_auth_sessions + ) + ) + + async def async_create_session( + self, + request: Request, + refresh_token: RefreshToken, + ) -> None: + """Create new session for given refresh token. + + Caller needs to make sure that the refresh token is valid. + By creating a session, we are implicitly revoking all other + sessions for the given refresh token as there is one refresh + token per device/user case. + """ + self._strict_connection_sessions = { + session_id: token_id + for session_id, token_id in self._strict_connection_sessions.items() + if token_id != refresh_token.id + } + + self._async_register_revoke_token_callback(refresh_token.id) + session_id = await self._async_create_new_session(request) + self._strict_connection_sessions[session_id] = refresh_token.id + self._async_schedule_save() + + async def async_create_temp_unauthorized_session(self, request: Request) -> None: + """Create a temporary unauthorized session.""" + session_id = await self._async_create_new_session( + request, max_age=int(TEMP_TIMEOUT_SECONDS) + ) + + @callback + def remove(_: datetime) -> None: + self._temp_sessions.pop(session_id, None) + + self._temp_sessions[session_id] = StrictConnectionTempSessionData( + async_call_later(self._hass, TEMP_TIMEOUT_SECONDS, remove) + ) + + async def _async_create_new_session( + self, + request: Request, + *, + max_age: int | None = None, + ) -> str: + session_id = secrets.token_hex(64) + + session = await new_session(request) + session[SESSION_ID] = session_id + if max_age is not None: + session.max_age = max_age + return session_id + + @callback + def _async_schedule_save(self, delay: float = 1) -> None: + """Save sessions.""" + self._store.async_delay_save(self._data_to_save, delay) + + @callback + def _data_to_save(self) -> StoreData: + """Return the data to store.""" + return StoreData( + unauthorized_sessions=self._strict_connection_sessions, + key=self.key, + ) + + async def async_setup(self) -> None: + """Set up session manager.""" + data = await self._store.async_load() + if data is None: + return + + self._key = data["key"] + self._strict_connection_sessions = data["unauthorized_sessions"] + for token_id in self._strict_connection_sessions.values(): + self._async_register_revoke_token_callback(token_id) diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index ff54971eb64..3d825cd99b5 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -162,6 +162,7 @@ from homeassistant.util import dt as dt_util from . import indieauth, login_flow, mfa_setup_flow DOMAIN = "auth" +STRICT_CONNECTION_URL = "/auth/strict_connection/temp_token" StoreResultType = Callable[[str, Credentials], str] RetrieveResultType = Callable[[str, str], Credentials | None] @@ -187,6 +188,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: hass.http.register_view(RevokeTokenView()) hass.http.register_view(LinkUserView(retrieve_result)) hass.http.register_view(OAuth2AuthorizeCallbackView()) + hass.http.register_view(StrictConnectionTempTokenView()) websocket_api.async_register_command(hass, websocket_current_user) websocket_api.async_register_command(hass, websocket_create_long_lived_access_token) @@ -260,10 +262,10 @@ class TokenView(HomeAssistantView): return await RevokeTokenView.post(self, request) # type: ignore[arg-type] if grant_type == "authorization_code": - return await self._async_handle_auth_code(hass, data, request.remote) + return await self._async_handle_auth_code(hass, data, request) if grant_type == "refresh_token": - return await self._async_handle_refresh_token(hass, data, request.remote) + return await self._async_handle_refresh_token(hass, data, request) return self.json( {"error": "unsupported_grant_type"}, status_code=HTTPStatus.BAD_REQUEST @@ -273,7 +275,7 @@ class TokenView(HomeAssistantView): self, hass: HomeAssistant, data: MultiDictProxy[str], - remote_addr: str | None, + request: web.Request, ) -> web.Response: """Handle authorization code request.""" client_id = data.get("client_id") @@ -313,7 +315,7 @@ class TokenView(HomeAssistantView): ) try: access_token = hass.auth.async_create_access_token( - refresh_token, remote_addr + refresh_token, request.remote ) except InvalidAuthError as exc: return self.json( @@ -321,6 +323,7 @@ class TokenView(HomeAssistantView): status_code=HTTPStatus.FORBIDDEN, ) + await hass.auth.session.async_create_session(request, refresh_token) return self.json( { "access_token": access_token, @@ -341,9 +344,9 @@ class TokenView(HomeAssistantView): self, hass: HomeAssistant, data: MultiDictProxy[str], - remote_addr: str | None, + request: web.Request, ) -> web.Response: - """Handle authorization code request.""" + """Handle refresh token request.""" client_id = data.get("client_id") if client_id is not None and not indieauth.verify_client_id(client_id): return self.json( @@ -381,7 +384,7 @@ class TokenView(HomeAssistantView): try: access_token = hass.auth.async_create_access_token( - refresh_token, remote_addr + refresh_token, request.remote ) except InvalidAuthError as exc: return self.json( @@ -389,6 +392,7 @@ class TokenView(HomeAssistantView): status_code=HTTPStatus.FORBIDDEN, ) + await hass.auth.session.async_create_session(request, refresh_token) return self.json( { "access_token": access_token, @@ -437,6 +441,20 @@ class LinkUserView(HomeAssistantView): return self.json_message("User linked") +class StrictConnectionTempTokenView(HomeAssistantView): + """View to get temporary strict connection token.""" + + url = STRICT_CONNECTION_URL + name = "api:auth:strict_connection:temp_token" + requires_auth = False + + async def get(self, request: web.Request) -> web.Response: + """Get a temporary token and redirect to main page.""" + hass = request.app[KEY_HASS] + await hass.auth.session.async_create_temp_unauthorized_session(request) + raise web.HTTPSeeOther(location="/") + + @callback def _create_auth_code_store() -> tuple[StoreResultType, RetrieveResultType]: """Create an in memory store.""" diff --git a/homeassistant/components/hassio/ingress.py b/homeassistant/components/hassio/ingress.py index 6d6faa6fe75..ed6e47145dd 100644 --- a/homeassistant/components/hassio/ingress.py +++ b/homeassistant/components/hassio/ingress.py @@ -197,7 +197,6 @@ class HassIOIngress(HomeAssistantView): content_type or simple_response.content_type ): simple_response.enable_compression() - await simple_response.prepare(request) return simple_response # Stream response diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index e89031cb265..3e5f7333cbc 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -10,7 +10,8 @@ import os import socket import ssl from tempfile import NamedTemporaryFile -from typing import Any, Final, TypedDict, cast +from typing import Any, Final, Required, TypedDict, cast +from urllib.parse import quote_plus, urljoin from aiohttp import web from aiohttp.abc import AbstractStreamWriter @@ -30,8 +31,20 @@ from yarl import URL from homeassistant.components.network import async_get_source_ip from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT -from homeassistant.core import Event, HomeAssistant -from homeassistant.exceptions import HomeAssistantError +from homeassistant.core import ( + Event, + HomeAssistant, + ServiceCall, + ServiceResponse, + SupportsResponse, + callback, +) +from homeassistant.exceptions import ( + HomeAssistantError, + ServiceValidationError, + Unauthorized, + UnknownUser, +) from homeassistant.helpers import storage import homeassistant.helpers.config_validation as cv from homeassistant.helpers.http import ( @@ -53,9 +66,13 @@ from homeassistant.util import dt as dt_util, ssl as ssl_util from homeassistant.util.async_ import create_eager_task from homeassistant.util.json import json_loads -from .auth import async_setup_auth +from .auth import async_setup_auth, async_sign_path from .ban import setup_bans -from .const import KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER # noqa: F401 +from .const import ( # noqa: F401 + KEY_HASS_REFRESH_TOKEN_ID, + KEY_HASS_USER, + StrictConnectionMode, +) from .cors import setup_cors from .decorators import require_admin # noqa: F401 from .forwarded import async_setup_forwarded @@ -80,6 +97,7 @@ CONF_TRUSTED_PROXIES: Final = "trusted_proxies" CONF_LOGIN_ATTEMPTS_THRESHOLD: Final = "login_attempts_threshold" CONF_IP_BAN_ENABLED: Final = "ip_ban_enabled" CONF_SSL_PROFILE: Final = "ssl_profile" +CONF_STRICT_CONNECTION: Final = "strict_connection" SSL_MODERN: Final = "modern" SSL_INTERMEDIATE: Final = "intermediate" @@ -129,6 +147,9 @@ HTTP_SCHEMA: Final = vol.All( [SSL_INTERMEDIATE, SSL_MODERN] ), vol.Optional(CONF_USE_X_FRAME_OPTIONS, default=True): cv.boolean, + vol.Optional( + CONF_STRICT_CONNECTION, default=StrictConnectionMode.DISABLED + ): vol.In([e.value for e in StrictConnectionMode]), } ), ) @@ -152,6 +173,7 @@ class ConfData(TypedDict, total=False): login_attempts_threshold: int ip_ban_enabled: bool ssl_profile: str + strict_connection: Required[StrictConnectionMode] @bind_hass @@ -218,6 +240,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: login_threshold=login_threshold, is_ban_enabled=is_ban_enabled, use_x_frame_options=use_x_frame_options, + strict_connection_non_cloud=conf[CONF_STRICT_CONNECTION], ) async def stop_server(event: Event) -> None: @@ -247,6 +270,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: local_ip, host, server_port, ssl_certificate is not None ) + _setup_services(hass, conf) return True @@ -331,6 +355,7 @@ class HomeAssistantHTTP: login_threshold: int, is_ban_enabled: bool, use_x_frame_options: bool, + strict_connection_non_cloud: StrictConnectionMode, ) -> None: """Initialize the server.""" self.app[KEY_HASS] = self.hass @@ -347,7 +372,7 @@ class HomeAssistantHTTP: if is_ban_enabled: setup_bans(self.hass, self.app, login_threshold) - await async_setup_auth(self.hass, self.app) + await async_setup_auth(self.hass, self.app, strict_connection_non_cloud) setup_headers(self.app, use_x_frame_options) setup_cors(self.app, cors_origins) @@ -577,3 +602,59 @@ async def start_http_server_and_save_config( ] store.async_delay_save(lambda: conf, SAVE_DELAY) + + +@callback +def _setup_services(hass: HomeAssistant, conf: ConfData) -> None: + """Set up services for HTTP component.""" + + async def create_temporary_strict_connection_url( + call: ServiceCall, + ) -> ServiceResponse: + """Create a strict connection url and return it.""" + # Copied form homeassistant/helpers/service.py#_async_admin_handler + # as the helper supports no responses yet + if call.context.user_id: + user = await hass.auth.async_get_user(call.context.user_id) + if user is None: + raise UnknownUser(context=call.context) + if not user.is_admin: + raise Unauthorized(context=call.context) + + if conf[CONF_STRICT_CONNECTION] is StrictConnectionMode.DISABLED: + raise ServiceValidationError( + translation_domain=DOMAIN, + translation_key="strict_connection_not_enabled_non_cloud", + ) + + try: + url = get_url(hass, prefer_external=True, allow_internal=False) + except NoURLAvailableError as ex: + raise ServiceValidationError( + translation_domain=DOMAIN, + translation_key="no_external_url_available", + ) from ex + + # to avoid circular import + # pylint: disable-next=import-outside-toplevel + from homeassistant.components.auth import STRICT_CONNECTION_URL + + path = async_sign_path( + hass, + STRICT_CONNECTION_URL, + datetime.timedelta(hours=1), + use_content_user=True, + ) + url = urljoin(url, path) + + return { + "url": f"https://login.home-assistant.io?u={quote_plus(url)}", + "direct_url": url, + } + + hass.services.async_register( + DOMAIN, + "create_temporary_strict_connection_url", + create_temporary_strict_connection_url, + supports_response=SupportsResponse.ONLY, + ) diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py index 2073c998384..1eb74289089 100644 --- a/homeassistant/components/http/auth.py +++ b/homeassistant/components/http/auth.py @@ -4,14 +4,18 @@ from __future__ import annotations from collections.abc import Awaitable, Callable from datetime import timedelta +from http import HTTPStatus from ipaddress import ip_address import logging +import os import secrets import time from typing import Any, Final from aiohttp import hdrs -from aiohttp.web import Application, Request, StreamResponse, middleware +from aiohttp.web import Application, Request, Response, StreamResponse, middleware +from aiohttp.web_exceptions import HTTPBadRequest +from aiohttp_session import session_middleware import jwt from jwt import api_jws from yarl import URL @@ -27,7 +31,13 @@ from homeassistant.helpers.network import is_cloud_connection from homeassistant.helpers.storage import Store from homeassistant.util.network import is_local -from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER +from .const import ( + KEY_AUTHENTICATED, + KEY_HASS_REFRESH_TOKEN_ID, + KEY_HASS_USER, + StrictConnectionMode, +) +from .session import HomeAssistantCookieStorage _LOGGER = logging.getLogger(__name__) @@ -39,6 +49,10 @@ SAFE_QUERY_PARAMS: Final = ["height", "width"] STORAGE_VERSION = 1 STORAGE_KEY = "http.auth" CONTENT_USER_NAME = "Home Assistant Content" +STRICT_CONNECTION_EXCLUDED_PATH = "/api/webhook/" +STRICT_CONNECTION_STATIC_PAGE = os.path.join( + os.path.dirname(__file__), "strict_connection_static_page.html" +) @callback @@ -48,13 +62,16 @@ def async_sign_path( expiration: timedelta, *, refresh_token_id: str | None = None, + use_content_user: bool = False, ) -> str: """Sign a path for temporary access without auth header.""" if (secret := hass.data.get(DATA_SIGN_SECRET)) is None: secret = hass.data[DATA_SIGN_SECRET] = secrets.token_hex() if refresh_token_id is None: - if connection := websocket_api.current_connection.get(): + if use_content_user: + refresh_token_id = hass.data[STORAGE_KEY] + elif connection := websocket_api.current_connection.get(): refresh_token_id = connection.refresh_token_id elif ( request := current_request.get() @@ -114,7 +131,11 @@ def async_user_not_allowed_do_auth( return "User cannot authenticate remotely" -async def async_setup_auth(hass: HomeAssistant, app: Application) -> None: +async def async_setup_auth( + hass: HomeAssistant, + app: Application, + strict_connection_mode_non_cloud: StrictConnectionMode, +) -> None: """Create auth middleware for the app.""" store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY) if (data := await store.async_load()) is None: @@ -135,6 +156,16 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None: await store.async_save(data) hass.data[STORAGE_KEY] = refresh_token.id + strict_connection_static_file_content = None + if strict_connection_mode_non_cloud is StrictConnectionMode.STATIC_PAGE: + + def read_static_page() -> str: + with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file: + return file.read() + + strict_connection_static_file_content = await hass.async_add_executor_job( + read_static_page + ) @callback def async_validate_auth_header(request: Request) -> bool: @@ -224,6 +255,22 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None: authenticated = True auth_type = "signed request" + if ( + not authenticated + and strict_connection_mode_non_cloud is not StrictConnectionMode.DISABLED + and not request.path.startswith(STRICT_CONNECTION_EXCLUDED_PATH) + and not await hass.auth.session.async_validate_request_for_strict_connection_session( + request + ) + and ( + resp := _async_perform_action_on_non_local( + request, strict_connection_static_file_content + ) + ) + is not None + ): + return resp + if authenticated and _LOGGER.isEnabledFor(logging.DEBUG): _LOGGER.debug( "Authenticated %s for %s using %s", @@ -235,4 +282,43 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None: request[KEY_AUTHENTICATED] = authenticated return await handler(request) + app.middlewares.append(session_middleware(HomeAssistantCookieStorage(hass))) app.middlewares.append(auth_middleware) + + +@callback +def _async_perform_action_on_non_local( + request: Request, + strict_connection_static_file_content: str | None, +) -> StreamResponse | None: + """Perform strict connection mode action if the request is not local. + + The function does the following: + - Try to get the IP address of the request. If it fails, assume it's not local + - If the request is local, return None (allow the request to continue) + - If strict_connection_static_file_content is set, return a response with the content + - Otherwise close the connection and raise an exception + """ + try: + ip_address_ = ip_address(request.remote) # type: ignore[arg-type] + except ValueError: + _LOGGER.debug("Invalid IP address: %s", request.remote) + ip_address_ = None + + if ip_address_ and is_local(ip_address_): + return None + + _LOGGER.debug("Perform strict connection action for %s", ip_address_) + if strict_connection_static_file_content: + return Response( + text=strict_connection_static_file_content, + content_type="text/html", + status=HTTPStatus.IM_A_TEAPOT, + ) + + if transport := request.transport: + # it should never happen that we don't have a transport + transport.close() + + # We need to raise an exception to stop processing the request + raise HTTPBadRequest diff --git a/homeassistant/components/http/const.py b/homeassistant/components/http/const.py index 1254744f258..d02416c531b 100644 --- a/homeassistant/components/http/const.py +++ b/homeassistant/components/http/const.py @@ -1,8 +1,17 @@ """HTTP specific constants.""" +from enum import StrEnum from typing import Final from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401 KEY_HASS_USER: Final = "hass_user" KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id" + + +class StrictConnectionMode(StrEnum): + """Enum for strict connection mode.""" + + DISABLED = "disabled" + STATIC_PAGE = "static_page" + DROP_CONNECTION = "drop_connection" diff --git a/homeassistant/components/http/icons.json b/homeassistant/components/http/icons.json new file mode 100644 index 00000000000..8e8b6285db7 --- /dev/null +++ b/homeassistant/components/http/icons.json @@ -0,0 +1,5 @@ +{ + "services": { + "create_temporary_strict_connection_url": "mdi:login-variant" + } +} diff --git a/homeassistant/components/http/services.yaml b/homeassistant/components/http/services.yaml new file mode 100644 index 00000000000..16b0debb144 --- /dev/null +++ b/homeassistant/components/http/services.yaml @@ -0,0 +1 @@ +create_temporary_strict_connection_url: ~ diff --git a/homeassistant/components/http/session.py b/homeassistant/components/http/session.py new file mode 100644 index 00000000000..81668ec2ccc --- /dev/null +++ b/homeassistant/components/http/session.py @@ -0,0 +1,160 @@ +"""Session http module.""" + +from functools import lru_cache +import logging + +from aiohttp.web import Request, StreamResponse +from aiohttp_session import Session, SessionData +from aiohttp_session.cookie_storage import EncryptedCookieStorage +from cryptography.fernet import InvalidToken + +from homeassistant.auth.const import REFRESH_TOKEN_EXPIRATION +from homeassistant.core import HomeAssistant +from homeassistant.helpers.json import json_dumps +from homeassistant.helpers.network import is_cloud_connection +from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads + +from .ban import process_wrong_login + +_LOGGER = logging.getLogger(__name__) + +COOKIE_NAME = "SC" +PREFIXED_COOKIE_NAME = f"__Host-{COOKIE_NAME}" +SESSION_CACHE_SIZE = 16 + + +def _get_cookie_name(is_secure: bool) -> str: + """Return the cookie name.""" + return PREFIXED_COOKIE_NAME if is_secure else COOKIE_NAME + + +class HomeAssistantCookieStorage(EncryptedCookieStorage): + """Home Assistant cookie storage. + + Own class is required: + - to set the secure flag based on the connection type + - to use a LRU cache for session decryption + """ + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the cookie storage.""" + super().__init__( + hass.auth.session.key, + cookie_name=PREFIXED_COOKIE_NAME, + max_age=int(REFRESH_TOKEN_EXPIRATION), + httponly=True, + samesite="Lax", + secure=True, + encoder=json_dumps, + decoder=json_loads, + ) + self._hass = hass + + def _secure_connection(self, request: Request) -> bool: + """Return if the connection is secure (https).""" + return is_cloud_connection(self._hass) or request.secure + + def load_cookie(self, request: Request) -> str | None: + """Load cookie.""" + is_secure = self._secure_connection(request) + cookie_name = _get_cookie_name(is_secure) + return request.cookies.get(cookie_name) + + @lru_cache(maxsize=SESSION_CACHE_SIZE) + def _decrypt_cookie(self, cookie: str) -> Session | None: + """Decrypt and validate cookie.""" + try: + data = SessionData( # type: ignore[misc] + self._decoder( + self._fernet.decrypt( + cookie.encode("utf-8"), ttl=self.max_age + ).decode("utf-8") + ) + ) + except (InvalidToken, TypeError, ValueError, *JSON_DECODE_EXCEPTIONS): + _LOGGER.warning("Cannot decrypt/parse cookie value") + return None + + session = Session(None, data=data, new=data is None, max_age=self.max_age) + + # Validate session if not empty + if ( + not session.empty + and not self._hass.auth.session.async_validate_strict_connection_session( + session + ) + ): + # Invalidate session as it is not valid + session.invalidate() + + return session + + async def new_session(self) -> Session: + """Create a new session and mark it as changed.""" + session = Session(None, data=None, new=True, max_age=self.max_age) + session.changed() + return session + + async def load_session(self, request: Request) -> Session: + """Load session.""" + # Split parent function to use lru_cache + if (cookie := self.load_cookie(request)) is None: + return await self.new_session() + + if (session := self._decrypt_cookie(cookie)) is None: + # Decrypting/parsing failed, log wrong login and create a new session + await process_wrong_login(request) + session = await self.new_session() + + return session + + async def save_session( + self, request: Request, response: StreamResponse, session: Session + ) -> None: + """Save session.""" + + is_secure = self._secure_connection(request) + cookie_name = _get_cookie_name(is_secure) + + if session.empty: + response.del_cookie(cookie_name) + else: + params = self.cookie_params.copy() + params["secure"] = is_secure + params["max_age"] = session.max_age + + cookie_data = self._encoder(self._get_session_data(session)).encode("utf-8") + response.set_cookie( + cookie_name, + self._fernet.encrypt(cookie_data).decode("utf-8"), + **params, + ) + # Add Cache-Control header to not cache the cookie as it + # is used for session management + self._add_cache_control_header(response) + + @staticmethod + def _add_cache_control_header(response: StreamResponse) -> None: + """Add/set cache control header to no-cache="Set-Cookie".""" + # Structure of the Cache-Control header defined in + # https://datatracker.ietf.org/doc/html/rfc2068#section-14.9 + if header := response.headers.get("Cache-Control"): + directives = [] + for directive in header.split(","): + directive = directive.strip() + directive_lowered = directive.lower() + if directive_lowered.startswith("no-cache"): + if "set-cookie" in directive_lowered or directive.find("=") == -1: + # Set-Cookie is already in the no-cache directive or + # the whole request should not be cached -> Nothing to do + return + + # Add Set-Cookie to the no-cache + # [:-1] to remove the " at the end of the directive + directive = f"{directive[:-1]}, Set-Cookie" + + directives.append(directive) + header = ", ".join(directives) + else: + header = 'no-cache="Set-Cookie"' + response.headers["Cache-Control"] = header diff --git a/homeassistant/components/http/strict_connection_static_page.html b/homeassistant/components/http/strict_connection_static_page.html new file mode 100644 index 00000000000..24049d9a0eb --- /dev/null +++ b/homeassistant/components/http/strict_connection_static_page.html @@ -0,0 +1,46 @@ + + + + + + I'm a Teapot + + + +
+

Error 418: I'm a Teapot

+

+ Oops! Looks like the server is taking a coffee break.
+ Don't worry, it'll be back to brewing your requests in no time! +

+

+
+ + diff --git a/homeassistant/components/http/strings.json b/homeassistant/components/http/strings.json new file mode 100644 index 00000000000..7cd64f5f297 --- /dev/null +++ b/homeassistant/components/http/strings.json @@ -0,0 +1,16 @@ +{ + "exceptions": { + "strict_connection_not_enabled_non_cloud": { + "message": "Strict connection is not enabled for non-cloud requests" + }, + "no_external_url_available": { + "message": "No external URL available" + } + }, + "services": { + "create_temporary_strict_connection_url": { + "name": "Create a temporary strict connection URL", + "description": "Create a temporary strict connection URL, which can be used to login on another device." + } + } +} diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index 090271e028e..b253d600a2d 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -7,6 +7,7 @@ aiohttp-fast-url-dispatcher==0.3.0 aiohttp-zlib-ng==0.3.1 aiohttp==3.9.4 aiohttp_cors==0.7.0 +aiohttp_session==2.12.0 astral==2.2 async-interrupt==1.1.1 async-upnp-client==0.38.3 diff --git a/pyproject.toml b/pyproject.toml index 5ea335115ca..79a66cc7d82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "aiodns==3.2.0", "aiohttp==3.9.4", "aiohttp_cors==0.7.0", + "aiohttp_session==2.12.0", "aiohttp-fast-url-dispatcher==0.3.0", "aiohttp-zlib-ng==0.3.1", "astral==2.2", diff --git a/requirements.txt b/requirements.txt index 3cd1e8edfa5..f2f26f9bb54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ aiodns==3.2.0 aiohttp==3.9.4 aiohttp_cors==0.7.0 +aiohttp_session==2.12.0 aiohttp-fast-url-dispatcher==0.3.0 aiohttp-zlib-ng==0.3.1 astral==2.2 diff --git a/tests/components/api/test_init.py b/tests/components/api/test_init.py index 0ac2e5973fe..5443d48452f 100644 --- a/tests/components/api/test_init.py +++ b/tests/components/api/test_init.py @@ -306,7 +306,7 @@ async def test_api_get_services( for serv_domain in data: local = local_services.pop(serv_domain["domain"]) - assert serv_domain["services"] == local + assert serv_domain["services"].keys() == local.keys() async def test_api_call_service_no_data( diff --git a/tests/components/http/test_auth.py b/tests/components/http/test_auth.py index de6f323bc8a..f0f87e58173 100644 --- a/tests/components/http/test_auth.py +++ b/tests/components/http/test_auth.py @@ -1,22 +1,28 @@ """The tests for the Home Assistant HTTP component.""" +from collections.abc import Awaitable, Callable from datetime import timedelta from http import HTTPStatus from ipaddress import ip_network +import logging from unittest.mock import Mock, patch -from aiohttp import BasicAuth, web +from aiohttp import BasicAuth, ServerDisconnectedError, web +from aiohttp.test_utils import TestClient from aiohttp.web_exceptions import HTTPUnauthorized +from aiohttp_session import get_session import jwt import pytest import yarl +from yarl import URL from homeassistant.auth.const import GROUP_ID_READ_ONLY -from homeassistant.auth.models import User +from homeassistant.auth.models import RefreshToken, User from homeassistant.auth.providers import trusted_networks from homeassistant.auth.providers.legacy_api_password import ( LegacyApiPasswordAuthProvider, ) +from homeassistant.auth.session import SESSION_ID, TEMP_TIMEOUT from homeassistant.components import websocket_api from homeassistant.components.http import KEY_HASS from homeassistant.components.http.auth import ( @@ -24,11 +30,12 @@ from homeassistant.components.http.auth import ( DATA_SIGN_SECRET, SIGN_QUERY_PARAM, STORAGE_KEY, + STRICT_CONNECTION_STATIC_PAGE, async_setup_auth, async_sign_path, async_user_not_allowed_do_auth, ) -from homeassistant.components.http.const import KEY_AUTHENTICATED +from homeassistant.components.http.const import KEY_AUTHENTICATED, StrictConnectionMode from homeassistant.components.http.forwarded import async_setup_forwarded from homeassistant.components.http.request_context import ( current_request, @@ -36,13 +43,15 @@ from homeassistant.components.http.request_context import ( ) from homeassistant.core import HomeAssistant, callback from homeassistant.setup import async_setup_component +from homeassistant.util.dt import utcnow from . import HTTP_HEADER_HA_AUTH -from tests.common import MockUser +from tests.common import MockUser, async_fire_time_changed from tests.test_util import mock_real_ip from tests.typing import ClientSessionGenerator, WebSocketGenerator +_LOGGER = logging.getLogger(__name__) API_PASSWORD = "test-password" # Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases @@ -54,7 +63,13 @@ TRUSTED_NETWORKS = [ ] TRUSTED_ADDRESSES = ["100.64.0.1", "192.0.2.100", "FD01:DB8::1", "2001:DB8:ABCD::1"] EXTERNAL_ADDRESSES = ["198.51.100.1", "2001:DB8:FA1::1"] -UNTRUSTED_ADDRESSES = [*EXTERNAL_ADDRESSES, "127.0.0.1", "::1"] +LOCALHOST_ADDRESSES = ["127.0.0.1", "::1"] +UNTRUSTED_ADDRESSES = [*EXTERNAL_ADDRESSES, *LOCALHOST_ADDRESSES] +PRIVATE_ADDRESSES = [ + "192.168.10.10", + "172.16.4.20", + "10.100.50.5", +] async def mock_handler(request): @@ -122,7 +137,7 @@ async def test_cant_access_with_password_in_header( hass: HomeAssistant, ) -> None: """Test access with password in header.""" - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) client = await aiohttp_client(app) req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD}) @@ -139,7 +154,7 @@ async def test_cant_access_with_password_in_query( hass: HomeAssistant, ) -> None: """Test access with password in URL.""" - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) client = await aiohttp_client(app) resp = await client.get("/", params={"api_password": API_PASSWORD}) @@ -159,7 +174,7 @@ async def test_basic_auth_does_not_work( legacy_auth: LegacyApiPasswordAuthProvider, ) -> None: """Test access with basic authentication.""" - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) client = await aiohttp_client(app) req = await client.get("/", auth=BasicAuth("homeassistant", API_PASSWORD)) @@ -183,7 +198,7 @@ async def test_cannot_access_with_trusted_ip( hass_owner_user: MockUser, ) -> None: """Test access with an untrusted ip address.""" - await async_setup_auth(hass, app2) + await async_setup_auth(hass, app2, StrictConnectionMode.DISABLED) set_mock_ip = mock_real_ip(app2) client = await aiohttp_client(app2) @@ -211,7 +226,7 @@ async def test_auth_active_access_with_access_token_in_header( ) -> None: """Test access with access token in header.""" token = hass_access_token - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) client = await aiohttp_client(app) refresh_token = hass.auth.async_validate_access_token(hass_access_token) @@ -247,7 +262,7 @@ async def test_auth_active_access_with_trusted_ip( hass_owner_user: MockUser, ) -> None: """Test access with an untrusted ip address.""" - await async_setup_auth(hass, app2) + await async_setup_auth(hass, app2, StrictConnectionMode.DISABLED) set_mock_ip = mock_real_ip(app2) client = await aiohttp_client(app2) @@ -274,7 +289,7 @@ async def test_auth_legacy_support_api_password_cannot_access( hass: HomeAssistant, ) -> None: """Test access using api_password if auth.support_legacy.""" - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) client = await aiohttp_client(app) req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD}) @@ -296,7 +311,7 @@ async def test_auth_access_signed_path_with_refresh_token( """Test access with signed url.""" app.router.add_post("/", mock_handler) app.router.add_get("/another_path", mock_handler) - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) client = await aiohttp_client(app) refresh_token = hass.auth.async_validate_access_token(hass_access_token) @@ -341,7 +356,7 @@ async def test_auth_access_signed_path_with_query_param( """Test access with signed url and query params.""" app.router.add_post("/", mock_handler) app.router.add_get("/another_path", mock_handler) - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) client = await aiohttp_client(app) refresh_token = hass.auth.async_validate_access_token(hass_access_token) @@ -371,7 +386,7 @@ async def test_auth_access_signed_path_with_query_param_order( """Test access with signed url and query params different order.""" app.router.add_post("/", mock_handler) app.router.add_get("/another_path", mock_handler) - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) client = await aiohttp_client(app) refresh_token = hass.auth.async_validate_access_token(hass_access_token) @@ -412,7 +427,7 @@ async def test_auth_access_signed_path_with_query_param_safe_param( """Test access with signed url and changing a safe param.""" app.router.add_post("/", mock_handler) app.router.add_get("/another_path", mock_handler) - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) client = await aiohttp_client(app) refresh_token = hass.auth.async_validate_access_token(hass_access_token) @@ -451,7 +466,7 @@ async def test_auth_access_signed_path_with_query_param_tamper( """Test access with signed url and query params that have been tampered with.""" app.router.add_post("/", mock_handler) app.router.add_get("/another_path", mock_handler) - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) client = await aiohttp_client(app) refresh_token = hass.auth.async_validate_access_token(hass_access_token) @@ -520,7 +535,7 @@ async def test_auth_access_signed_path_with_http( ) app.router.add_get("/hello", mock_handler) - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) client = await aiohttp_client(app) refresh_token = hass.auth.async_validate_access_token(hass_access_token) @@ -544,7 +559,7 @@ async def test_auth_access_signed_path_with_content_user( hass: HomeAssistant, app, aiohttp_client: ClientSessionGenerator ) -> None: """Test access signed url uses content user.""" - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) signed_path = async_sign_path(hass, "/", timedelta(seconds=5)) signature = yarl.URL(signed_path).query["authSig"] claims = jwt.decode( @@ -564,7 +579,7 @@ async def test_local_only_user_rejected( ) -> None: """Test access with access token in header.""" token = hass_access_token - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) set_mock_ip = mock_real_ip(app) client = await aiohttp_client(app) refresh_token = hass.auth.async_validate_access_token(hass_access_token) @@ -630,7 +645,7 @@ async def test_create_user_once(hass: HomeAssistant) -> None: """Test that we reuse the user.""" cur_users = len(await hass.auth.async_get_users()) app = web.Application() - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) users = await hass.auth.async_get_users() assert len(users) == cur_users + 1 @@ -642,7 +657,287 @@ async def test_create_user_once(hass: HomeAssistant) -> None: assert len(user.refresh_tokens) == 1 assert user.system_generated - await async_setup_auth(hass, app) + await async_setup_auth(hass, app, StrictConnectionMode.DISABLED) # test it did not create a user assert len(await hass.auth.async_get_users()) == cur_users + 1 + + +@pytest.fixture +def app_strict_connection(hass): + """Fixture to set up a web.Application.""" + + async def handler(request): + """Return if request was authenticated.""" + return web.json_response(data={"authenticated": request[KEY_AUTHENTICATED]}) + + app = web.Application() + app[KEY_HASS] = hass + app.router.add_get("/", handler) + async_setup_forwarded(app, True, []) + return app + + +@pytest.mark.parametrize( + "strict_connection_mode", [e.value for e in StrictConnectionMode] +) +async def test_strict_connection_non_cloud_authenticated_requests( + hass: HomeAssistant, + app_strict_connection: web.Application, + aiohttp_client: ClientSessionGenerator, + hass_access_token: str, + strict_connection_mode: StrictConnectionMode, +) -> None: + """Test authenticated requests with strict connection.""" + token = hass_access_token + await async_setup_auth(hass, app_strict_connection, strict_connection_mode) + set_mock_ip = mock_real_ip(app_strict_connection) + client = await aiohttp_client(app_strict_connection) + refresh_token = hass.auth.async_validate_access_token(hass_access_token) + assert refresh_token + assert hass.auth.session._strict_connection_sessions == {} + + signed_path = async_sign_path( + hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id + ) + + for remote_addr in (*LOCALHOST_ADDRESSES, *PRIVATE_ADDRESSES, *EXTERNAL_ADDRESSES): + set_mock_ip(remote_addr) + + # authorized requests should work normally + req = await client.get("/", headers={"Authorization": f"Bearer {token}"}) + assert req.status == HTTPStatus.OK + assert await req.json() == {"authenticated": True} + req = await client.get(signed_path) + assert req.status == HTTPStatus.OK + assert await req.json() == {"authenticated": True} + + +@pytest.mark.parametrize( + "strict_connection_mode", [e.value for e in StrictConnectionMode] +) +async def test_strict_connection_non_cloud_local_unauthenticated_requests( + hass: HomeAssistant, + app_strict_connection: web.Application, + aiohttp_client: ClientSessionGenerator, + strict_connection_mode: StrictConnectionMode, +) -> None: + """Test local unauthenticated requests with strict connection.""" + await async_setup_auth(hass, app_strict_connection, strict_connection_mode) + set_mock_ip = mock_real_ip(app_strict_connection) + client = await aiohttp_client(app_strict_connection) + assert hass.auth.session._strict_connection_sessions == {} + + for remote_addr in (*LOCALHOST_ADDRESSES, *PRIVATE_ADDRESSES): + set_mock_ip(remote_addr) + # local requests should work normally + req = await client.get("/") + assert req.status == HTTPStatus.OK + assert await req.json() == {"authenticated": False} + + +def _add_set_cookie_endpoint(app: web.Application, refresh_token: RefreshToken) -> None: + """Add an endpoint to set a cookie.""" + + async def set_cookie(request: web.Request) -> web.Response: + hass = request.app[KEY_HASS] + # Clear all sessions + hass.auth.session._temp_sessions.clear() + hass.auth.session._strict_connection_sessions.clear() + + if request.query["token"] == "refresh": + await hass.auth.session.async_create_session(request, refresh_token) + else: + await hass.auth.session.async_create_temp_unauthorized_session(request) + session = await get_session(request) + return web.Response(text=session[SESSION_ID]) + + app.router.add_get("/test/cookie", set_cookie) + + +async def _test_strict_connection_non_cloud_enabled_setup( + hass: HomeAssistant, + app: web.Application, + aiohttp_client: ClientSessionGenerator, + hass_access_token: str, + strict_connection_mode: StrictConnectionMode, +) -> tuple[TestClient, Callable[[str], None], RefreshToken]: + """Test external unauthenticated requests with strict connection non cloud enabled.""" + refresh_token = hass.auth.async_validate_access_token(hass_access_token) + assert refresh_token + session = hass.auth.session + assert session._strict_connection_sessions == {} + assert session._temp_sessions == {} + + _add_set_cookie_endpoint(app, refresh_token) + await async_setup_auth(hass, app, strict_connection_mode) + set_mock_ip = mock_real_ip(app) + client = await aiohttp_client(app) + return (client, set_mock_ip, refresh_token) + + +async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests( + hass: HomeAssistant, + app: web.Application, + aiohttp_client: ClientSessionGenerator, + hass_access_token: str, + perform_unauthenticated_request: Callable[ + [HomeAssistant, TestClient], Awaitable[None] + ], + strict_connection_mode: StrictConnectionMode, +) -> None: + """Test external unauthenticated requests with strict connection non cloud enabled.""" + client, set_mock_ip, _ = await _test_strict_connection_non_cloud_enabled_setup( + hass, app, aiohttp_client, hass_access_token, strict_connection_mode + ) + + for remote_addr in EXTERNAL_ADDRESSES: + set_mock_ip(remote_addr) + await perform_unauthenticated_request(hass, client) + + +async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_refresh_token( + hass: HomeAssistant, + app: web.Application, + aiohttp_client: ClientSessionGenerator, + hass_access_token: str, + perform_unauthenticated_request: Callable[ + [HomeAssistant, TestClient], Awaitable[None] + ], + strict_connection_mode: StrictConnectionMode, +) -> None: + """Test external unauthenticated requests with strict connection non cloud enabled and refresh token cookie.""" + ( + client, + set_mock_ip, + refresh_token, + ) = await _test_strict_connection_non_cloud_enabled_setup( + hass, app, aiohttp_client, hass_access_token, strict_connection_mode + ) + session = hass.auth.session + + # set strict connection cookie with refresh token + set_mock_ip(LOCALHOST_ADDRESSES[0]) + session_id = await (await client.get("/test/cookie?token=refresh")).text() + assert session._strict_connection_sessions == {session_id: refresh_token.id} + for remote_addr in EXTERNAL_ADDRESSES: + set_mock_ip(remote_addr) + req = await client.get("/") + assert req.status == HTTPStatus.OK + assert await req.json() == {"authenticated": False} + + # Invalidate refresh token, which should also invalidate session + hass.auth.async_remove_refresh_token(refresh_token) + assert session._strict_connection_sessions == {} + for remote_addr in EXTERNAL_ADDRESSES: + set_mock_ip(remote_addr) + await perform_unauthenticated_request(hass, client) + + +async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_temp_session( + hass: HomeAssistant, + app: web.Application, + aiohttp_client: ClientSessionGenerator, + hass_access_token: str, + perform_unauthenticated_request: Callable[ + [HomeAssistant, TestClient], Awaitable[None] + ], + strict_connection_mode: StrictConnectionMode, +) -> None: + """Test external unauthenticated requests with strict connection non cloud enabled and temp cookie.""" + client, set_mock_ip, _ = await _test_strict_connection_non_cloud_enabled_setup( + hass, app, aiohttp_client, hass_access_token, strict_connection_mode + ) + session = hass.auth.session + + # set strict connection cookie with temp session + assert session._temp_sessions == {} + set_mock_ip(LOCALHOST_ADDRESSES[0]) + session_id = await (await client.get("/test/cookie?token=temp")).text() + assert client.session.cookie_jar.filter_cookies(URL("http://127.0.0.1")) + assert session_id in session._temp_sessions + for remote_addr in EXTERNAL_ADDRESSES: + set_mock_ip(remote_addr) + resp = await client.get("/") + assert resp.status == HTTPStatus.OK + assert await resp.json() == {"authenticated": False} + + async_fire_time_changed(hass, utcnow() + TEMP_TIMEOUT + timedelta(minutes=1)) + await hass.async_block_till_done(wait_background_tasks=True) + + assert session._temp_sessions == {} + for remote_addr in EXTERNAL_ADDRESSES: + set_mock_ip(remote_addr) + await perform_unauthenticated_request(hass, client) + + +async def _drop_connection_unauthorized_request( + _: HomeAssistant, client: TestClient +) -> None: + with pytest.raises(ServerDisconnectedError): + # unauthorized requests should raise ServerDisconnectedError + await client.get("/") + + +async def _static_page_unauthorized_request( + hass: HomeAssistant, client: TestClient +) -> None: + req = await client.get("/") + assert req.status == HTTPStatus.IM_A_TEAPOT + + def read_static_page() -> str: + with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file: + return file.read() + + assert await req.text() == await hass.async_add_executor_job(read_static_page) + + +@pytest.mark.parametrize( + "test_func", + [ + _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests, + _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_refresh_token, + _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_temp_session, + ], + ids=[ + "no cookie", + "refresh token cookie", + "temp session cookie", + ], +) +@pytest.mark.parametrize( + ("strict_connection_mode", "request_func"), + [ + (StrictConnectionMode.DROP_CONNECTION, _drop_connection_unauthorized_request), + (StrictConnectionMode.STATIC_PAGE, _static_page_unauthorized_request), + ], + ids=["drop connection", "static page"], +) +async def test_strict_connection_non_cloud_external_unauthenticated_requests( + hass: HomeAssistant, + app_strict_connection: web.Application, + aiohttp_client: ClientSessionGenerator, + hass_access_token: str, + test_func: Callable[ + [ + HomeAssistant, + web.Application, + ClientSessionGenerator, + str, + Callable[[HomeAssistant, TestClient], Awaitable[None]], + StrictConnectionMode, + ], + Awaitable[None], + ], + strict_connection_mode: StrictConnectionMode, + request_func: Callable[[HomeAssistant, TestClient], Awaitable[None]], +) -> None: + """Test external unauthenticated requests with strict connection non cloud.""" + await test_func( + hass, + app_strict_connection, + aiohttp_client, + hass_access_token, + request_func, + strict_connection_mode, + ) diff --git a/tests/components/http/test_init.py b/tests/components/http/test_init.py index 9e892e2ee43..b84da595ab1 100644 --- a/tests/components/http/test_init.py +++ b/tests/components/http/test_init.py @@ -7,6 +7,7 @@ from ipaddress import ip_network import logging from pathlib import Path from unittest.mock import Mock, patch +from urllib.parse import quote_plus import pytest @@ -14,7 +15,10 @@ from homeassistant.auth.providers.legacy_api_password import ( LegacyApiPasswordAuthProvider, ) from homeassistant.components import http +from homeassistant.components.http.const import StrictConnectionMode +from homeassistant.config import async_process_ha_core_config from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ServiceValidationError from homeassistant.helpers.http import KEY_HASS from homeassistant.helpers.network import NoURLAvailableError from homeassistant.setup import async_setup_component @@ -521,3 +525,78 @@ async def test_logging( response = await client.get("/api/states/logging.entity") assert response.status == HTTPStatus.OK assert "GET /api/states/logging.entity" not in caplog.text + + +async def test_service_create_temporary_strict_connection_url_strict_connection_disabled( + hass: HomeAssistant, +) -> None: + """Test service create_temporary_strict_connection_url with strict_connection not enabled.""" + assert await async_setup_component(hass, http.DOMAIN, {"http": {}}) + with pytest.raises( + ServiceValidationError, + match="Strict connection is not enabled for non-cloud requests", + ): + await hass.services.async_call( + http.DOMAIN, + "create_temporary_strict_connection_url", + blocking=True, + return_response=True, + ) + + +@pytest.mark.parametrize( + ("mode"), + [ + StrictConnectionMode.DROP_CONNECTION, + StrictConnectionMode.STATIC_PAGE, + ], +) +async def test_service_create_temporary_strict_connection( + hass: HomeAssistant, mode: StrictConnectionMode +) -> None: + """Test service create_temporary_strict_connection_url.""" + assert await async_setup_component( + hass, http.DOMAIN, {"http": {"strict_connection": mode}} + ) + + # No external url set + assert hass.config.external_url is None + assert hass.config.internal_url is None + with pytest.raises(ServiceValidationError, match="No external URL available"): + await hass.services.async_call( + http.DOMAIN, + "create_temporary_strict_connection_url", + blocking=True, + return_response=True, + ) + + # Raise if only internal url is available + hass.config.api = Mock(use_ssl=False, port=8123, local_ip="192.168.123.123") + with pytest.raises(ServiceValidationError, match="No external URL available"): + await hass.services.async_call( + http.DOMAIN, + "create_temporary_strict_connection_url", + blocking=True, + return_response=True, + ) + + # Set external url too + external_url = "https://example.com" + await async_process_ha_core_config( + hass, + {"external_url": external_url}, + ) + assert hass.config.external_url == external_url + response = await hass.services.async_call( + http.DOMAIN, + "create_temporary_strict_connection_url", + blocking=True, + return_response=True, + ) + assert isinstance(response, dict) + direct_url_prefix = f"{external_url}/auth/strict_connection/temp_token?authSig=" + assert response.pop("direct_url").startswith(direct_url_prefix) + assert response.pop("url").startswith( + f"https://login.home-assistant.io?u={quote_plus(direct_url_prefix)}" + ) + assert response == {} # No more keys in response diff --git a/tests/components/http/test_session.py b/tests/components/http/test_session.py new file mode 100644 index 00000000000..ae62365749a --- /dev/null +++ b/tests/components/http/test_session.py @@ -0,0 +1,107 @@ +"""Tests for HTTP session.""" + +from collections.abc import Callable +import logging +from typing import Any +from unittest.mock import patch + +from aiohttp import web +from aiohttp.test_utils import make_mocked_request +import pytest + +from homeassistant.auth.session import SESSION_ID +from homeassistant.components.http.session import ( + COOKIE_NAME, + HomeAssistantCookieStorage, +) +from homeassistant.core import HomeAssistant + + +def fake_request_with_strict_connection_cookie(cookie_value: str) -> web.Request: + """Return a fake request with a strict connection cookie.""" + request = make_mocked_request( + "GET", "/", headers={"Cookie": f"{COOKIE_NAME}={cookie_value}"} + ) + assert COOKIE_NAME in request.cookies + return request + + +@pytest.fixture +def cookie_storage(hass: HomeAssistant) -> HomeAssistantCookieStorage: + """Fixture for the cookie storage.""" + return HomeAssistantCookieStorage(hass) + + +def _encrypt_cookie_data(cookie_storage: HomeAssistantCookieStorage, data: Any) -> str: + """Encrypt cookie data.""" + cookie_data = cookie_storage._encoder(data).encode("utf-8") + return cookie_storage._fernet.encrypt(cookie_data).decode("utf-8") + + +@pytest.mark.parametrize( + "func", + [ + lambda _: "invalid", + lambda storage: _encrypt_cookie_data(storage, "bla"), + lambda storage: _encrypt_cookie_data(storage, None), + ], +) +async def test_load_session_modified_cookies( + cookie_storage: HomeAssistantCookieStorage, + caplog: pytest.LogCaptureFixture, + func: Callable[[HomeAssistantCookieStorage], str], +) -> None: + """Test that on modified cookies the session is empty and the request will be logged for ban.""" + request = fake_request_with_strict_connection_cookie(func(cookie_storage)) + with patch( + "homeassistant.components.http.session.process_wrong_login", + ) as mock_process_wrong_login: + session = await cookie_storage.load_session(request) + assert session.empty + assert ( + "homeassistant.components.http.session", + logging.WARNING, + "Cannot decrypt/parse cookie value", + ) in caplog.record_tuples + mock_process_wrong_login.assert_called() + + +async def test_load_session_validate_session( + hass: HomeAssistant, + cookie_storage: HomeAssistantCookieStorage, +) -> None: + """Test load session validates the session.""" + session = await cookie_storage.new_session() + session[SESSION_ID] = "bla" + request = fake_request_with_strict_connection_cookie( + _encrypt_cookie_data(cookie_storage, cookie_storage._get_session_data(session)) + ) + + with patch.object( + hass.auth.session, "async_validate_strict_connection_session", return_value=True + ) as mock_validate: + session = await cookie_storage.load_session(request) + assert not session.empty + assert session[SESSION_ID] == "bla" + mock_validate.assert_called_with(session) + + # verify lru_cache is working + mock_validate.reset_mock() + await cookie_storage.load_session(request) + mock_validate.assert_not_called() + + session = await cookie_storage.new_session() + session[SESSION_ID] = "something" + request = fake_request_with_strict_connection_cookie( + _encrypt_cookie_data(cookie_storage, cookie_storage._get_session_data(session)) + ) + + with patch.object( + hass.auth.session, + "async_validate_strict_connection_session", + return_value=False, + ): + session = await cookie_storage.load_session(request) + assert session.empty + assert SESSION_ID not in session + assert session._changed diff --git a/tests/components/stream/conftest.py b/tests/components/stream/conftest.py index 9ce23d99152..280d15cd1ef 100644 --- a/tests/components/stream/conftest.py +++ b/tests/components/stream/conftest.py @@ -14,7 +14,6 @@ from __future__ import annotations import asyncio from collections.abc import Generator -from http import HTTPStatus import logging import threading from unittest.mock import Mock, patch @@ -87,6 +86,17 @@ class HLSSync: self._num_recvs = 0 self._num_finished = 0 + def on_resp(): + self._num_finished += 1 + self.check_requests_ready() + + class SyncResponse(web.Response): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + on_resp() + + self.response = SyncResponse + def reset_request_pool(self, num_requests: int, reset_finished=True): """Use to reset the request counter between segments.""" self._num_recvs = 0 @@ -120,12 +130,6 @@ class HLSSync: self.check_requests_ready() return self._original_not_found() - def response(self, body, headers=None, status=HTTPStatus.OK): - """Intercept the Response call so we know when the web handler is finished.""" - self._num_finished += 1 - self.check_requests_ready() - return self._original_response(body=body, headers=headers, status=status) - async def recv(self, output: StreamOutput, **kw): """Intercept the recv call so we know when the response is blocking on recv.""" self._num_recvs += 1 @@ -164,7 +168,7 @@ def hls_sync(): ), patch( "homeassistant.components.stream.hls.web.Response", - side_effect=sync.response, + new=sync.response, ), ): yield sync diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index e96f1c4f903..2bd76accfdd 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -701,7 +701,7 @@ async def test_get_services( assert msg["id"] == id_ assert msg["type"] == const.TYPE_RESULT assert msg["success"] - assert msg["result"] == hass.services.async_services() + assert msg["result"].keys() == hass.services.async_services().keys() async def test_get_config( diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 74b8a86ce7c..b5e71f4c9d8 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -7,6 +7,7 @@ from typing import Any from unittest.mock import AsyncMock, Mock, patch import pytest +from pytest_unordered import unordered import voluptuous as vol # To prevent circular import when running just this file @@ -16,6 +17,7 @@ import homeassistant.components # noqa: F401 from homeassistant.components.group import DOMAIN as DOMAIN_GROUP, Group from homeassistant.components.logger import DOMAIN as DOMAIN_LOGGER from homeassistant.components.shell_command import DOMAIN as DOMAIN_SHELL_COMMAND +from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH from homeassistant.const import ( ATTR_ENTITY_ID, ENTITY_MATCH_ALL, @@ -785,7 +787,7 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None: """Test async_get_all_descriptions.""" group_config = {DOMAIN_GROUP: {}} assert await async_setup_component(hass, DOMAIN_GROUP, group_config) - assert await async_setup_component(hass, "system_health", {}) + assert await async_setup_component(hass, DOMAIN_SYSTEM_HEALTH, {}) with patch( "homeassistant.helpers.service._load_services_files", @@ -795,17 +797,20 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None: # Test we only load services.yaml for integrations with services.yaml # And system_health has no services - assert proxy_load_services_files.mock_calls[0][1][1] == [ - await async_get_integration(hass, "group") - ] + assert proxy_load_services_files.mock_calls[0][1][1] == unordered( + [ + await async_get_integration(hass, DOMAIN_GROUP), + await async_get_integration(hass, "http"), # system_health requires http + ] + ) - assert len(descriptions) == 1 - - assert "description" in descriptions["group"]["reload"] - assert "fields" in descriptions["group"]["reload"] + assert len(descriptions) == 2 + assert DOMAIN_GROUP in descriptions + assert "description" in descriptions[DOMAIN_GROUP]["reload"] + assert "fields" in descriptions[DOMAIN_GROUP]["reload"] # Does not have services - assert "system_health" not in descriptions + assert DOMAIN_SYSTEM_HEALTH not in descriptions logger_config = {DOMAIN_LOGGER: {}} @@ -833,8 +838,8 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None: await async_setup_component(hass, DOMAIN_LOGGER, logger_config) descriptions = await service.async_get_all_descriptions(hass) - assert len(descriptions) == 2 - + assert len(descriptions) == 3 + assert DOMAIN_LOGGER in descriptions assert descriptions[DOMAIN_LOGGER]["set_default_level"]["name"] == "Translated name" assert ( descriptions[DOMAIN_LOGGER]["set_default_level"]["description"] diff --git a/tests/scripts/test_check_config.py b/tests/scripts/test_check_config.py index 79c64259f8b..76acb2ff678 100644 --- a/tests/scripts/test_check_config.py +++ b/tests/scripts/test_check_config.py @@ -5,6 +5,7 @@ from unittest.mock import patch import pytest +from homeassistant.components.http.const import StrictConnectionMode from homeassistant.config import YAML_CONFIG_FILE from homeassistant.scripts import check_config @@ -134,6 +135,7 @@ def test_secrets(mock_is_file, event_loop, mock_hass_config_yaml: None) -> None: "login_attempts_threshold": -1, "server_port": 8123, "ssl_profile": "modern", + "strict_connection": StrictConnectionMode.DISABLED, "use_x_frame_options": True, "server_host": ["0.0.0.0", "::"], }