diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py
index 969fcc3529e..2a9525181f6 100644
--- a/homeassistant/auth/__init__.py
+++ b/homeassistant/auth/__init__.py
@@ -28,6 +28,7 @@ from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRA
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
from .models import AuthFlowResult
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
+from .session import SessionManager
EVENT_USER_ADDED = "user_added"
EVENT_USER_UPDATED = "user_updated"
@@ -85,7 +86,7 @@ async def auth_manager_from_config(
module_hash[module.id] = module
manager = AuthManager(hass, store, provider_hash, module_hash)
- manager.async_setup()
+ await manager.async_setup()
return manager
@@ -180,9 +181,9 @@ class AuthManager:
self._remove_expired_job = HassJob(
self._async_remove_expired_refresh_tokens, job_type=HassJobType.Callback
)
+ self.session = SessionManager(hass, self)
- @callback
- def async_setup(self) -> None:
+ async def async_setup(self) -> None:
"""Set up the auth manager."""
hass = self.hass
hass.async_add_shutdown_job(
@@ -191,6 +192,7 @@ class AuthManager:
)
)
self._async_track_next_refresh_token_expiration()
+ await self.session.async_setup()
@property
def auth_providers(self) -> list[AuthProvider]:
diff --git a/homeassistant/auth/session.py b/homeassistant/auth/session.py
new file mode 100644
index 00000000000..88297b50d90
--- /dev/null
+++ b/homeassistant/auth/session.py
@@ -0,0 +1,205 @@
+"""Session auth module."""
+
+from __future__ import annotations
+
+from datetime import datetime, timedelta
+import secrets
+from typing import TYPE_CHECKING, Final, TypedDict
+
+from aiohttp.web import Request
+from aiohttp_session import Session, get_session, new_session
+from cryptography.fernet import Fernet
+
+from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
+from homeassistant.helpers.event import async_call_later
+from homeassistant.helpers.storage import Store
+from homeassistant.util import dt as dt_util
+
+from .models import RefreshToken
+
+if TYPE_CHECKING:
+ from . import AuthManager
+
+
+TEMP_TIMEOUT = timedelta(minutes=5)
+TEMP_TIMEOUT_SECONDS = TEMP_TIMEOUT.total_seconds()
+
+SESSION_ID = "id"
+STORAGE_VERSION = 1
+STORAGE_KEY = "auth.session"
+
+
+class StrictConnectionTempSessionData:
+ """Data for accessing unauthorized resources for a short period of time."""
+
+ __slots__ = ("cancel_remove", "absolute_expiry")
+
+ def __init__(self, cancel_remove: CALLBACK_TYPE) -> None:
+ """Initialize the temp session data."""
+ self.cancel_remove: Final[CALLBACK_TYPE] = cancel_remove
+ self.absolute_expiry: Final[datetime] = dt_util.utcnow() + TEMP_TIMEOUT
+
+
+class StoreData(TypedDict):
+ """Data to store."""
+
+ unauthorized_sessions: dict[str, str]
+ key: str
+
+
+class SessionManager:
+ """Session manager."""
+
+ def __init__(self, hass: HomeAssistant, auth: AuthManager) -> None:
+ """Initialize the strict connection manager."""
+ self._auth = auth
+ self._hass = hass
+ self._temp_sessions: dict[str, StrictConnectionTempSessionData] = {}
+ self._strict_connection_sessions: dict[str, str] = {}
+ self._store = Store[StoreData](
+ hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
+ )
+ self._key: str | None = None
+ self._refresh_token_revoke_callbacks: dict[str, CALLBACK_TYPE] = {}
+
+ @property
+ def key(self) -> str:
+ """Return the encryption key."""
+ if self._key is None:
+ self._key = Fernet.generate_key().decode()
+ self._async_schedule_save()
+ return self._key
+
+ async def async_validate_request_for_strict_connection_session(
+ self,
+ request: Request,
+ ) -> bool:
+ """Check if a request has a valid strict connection session."""
+ session = await get_session(request)
+ if session.new or session.empty:
+ return False
+ result = self.async_validate_strict_connection_session(session)
+ if result is False:
+ session.invalidate()
+ return result
+
+ @callback
+ def async_validate_strict_connection_session(
+ self,
+ session: Session,
+ ) -> bool:
+ """Validate a strict connection session."""
+ if not (session_id := session.get(SESSION_ID)):
+ return False
+
+ if token_id := self._strict_connection_sessions.get(session_id):
+ if self._auth.async_get_refresh_token(token_id):
+ return True
+ # refresh token is invalid, delete entry
+ self._strict_connection_sessions.pop(session_id)
+ self._async_schedule_save()
+
+ if data := self._temp_sessions.get(session_id):
+ if dt_util.utcnow() <= data.absolute_expiry:
+ return True
+ # session expired, delete entry
+ self._temp_sessions.pop(session_id).cancel_remove()
+
+ return False
+
+ @callback
+ def _async_register_revoke_token_callback(self, refresh_token_id: str) -> None:
+ """Register a callback to revoke all sessions for a refresh token."""
+ if refresh_token_id in self._refresh_token_revoke_callbacks:
+ return
+
+ @callback
+ def async_invalidate_auth_sessions() -> None:
+ """Invalidate all sessions for a refresh token."""
+ self._strict_connection_sessions = {
+ session_id: token_id
+ for session_id, token_id in self._strict_connection_sessions.items()
+ if token_id != refresh_token_id
+ }
+ self._async_schedule_save()
+
+ self._refresh_token_revoke_callbacks[refresh_token_id] = (
+ self._auth.async_register_revoke_token_callback(
+ refresh_token_id, async_invalidate_auth_sessions
+ )
+ )
+
+ async def async_create_session(
+ self,
+ request: Request,
+ refresh_token: RefreshToken,
+ ) -> None:
+ """Create new session for given refresh token.
+
+ Caller needs to make sure that the refresh token is valid.
+ By creating a session, we are implicitly revoking all other
+ sessions for the given refresh token as there is one refresh
+ token per device/user case.
+ """
+ self._strict_connection_sessions = {
+ session_id: token_id
+ for session_id, token_id in self._strict_connection_sessions.items()
+ if token_id != refresh_token.id
+ }
+
+ self._async_register_revoke_token_callback(refresh_token.id)
+ session_id = await self._async_create_new_session(request)
+ self._strict_connection_sessions[session_id] = refresh_token.id
+ self._async_schedule_save()
+
+ async def async_create_temp_unauthorized_session(self, request: Request) -> None:
+ """Create a temporary unauthorized session."""
+ session_id = await self._async_create_new_session(
+ request, max_age=int(TEMP_TIMEOUT_SECONDS)
+ )
+
+ @callback
+ def remove(_: datetime) -> None:
+ self._temp_sessions.pop(session_id, None)
+
+ self._temp_sessions[session_id] = StrictConnectionTempSessionData(
+ async_call_later(self._hass, TEMP_TIMEOUT_SECONDS, remove)
+ )
+
+ async def _async_create_new_session(
+ self,
+ request: Request,
+ *,
+ max_age: int | None = None,
+ ) -> str:
+ session_id = secrets.token_hex(64)
+
+ session = await new_session(request)
+ session[SESSION_ID] = session_id
+ if max_age is not None:
+ session.max_age = max_age
+ return session_id
+
+ @callback
+ def _async_schedule_save(self, delay: float = 1) -> None:
+ """Save sessions."""
+ self._store.async_delay_save(self._data_to_save, delay)
+
+ @callback
+ def _data_to_save(self) -> StoreData:
+ """Return the data to store."""
+ return StoreData(
+ unauthorized_sessions=self._strict_connection_sessions,
+ key=self.key,
+ )
+
+ async def async_setup(self) -> None:
+ """Set up session manager."""
+ data = await self._store.async_load()
+ if data is None:
+ return
+
+ self._key = data["key"]
+ self._strict_connection_sessions = data["unauthorized_sessions"]
+ for token_id in self._strict_connection_sessions.values():
+ self._async_register_revoke_token_callback(token_id)
diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py
index ff54971eb64..3d825cd99b5 100644
--- a/homeassistant/components/auth/__init__.py
+++ b/homeassistant/components/auth/__init__.py
@@ -162,6 +162,7 @@ from homeassistant.util import dt as dt_util
from . import indieauth, login_flow, mfa_setup_flow
DOMAIN = "auth"
+STRICT_CONNECTION_URL = "/auth/strict_connection/temp_token"
StoreResultType = Callable[[str, Credentials], str]
RetrieveResultType = Callable[[str, str], Credentials | None]
@@ -187,6 +188,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass.http.register_view(RevokeTokenView())
hass.http.register_view(LinkUserView(retrieve_result))
hass.http.register_view(OAuth2AuthorizeCallbackView())
+ hass.http.register_view(StrictConnectionTempTokenView())
websocket_api.async_register_command(hass, websocket_current_user)
websocket_api.async_register_command(hass, websocket_create_long_lived_access_token)
@@ -260,10 +262,10 @@ class TokenView(HomeAssistantView):
return await RevokeTokenView.post(self, request) # type: ignore[arg-type]
if grant_type == "authorization_code":
- return await self._async_handle_auth_code(hass, data, request.remote)
+ return await self._async_handle_auth_code(hass, data, request)
if grant_type == "refresh_token":
- return await self._async_handle_refresh_token(hass, data, request.remote)
+ return await self._async_handle_refresh_token(hass, data, request)
return self.json(
{"error": "unsupported_grant_type"}, status_code=HTTPStatus.BAD_REQUEST
@@ -273,7 +275,7 @@ class TokenView(HomeAssistantView):
self,
hass: HomeAssistant,
data: MultiDictProxy[str],
- remote_addr: str | None,
+ request: web.Request,
) -> web.Response:
"""Handle authorization code request."""
client_id = data.get("client_id")
@@ -313,7 +315,7 @@ class TokenView(HomeAssistantView):
)
try:
access_token = hass.auth.async_create_access_token(
- refresh_token, remote_addr
+ refresh_token, request.remote
)
except InvalidAuthError as exc:
return self.json(
@@ -321,6 +323,7 @@ class TokenView(HomeAssistantView):
status_code=HTTPStatus.FORBIDDEN,
)
+ await hass.auth.session.async_create_session(request, refresh_token)
return self.json(
{
"access_token": access_token,
@@ -341,9 +344,9 @@ class TokenView(HomeAssistantView):
self,
hass: HomeAssistant,
data: MultiDictProxy[str],
- remote_addr: str | None,
+ request: web.Request,
) -> web.Response:
- """Handle authorization code request."""
+ """Handle refresh token request."""
client_id = data.get("client_id")
if client_id is not None and not indieauth.verify_client_id(client_id):
return self.json(
@@ -381,7 +384,7 @@ class TokenView(HomeAssistantView):
try:
access_token = hass.auth.async_create_access_token(
- refresh_token, remote_addr
+ refresh_token, request.remote
)
except InvalidAuthError as exc:
return self.json(
@@ -389,6 +392,7 @@ class TokenView(HomeAssistantView):
status_code=HTTPStatus.FORBIDDEN,
)
+ await hass.auth.session.async_create_session(request, refresh_token)
return self.json(
{
"access_token": access_token,
@@ -437,6 +441,20 @@ class LinkUserView(HomeAssistantView):
return self.json_message("User linked")
+class StrictConnectionTempTokenView(HomeAssistantView):
+ """View to get temporary strict connection token."""
+
+ url = STRICT_CONNECTION_URL
+ name = "api:auth:strict_connection:temp_token"
+ requires_auth = False
+
+ async def get(self, request: web.Request) -> web.Response:
+ """Get a temporary token and redirect to main page."""
+ hass = request.app[KEY_HASS]
+ await hass.auth.session.async_create_temp_unauthorized_session(request)
+ raise web.HTTPSeeOther(location="/")
+
+
@callback
def _create_auth_code_store() -> tuple[StoreResultType, RetrieveResultType]:
"""Create an in memory store."""
diff --git a/homeassistant/components/hassio/ingress.py b/homeassistant/components/hassio/ingress.py
index 6d6faa6fe75..ed6e47145dd 100644
--- a/homeassistant/components/hassio/ingress.py
+++ b/homeassistant/components/hassio/ingress.py
@@ -197,7 +197,6 @@ class HassIOIngress(HomeAssistantView):
content_type or simple_response.content_type
):
simple_response.enable_compression()
- await simple_response.prepare(request)
return simple_response
# Stream response
diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py
index e89031cb265..3e5f7333cbc 100644
--- a/homeassistant/components/http/__init__.py
+++ b/homeassistant/components/http/__init__.py
@@ -10,7 +10,8 @@ import os
import socket
import ssl
from tempfile import NamedTemporaryFile
-from typing import Any, Final, TypedDict, cast
+from typing import Any, Final, Required, TypedDict, cast
+from urllib.parse import quote_plus, urljoin
from aiohttp import web
from aiohttp.abc import AbstractStreamWriter
@@ -30,8 +31,20 @@ from yarl import URL
from homeassistant.components.network import async_get_source_ip
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT
-from homeassistant.core import Event, HomeAssistant
-from homeassistant.exceptions import HomeAssistantError
+from homeassistant.core import (
+ Event,
+ HomeAssistant,
+ ServiceCall,
+ ServiceResponse,
+ SupportsResponse,
+ callback,
+)
+from homeassistant.exceptions import (
+ HomeAssistantError,
+ ServiceValidationError,
+ Unauthorized,
+ UnknownUser,
+)
from homeassistant.helpers import storage
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.http import (
@@ -53,9 +66,13 @@ from homeassistant.util import dt as dt_util, ssl as ssl_util
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.json import json_loads
-from .auth import async_setup_auth
+from .auth import async_setup_auth, async_sign_path
from .ban import setup_bans
-from .const import KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER # noqa: F401
+from .const import ( # noqa: F401
+ KEY_HASS_REFRESH_TOKEN_ID,
+ KEY_HASS_USER,
+ StrictConnectionMode,
+)
from .cors import setup_cors
from .decorators import require_admin # noqa: F401
from .forwarded import async_setup_forwarded
@@ -80,6 +97,7 @@ CONF_TRUSTED_PROXIES: Final = "trusted_proxies"
CONF_LOGIN_ATTEMPTS_THRESHOLD: Final = "login_attempts_threshold"
CONF_IP_BAN_ENABLED: Final = "ip_ban_enabled"
CONF_SSL_PROFILE: Final = "ssl_profile"
+CONF_STRICT_CONNECTION: Final = "strict_connection"
SSL_MODERN: Final = "modern"
SSL_INTERMEDIATE: Final = "intermediate"
@@ -129,6 +147,9 @@ HTTP_SCHEMA: Final = vol.All(
[SSL_INTERMEDIATE, SSL_MODERN]
),
vol.Optional(CONF_USE_X_FRAME_OPTIONS, default=True): cv.boolean,
+ vol.Optional(
+ CONF_STRICT_CONNECTION, default=StrictConnectionMode.DISABLED
+ ): vol.In([e.value for e in StrictConnectionMode]),
}
),
)
@@ -152,6 +173,7 @@ class ConfData(TypedDict, total=False):
login_attempts_threshold: int
ip_ban_enabled: bool
ssl_profile: str
+ strict_connection: Required[StrictConnectionMode]
@bind_hass
@@ -218,6 +240,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
login_threshold=login_threshold,
is_ban_enabled=is_ban_enabled,
use_x_frame_options=use_x_frame_options,
+ strict_connection_non_cloud=conf[CONF_STRICT_CONNECTION],
)
async def stop_server(event: Event) -> None:
@@ -247,6 +270,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
local_ip, host, server_port, ssl_certificate is not None
)
+ _setup_services(hass, conf)
return True
@@ -331,6 +355,7 @@ class HomeAssistantHTTP:
login_threshold: int,
is_ban_enabled: bool,
use_x_frame_options: bool,
+ strict_connection_non_cloud: StrictConnectionMode,
) -> None:
"""Initialize the server."""
self.app[KEY_HASS] = self.hass
@@ -347,7 +372,7 @@ class HomeAssistantHTTP:
if is_ban_enabled:
setup_bans(self.hass, self.app, login_threshold)
- await async_setup_auth(self.hass, self.app)
+ await async_setup_auth(self.hass, self.app, strict_connection_non_cloud)
setup_headers(self.app, use_x_frame_options)
setup_cors(self.app, cors_origins)
@@ -577,3 +602,59 @@ async def start_http_server_and_save_config(
]
store.async_delay_save(lambda: conf, SAVE_DELAY)
+
+
+@callback
+def _setup_services(hass: HomeAssistant, conf: ConfData) -> None:
+ """Set up services for HTTP component."""
+
+ async def create_temporary_strict_connection_url(
+ call: ServiceCall,
+ ) -> ServiceResponse:
+ """Create a strict connection url and return it."""
+ # Copied form homeassistant/helpers/service.py#_async_admin_handler
+ # as the helper supports no responses yet
+ if call.context.user_id:
+ user = await hass.auth.async_get_user(call.context.user_id)
+ if user is None:
+ raise UnknownUser(context=call.context)
+ if not user.is_admin:
+ raise Unauthorized(context=call.context)
+
+ if conf[CONF_STRICT_CONNECTION] is StrictConnectionMode.DISABLED:
+ raise ServiceValidationError(
+ translation_domain=DOMAIN,
+ translation_key="strict_connection_not_enabled_non_cloud",
+ )
+
+ try:
+ url = get_url(hass, prefer_external=True, allow_internal=False)
+ except NoURLAvailableError as ex:
+ raise ServiceValidationError(
+ translation_domain=DOMAIN,
+ translation_key="no_external_url_available",
+ ) from ex
+
+ # to avoid circular import
+ # pylint: disable-next=import-outside-toplevel
+ from homeassistant.components.auth import STRICT_CONNECTION_URL
+
+ path = async_sign_path(
+ hass,
+ STRICT_CONNECTION_URL,
+ datetime.timedelta(hours=1),
+ use_content_user=True,
+ )
+ url = urljoin(url, path)
+
+ return {
+ "url": f"https://login.home-assistant.io?u={quote_plus(url)}",
+ "direct_url": url,
+ }
+
+ hass.services.async_register(
+ DOMAIN,
+ "create_temporary_strict_connection_url",
+ create_temporary_strict_connection_url,
+ supports_response=SupportsResponse.ONLY,
+ )
diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py
index 2073c998384..1eb74289089 100644
--- a/homeassistant/components/http/auth.py
+++ b/homeassistant/components/http/auth.py
@@ -4,14 +4,18 @@ from __future__ import annotations
from collections.abc import Awaitable, Callable
from datetime import timedelta
+from http import HTTPStatus
from ipaddress import ip_address
import logging
+import os
import secrets
import time
from typing import Any, Final
from aiohttp import hdrs
-from aiohttp.web import Application, Request, StreamResponse, middleware
+from aiohttp.web import Application, Request, Response, StreamResponse, middleware
+from aiohttp.web_exceptions import HTTPBadRequest
+from aiohttp_session import session_middleware
import jwt
from jwt import api_jws
from yarl import URL
@@ -27,7 +31,13 @@ from homeassistant.helpers.network import is_cloud_connection
from homeassistant.helpers.storage import Store
from homeassistant.util.network import is_local
-from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
+from .const import (
+ KEY_AUTHENTICATED,
+ KEY_HASS_REFRESH_TOKEN_ID,
+ KEY_HASS_USER,
+ StrictConnectionMode,
+)
+from .session import HomeAssistantCookieStorage
_LOGGER = logging.getLogger(__name__)
@@ -39,6 +49,10 @@ SAFE_QUERY_PARAMS: Final = ["height", "width"]
STORAGE_VERSION = 1
STORAGE_KEY = "http.auth"
CONTENT_USER_NAME = "Home Assistant Content"
+STRICT_CONNECTION_EXCLUDED_PATH = "/api/webhook/"
+STRICT_CONNECTION_STATIC_PAGE = os.path.join(
+ os.path.dirname(__file__), "strict_connection_static_page.html"
+)
@callback
@@ -48,13 +62,16 @@ def async_sign_path(
expiration: timedelta,
*,
refresh_token_id: str | None = None,
+ use_content_user: bool = False,
) -> str:
"""Sign a path for temporary access without auth header."""
if (secret := hass.data.get(DATA_SIGN_SECRET)) is None:
secret = hass.data[DATA_SIGN_SECRET] = secrets.token_hex()
if refresh_token_id is None:
- if connection := websocket_api.current_connection.get():
+ if use_content_user:
+ refresh_token_id = hass.data[STORAGE_KEY]
+ elif connection := websocket_api.current_connection.get():
refresh_token_id = connection.refresh_token_id
elif (
request := current_request.get()
@@ -114,7 +131,11 @@ def async_user_not_allowed_do_auth(
return "User cannot authenticate remotely"
-async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
+async def async_setup_auth(
+ hass: HomeAssistant,
+ app: Application,
+ strict_connection_mode_non_cloud: StrictConnectionMode,
+) -> None:
"""Create auth middleware for the app."""
store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
if (data := await store.async_load()) is None:
@@ -135,6 +156,16 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
await store.async_save(data)
hass.data[STORAGE_KEY] = refresh_token.id
+ strict_connection_static_file_content = None
+ if strict_connection_mode_non_cloud is StrictConnectionMode.STATIC_PAGE:
+
+ def read_static_page() -> str:
+ with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
+ return file.read()
+
+ strict_connection_static_file_content = await hass.async_add_executor_job(
+ read_static_page
+ )
@callback
def async_validate_auth_header(request: Request) -> bool:
@@ -224,6 +255,22 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
authenticated = True
auth_type = "signed request"
+ if (
+ not authenticated
+ and strict_connection_mode_non_cloud is not StrictConnectionMode.DISABLED
+ and not request.path.startswith(STRICT_CONNECTION_EXCLUDED_PATH)
+ and not await hass.auth.session.async_validate_request_for_strict_connection_session(
+ request
+ )
+ and (
+ resp := _async_perform_action_on_non_local(
+ request, strict_connection_static_file_content
+ )
+ )
+ is not None
+ ):
+ return resp
+
if authenticated and _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Authenticated %s for %s using %s",
@@ -235,4 +282,43 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
request[KEY_AUTHENTICATED] = authenticated
return await handler(request)
+ app.middlewares.append(session_middleware(HomeAssistantCookieStorage(hass)))
app.middlewares.append(auth_middleware)
+
+
+@callback
+def _async_perform_action_on_non_local(
+ request: Request,
+ strict_connection_static_file_content: str | None,
+) -> StreamResponse | None:
+ """Perform strict connection mode action if the request is not local.
+
+ The function does the following:
+ - Try to get the IP address of the request. If it fails, assume it's not local
+ - If the request is local, return None (allow the request to continue)
+ - If strict_connection_static_file_content is set, return a response with the content
+ - Otherwise close the connection and raise an exception
+ """
+ try:
+ ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
+ except ValueError:
+ _LOGGER.debug("Invalid IP address: %s", request.remote)
+ ip_address_ = None
+
+ if ip_address_ and is_local(ip_address_):
+ return None
+
+ _LOGGER.debug("Perform strict connection action for %s", ip_address_)
+ if strict_connection_static_file_content:
+ return Response(
+ text=strict_connection_static_file_content,
+ content_type="text/html",
+ status=HTTPStatus.IM_A_TEAPOT,
+ )
+
+ if transport := request.transport:
+ # it should never happen that we don't have a transport
+ transport.close()
+
+ # We need to raise an exception to stop processing the request
+ raise HTTPBadRequest
diff --git a/homeassistant/components/http/const.py b/homeassistant/components/http/const.py
index 1254744f258..d02416c531b 100644
--- a/homeassistant/components/http/const.py
+++ b/homeassistant/components/http/const.py
@@ -1,8 +1,17 @@
"""HTTP specific constants."""
+from enum import StrEnum
from typing import Final
from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401
KEY_HASS_USER: Final = "hass_user"
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"
+
+
+class StrictConnectionMode(StrEnum):
+ """Enum for strict connection mode."""
+
+ DISABLED = "disabled"
+ STATIC_PAGE = "static_page"
+ DROP_CONNECTION = "drop_connection"
diff --git a/homeassistant/components/http/icons.json b/homeassistant/components/http/icons.json
new file mode 100644
index 00000000000..8e8b6285db7
--- /dev/null
+++ b/homeassistant/components/http/icons.json
@@ -0,0 +1,5 @@
+{
+ "services": {
+ "create_temporary_strict_connection_url": "mdi:login-variant"
+ }
+}
diff --git a/homeassistant/components/http/services.yaml b/homeassistant/components/http/services.yaml
new file mode 100644
index 00000000000..16b0debb144
--- /dev/null
+++ b/homeassistant/components/http/services.yaml
@@ -0,0 +1 @@
+create_temporary_strict_connection_url: ~
diff --git a/homeassistant/components/http/session.py b/homeassistant/components/http/session.py
new file mode 100644
index 00000000000..81668ec2ccc
--- /dev/null
+++ b/homeassistant/components/http/session.py
@@ -0,0 +1,160 @@
+"""Session http module."""
+
+from functools import lru_cache
+import logging
+
+from aiohttp.web import Request, StreamResponse
+from aiohttp_session import Session, SessionData
+from aiohttp_session.cookie_storage import EncryptedCookieStorage
+from cryptography.fernet import InvalidToken
+
+from homeassistant.auth.const import REFRESH_TOKEN_EXPIRATION
+from homeassistant.core import HomeAssistant
+from homeassistant.helpers.json import json_dumps
+from homeassistant.helpers.network import is_cloud_connection
+from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads
+
+from .ban import process_wrong_login
+
+_LOGGER = logging.getLogger(__name__)
+
+COOKIE_NAME = "SC"
+PREFIXED_COOKIE_NAME = f"__Host-{COOKIE_NAME}"
+SESSION_CACHE_SIZE = 16
+
+
+def _get_cookie_name(is_secure: bool) -> str:
+ """Return the cookie name."""
+ return PREFIXED_COOKIE_NAME if is_secure else COOKIE_NAME
+
+
+class HomeAssistantCookieStorage(EncryptedCookieStorage):
+ """Home Assistant cookie storage.
+
+ Own class is required:
+ - to set the secure flag based on the connection type
+ - to use a LRU cache for session decryption
+ """
+
+ def __init__(self, hass: HomeAssistant) -> None:
+ """Initialize the cookie storage."""
+ super().__init__(
+ hass.auth.session.key,
+ cookie_name=PREFIXED_COOKIE_NAME,
+ max_age=int(REFRESH_TOKEN_EXPIRATION),
+ httponly=True,
+ samesite="Lax",
+ secure=True,
+ encoder=json_dumps,
+ decoder=json_loads,
+ )
+ self._hass = hass
+
+ def _secure_connection(self, request: Request) -> bool:
+ """Return if the connection is secure (https)."""
+ return is_cloud_connection(self._hass) or request.secure
+
+ def load_cookie(self, request: Request) -> str | None:
+ """Load cookie."""
+ is_secure = self._secure_connection(request)
+ cookie_name = _get_cookie_name(is_secure)
+ return request.cookies.get(cookie_name)
+
+ @lru_cache(maxsize=SESSION_CACHE_SIZE)
+ def _decrypt_cookie(self, cookie: str) -> Session | None:
+ """Decrypt and validate cookie."""
+ try:
+ data = SessionData( # type: ignore[misc]
+ self._decoder(
+ self._fernet.decrypt(
+ cookie.encode("utf-8"), ttl=self.max_age
+ ).decode("utf-8")
+ )
+ )
+ except (InvalidToken, TypeError, ValueError, *JSON_DECODE_EXCEPTIONS):
+ _LOGGER.warning("Cannot decrypt/parse cookie value")
+ return None
+
+ session = Session(None, data=data, new=data is None, max_age=self.max_age)
+
+ # Validate session if not empty
+ if (
+ not session.empty
+ and not self._hass.auth.session.async_validate_strict_connection_session(
+ session
+ )
+ ):
+ # Invalidate session as it is not valid
+ session.invalidate()
+
+ return session
+
+ async def new_session(self) -> Session:
+ """Create a new session and mark it as changed."""
+ session = Session(None, data=None, new=True, max_age=self.max_age)
+ session.changed()
+ return session
+
+ async def load_session(self, request: Request) -> Session:
+ """Load session."""
+ # Split parent function to use lru_cache
+ if (cookie := self.load_cookie(request)) is None:
+ return await self.new_session()
+
+ if (session := self._decrypt_cookie(cookie)) is None:
+ # Decrypting/parsing failed, log wrong login and create a new session
+ await process_wrong_login(request)
+ session = await self.new_session()
+
+ return session
+
+ async def save_session(
+ self, request: Request, response: StreamResponse, session: Session
+ ) -> None:
+ """Save session."""
+
+ is_secure = self._secure_connection(request)
+ cookie_name = _get_cookie_name(is_secure)
+
+ if session.empty:
+ response.del_cookie(cookie_name)
+ else:
+ params = self.cookie_params.copy()
+ params["secure"] = is_secure
+ params["max_age"] = session.max_age
+
+ cookie_data = self._encoder(self._get_session_data(session)).encode("utf-8")
+ response.set_cookie(
+ cookie_name,
+ self._fernet.encrypt(cookie_data).decode("utf-8"),
+ **params,
+ )
+ # Add Cache-Control header to not cache the cookie as it
+ # is used for session management
+ self._add_cache_control_header(response)
+
+ @staticmethod
+ def _add_cache_control_header(response: StreamResponse) -> None:
+ """Add/set cache control header to no-cache="Set-Cookie"."""
+ # Structure of the Cache-Control header defined in
+ # https://datatracker.ietf.org/doc/html/rfc2068#section-14.9
+ if header := response.headers.get("Cache-Control"):
+ directives = []
+ for directive in header.split(","):
+ directive = directive.strip()
+ directive_lowered = directive.lower()
+ if directive_lowered.startswith("no-cache"):
+ if "set-cookie" in directive_lowered or directive.find("=") == -1:
+ # Set-Cookie is already in the no-cache directive or
+ # the whole request should not be cached -> Nothing to do
+ return
+
+ # Add Set-Cookie to the no-cache
+ # [:-1] to remove the " at the end of the directive
+ directive = f"{directive[:-1]}, Set-Cookie"
+
+ directives.append(directive)
+ header = ", ".join(directives)
+ else:
+ header = 'no-cache="Set-Cookie"'
+ response.headers["Cache-Control"] = header
diff --git a/homeassistant/components/http/strict_connection_static_page.html b/homeassistant/components/http/strict_connection_static_page.html
new file mode 100644
index 00000000000..24049d9a0eb
--- /dev/null
+++ b/homeassistant/components/http/strict_connection_static_page.html
@@ -0,0 +1,46 @@
+
+
+
+
+
+ I'm a Teapot
+
+
+
+
+
Error 418: I'm a Teapot
+
+ Oops! Looks like the server is taking a coffee break.
+ Don't worry, it'll be back to brewing your requests in no time!
+
+
☕
+
+
+
diff --git a/homeassistant/components/http/strings.json b/homeassistant/components/http/strings.json
new file mode 100644
index 00000000000..7cd64f5f297
--- /dev/null
+++ b/homeassistant/components/http/strings.json
@@ -0,0 +1,16 @@
+{
+ "exceptions": {
+ "strict_connection_not_enabled_non_cloud": {
+ "message": "Strict connection is not enabled for non-cloud requests"
+ },
+ "no_external_url_available": {
+ "message": "No external URL available"
+ }
+ },
+ "services": {
+ "create_temporary_strict_connection_url": {
+ "name": "Create a temporary strict connection URL",
+ "description": "Create a temporary strict connection URL, which can be used to login on another device."
+ }
+ }
+}
diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt
index 090271e028e..b253d600a2d 100644
--- a/homeassistant/package_constraints.txt
+++ b/homeassistant/package_constraints.txt
@@ -7,6 +7,7 @@ aiohttp-fast-url-dispatcher==0.3.0
aiohttp-zlib-ng==0.3.1
aiohttp==3.9.4
aiohttp_cors==0.7.0
+aiohttp_session==2.12.0
astral==2.2
async-interrupt==1.1.1
async-upnp-client==0.38.3
diff --git a/pyproject.toml b/pyproject.toml
index 5ea335115ca..79a66cc7d82 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,6 +26,7 @@ dependencies = [
"aiodns==3.2.0",
"aiohttp==3.9.4",
"aiohttp_cors==0.7.0",
+ "aiohttp_session==2.12.0",
"aiohttp-fast-url-dispatcher==0.3.0",
"aiohttp-zlib-ng==0.3.1",
"astral==2.2",
diff --git a/requirements.txt b/requirements.txt
index 3cd1e8edfa5..f2f26f9bb54 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,6 +6,7 @@
aiodns==3.2.0
aiohttp==3.9.4
aiohttp_cors==0.7.0
+aiohttp_session==2.12.0
aiohttp-fast-url-dispatcher==0.3.0
aiohttp-zlib-ng==0.3.1
astral==2.2
diff --git a/tests/components/api/test_init.py b/tests/components/api/test_init.py
index 0ac2e5973fe..5443d48452f 100644
--- a/tests/components/api/test_init.py
+++ b/tests/components/api/test_init.py
@@ -306,7 +306,7 @@ async def test_api_get_services(
for serv_domain in data:
local = local_services.pop(serv_domain["domain"])
- assert serv_domain["services"] == local
+ assert serv_domain["services"].keys() == local.keys()
async def test_api_call_service_no_data(
diff --git a/tests/components/http/test_auth.py b/tests/components/http/test_auth.py
index de6f323bc8a..f0f87e58173 100644
--- a/tests/components/http/test_auth.py
+++ b/tests/components/http/test_auth.py
@@ -1,22 +1,28 @@
"""The tests for the Home Assistant HTTP component."""
+from collections.abc import Awaitable, Callable
from datetime import timedelta
from http import HTTPStatus
from ipaddress import ip_network
+import logging
from unittest.mock import Mock, patch
-from aiohttp import BasicAuth, web
+from aiohttp import BasicAuth, ServerDisconnectedError, web
+from aiohttp.test_utils import TestClient
from aiohttp.web_exceptions import HTTPUnauthorized
+from aiohttp_session import get_session
import jwt
import pytest
import yarl
+from yarl import URL
from homeassistant.auth.const import GROUP_ID_READ_ONLY
-from homeassistant.auth.models import User
+from homeassistant.auth.models import RefreshToken, User
from homeassistant.auth.providers import trusted_networks
from homeassistant.auth.providers.legacy_api_password import (
LegacyApiPasswordAuthProvider,
)
+from homeassistant.auth.session import SESSION_ID, TEMP_TIMEOUT
from homeassistant.components import websocket_api
from homeassistant.components.http import KEY_HASS
from homeassistant.components.http.auth import (
@@ -24,11 +30,12 @@ from homeassistant.components.http.auth import (
DATA_SIGN_SECRET,
SIGN_QUERY_PARAM,
STORAGE_KEY,
+ STRICT_CONNECTION_STATIC_PAGE,
async_setup_auth,
async_sign_path,
async_user_not_allowed_do_auth,
)
-from homeassistant.components.http.const import KEY_AUTHENTICATED
+from homeassistant.components.http.const import KEY_AUTHENTICATED, StrictConnectionMode
from homeassistant.components.http.forwarded import async_setup_forwarded
from homeassistant.components.http.request_context import (
current_request,
@@ -36,13 +43,15 @@ from homeassistant.components.http.request_context import (
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.setup import async_setup_component
+from homeassistant.util.dt import utcnow
from . import HTTP_HEADER_HA_AUTH
-from tests.common import MockUser
+from tests.common import MockUser, async_fire_time_changed
from tests.test_util import mock_real_ip
from tests.typing import ClientSessionGenerator, WebSocketGenerator
+_LOGGER = logging.getLogger(__name__)
API_PASSWORD = "test-password"
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
@@ -54,7 +63,13 @@ TRUSTED_NETWORKS = [
]
TRUSTED_ADDRESSES = ["100.64.0.1", "192.0.2.100", "FD01:DB8::1", "2001:DB8:ABCD::1"]
EXTERNAL_ADDRESSES = ["198.51.100.1", "2001:DB8:FA1::1"]
-UNTRUSTED_ADDRESSES = [*EXTERNAL_ADDRESSES, "127.0.0.1", "::1"]
+LOCALHOST_ADDRESSES = ["127.0.0.1", "::1"]
+UNTRUSTED_ADDRESSES = [*EXTERNAL_ADDRESSES, *LOCALHOST_ADDRESSES]
+PRIVATE_ADDRESSES = [
+ "192.168.10.10",
+ "172.16.4.20",
+ "10.100.50.5",
+]
async def mock_handler(request):
@@ -122,7 +137,7 @@ async def test_cant_access_with_password_in_header(
hass: HomeAssistant,
) -> None:
"""Test access with password in header."""
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
@@ -139,7 +154,7 @@ async def test_cant_access_with_password_in_query(
hass: HomeAssistant,
) -> None:
"""Test access with password in URL."""
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
resp = await client.get("/", params={"api_password": API_PASSWORD})
@@ -159,7 +174,7 @@ async def test_basic_auth_does_not_work(
legacy_auth: LegacyApiPasswordAuthProvider,
) -> None:
"""Test access with basic authentication."""
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
req = await client.get("/", auth=BasicAuth("homeassistant", API_PASSWORD))
@@ -183,7 +198,7 @@ async def test_cannot_access_with_trusted_ip(
hass_owner_user: MockUser,
) -> None:
"""Test access with an untrusted ip address."""
- await async_setup_auth(hass, app2)
+ await async_setup_auth(hass, app2, StrictConnectionMode.DISABLED)
set_mock_ip = mock_real_ip(app2)
client = await aiohttp_client(app2)
@@ -211,7 +226,7 @@ async def test_auth_active_access_with_access_token_in_header(
) -> None:
"""Test access with access token in header."""
token = hass_access_token
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@@ -247,7 +262,7 @@ async def test_auth_active_access_with_trusted_ip(
hass_owner_user: MockUser,
) -> None:
"""Test access with an untrusted ip address."""
- await async_setup_auth(hass, app2)
+ await async_setup_auth(hass, app2, StrictConnectionMode.DISABLED)
set_mock_ip = mock_real_ip(app2)
client = await aiohttp_client(app2)
@@ -274,7 +289,7 @@ async def test_auth_legacy_support_api_password_cannot_access(
hass: HomeAssistant,
) -> None:
"""Test access using api_password if auth.support_legacy."""
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
@@ -296,7 +311,7 @@ async def test_auth_access_signed_path_with_refresh_token(
"""Test access with signed url."""
app.router.add_post("/", mock_handler)
app.router.add_get("/another_path", mock_handler)
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@@ -341,7 +356,7 @@ async def test_auth_access_signed_path_with_query_param(
"""Test access with signed url and query params."""
app.router.add_post("/", mock_handler)
app.router.add_get("/another_path", mock_handler)
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@@ -371,7 +386,7 @@ async def test_auth_access_signed_path_with_query_param_order(
"""Test access with signed url and query params different order."""
app.router.add_post("/", mock_handler)
app.router.add_get("/another_path", mock_handler)
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@@ -412,7 +427,7 @@ async def test_auth_access_signed_path_with_query_param_safe_param(
"""Test access with signed url and changing a safe param."""
app.router.add_post("/", mock_handler)
app.router.add_get("/another_path", mock_handler)
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@@ -451,7 +466,7 @@ async def test_auth_access_signed_path_with_query_param_tamper(
"""Test access with signed url and query params that have been tampered with."""
app.router.add_post("/", mock_handler)
app.router.add_get("/another_path", mock_handler)
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@@ -520,7 +535,7 @@ async def test_auth_access_signed_path_with_http(
)
app.router.add_get("/hello", mock_handler)
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@@ -544,7 +559,7 @@ async def test_auth_access_signed_path_with_content_user(
hass: HomeAssistant, app, aiohttp_client: ClientSessionGenerator
) -> None:
"""Test access signed url uses content user."""
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
signed_path = async_sign_path(hass, "/", timedelta(seconds=5))
signature = yarl.URL(signed_path).query["authSig"]
claims = jwt.decode(
@@ -564,7 +579,7 @@ async def test_local_only_user_rejected(
) -> None:
"""Test access with access token in header."""
token = hass_access_token
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
set_mock_ip = mock_real_ip(app)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@@ -630,7 +645,7 @@ async def test_create_user_once(hass: HomeAssistant) -> None:
"""Test that we reuse the user."""
cur_users = len(await hass.auth.async_get_users())
app = web.Application()
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
users = await hass.auth.async_get_users()
assert len(users) == cur_users + 1
@@ -642,7 +657,287 @@ async def test_create_user_once(hass: HomeAssistant) -> None:
assert len(user.refresh_tokens) == 1
assert user.system_generated
- await async_setup_auth(hass, app)
+ await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
# test it did not create a user
assert len(await hass.auth.async_get_users()) == cur_users + 1
+
+
+@pytest.fixture
+def app_strict_connection(hass):
+ """Fixture to set up a web.Application."""
+
+ async def handler(request):
+ """Return if request was authenticated."""
+ return web.json_response(data={"authenticated": request[KEY_AUTHENTICATED]})
+
+ app = web.Application()
+ app[KEY_HASS] = hass
+ app.router.add_get("/", handler)
+ async_setup_forwarded(app, True, [])
+ return app
+
+
+@pytest.mark.parametrize(
+ "strict_connection_mode", [e.value for e in StrictConnectionMode]
+)
+async def test_strict_connection_non_cloud_authenticated_requests(
+ hass: HomeAssistant,
+ app_strict_connection: web.Application,
+ aiohttp_client: ClientSessionGenerator,
+ hass_access_token: str,
+ strict_connection_mode: StrictConnectionMode,
+) -> None:
+ """Test authenticated requests with strict connection."""
+ token = hass_access_token
+ await async_setup_auth(hass, app_strict_connection, strict_connection_mode)
+ set_mock_ip = mock_real_ip(app_strict_connection)
+ client = await aiohttp_client(app_strict_connection)
+ refresh_token = hass.auth.async_validate_access_token(hass_access_token)
+ assert refresh_token
+ assert hass.auth.session._strict_connection_sessions == {}
+
+ signed_path = async_sign_path(
+ hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
+ )
+
+ for remote_addr in (*LOCALHOST_ADDRESSES, *PRIVATE_ADDRESSES, *EXTERNAL_ADDRESSES):
+ set_mock_ip(remote_addr)
+
+ # authorized requests should work normally
+ req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
+ assert req.status == HTTPStatus.OK
+ assert await req.json() == {"authenticated": True}
+ req = await client.get(signed_path)
+ assert req.status == HTTPStatus.OK
+ assert await req.json() == {"authenticated": True}
+
+
+@pytest.mark.parametrize(
+ "strict_connection_mode", [e.value for e in StrictConnectionMode]
+)
+async def test_strict_connection_non_cloud_local_unauthenticated_requests(
+ hass: HomeAssistant,
+ app_strict_connection: web.Application,
+ aiohttp_client: ClientSessionGenerator,
+ strict_connection_mode: StrictConnectionMode,
+) -> None:
+ """Test local unauthenticated requests with strict connection."""
+ await async_setup_auth(hass, app_strict_connection, strict_connection_mode)
+ set_mock_ip = mock_real_ip(app_strict_connection)
+ client = await aiohttp_client(app_strict_connection)
+ assert hass.auth.session._strict_connection_sessions == {}
+
+ for remote_addr in (*LOCALHOST_ADDRESSES, *PRIVATE_ADDRESSES):
+ set_mock_ip(remote_addr)
+ # local requests should work normally
+ req = await client.get("/")
+ assert req.status == HTTPStatus.OK
+ assert await req.json() == {"authenticated": False}
+
+
+def _add_set_cookie_endpoint(app: web.Application, refresh_token: RefreshToken) -> None:
+ """Add an endpoint to set a cookie."""
+
+ async def set_cookie(request: web.Request) -> web.Response:
+ hass = request.app[KEY_HASS]
+ # Clear all sessions
+ hass.auth.session._temp_sessions.clear()
+ hass.auth.session._strict_connection_sessions.clear()
+
+ if request.query["token"] == "refresh":
+ await hass.auth.session.async_create_session(request, refresh_token)
+ else:
+ await hass.auth.session.async_create_temp_unauthorized_session(request)
+ session = await get_session(request)
+ return web.Response(text=session[SESSION_ID])
+
+ app.router.add_get("/test/cookie", set_cookie)
+
+
+async def _test_strict_connection_non_cloud_enabled_setup(
+ hass: HomeAssistant,
+ app: web.Application,
+ aiohttp_client: ClientSessionGenerator,
+ hass_access_token: str,
+ strict_connection_mode: StrictConnectionMode,
+) -> tuple[TestClient, Callable[[str], None], RefreshToken]:
+ """Test external unauthenticated requests with strict connection non cloud enabled."""
+ refresh_token = hass.auth.async_validate_access_token(hass_access_token)
+ assert refresh_token
+ session = hass.auth.session
+ assert session._strict_connection_sessions == {}
+ assert session._temp_sessions == {}
+
+ _add_set_cookie_endpoint(app, refresh_token)
+ await async_setup_auth(hass, app, strict_connection_mode)
+ set_mock_ip = mock_real_ip(app)
+ client = await aiohttp_client(app)
+ return (client, set_mock_ip, refresh_token)
+
+
+async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests(
+ hass: HomeAssistant,
+ app: web.Application,
+ aiohttp_client: ClientSessionGenerator,
+ hass_access_token: str,
+ perform_unauthenticated_request: Callable[
+ [HomeAssistant, TestClient], Awaitable[None]
+ ],
+ strict_connection_mode: StrictConnectionMode,
+) -> None:
+ """Test external unauthenticated requests with strict connection non cloud enabled."""
+ client, set_mock_ip, _ = await _test_strict_connection_non_cloud_enabled_setup(
+ hass, app, aiohttp_client, hass_access_token, strict_connection_mode
+ )
+
+ for remote_addr in EXTERNAL_ADDRESSES:
+ set_mock_ip(remote_addr)
+ await perform_unauthenticated_request(hass, client)
+
+
+async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_refresh_token(
+ hass: HomeAssistant,
+ app: web.Application,
+ aiohttp_client: ClientSessionGenerator,
+ hass_access_token: str,
+ perform_unauthenticated_request: Callable[
+ [HomeAssistant, TestClient], Awaitable[None]
+ ],
+ strict_connection_mode: StrictConnectionMode,
+) -> None:
+ """Test external unauthenticated requests with strict connection non cloud enabled and refresh token cookie."""
+ (
+ client,
+ set_mock_ip,
+ refresh_token,
+ ) = await _test_strict_connection_non_cloud_enabled_setup(
+ hass, app, aiohttp_client, hass_access_token, strict_connection_mode
+ )
+ session = hass.auth.session
+
+ # set strict connection cookie with refresh token
+ set_mock_ip(LOCALHOST_ADDRESSES[0])
+ session_id = await (await client.get("/test/cookie?token=refresh")).text()
+ assert session._strict_connection_sessions == {session_id: refresh_token.id}
+ for remote_addr in EXTERNAL_ADDRESSES:
+ set_mock_ip(remote_addr)
+ req = await client.get("/")
+ assert req.status == HTTPStatus.OK
+ assert await req.json() == {"authenticated": False}
+
+ # Invalidate refresh token, which should also invalidate session
+ hass.auth.async_remove_refresh_token(refresh_token)
+ assert session._strict_connection_sessions == {}
+ for remote_addr in EXTERNAL_ADDRESSES:
+ set_mock_ip(remote_addr)
+ await perform_unauthenticated_request(hass, client)
+
+
+async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_temp_session(
+ hass: HomeAssistant,
+ app: web.Application,
+ aiohttp_client: ClientSessionGenerator,
+ hass_access_token: str,
+ perform_unauthenticated_request: Callable[
+ [HomeAssistant, TestClient], Awaitable[None]
+ ],
+ strict_connection_mode: StrictConnectionMode,
+) -> None:
+ """Test external unauthenticated requests with strict connection non cloud enabled and temp cookie."""
+ client, set_mock_ip, _ = await _test_strict_connection_non_cloud_enabled_setup(
+ hass, app, aiohttp_client, hass_access_token, strict_connection_mode
+ )
+ session = hass.auth.session
+
+ # set strict connection cookie with temp session
+ assert session._temp_sessions == {}
+ set_mock_ip(LOCALHOST_ADDRESSES[0])
+ session_id = await (await client.get("/test/cookie?token=temp")).text()
+ assert client.session.cookie_jar.filter_cookies(URL("http://127.0.0.1"))
+ assert session_id in session._temp_sessions
+ for remote_addr in EXTERNAL_ADDRESSES:
+ set_mock_ip(remote_addr)
+ resp = await client.get("/")
+ assert resp.status == HTTPStatus.OK
+ assert await resp.json() == {"authenticated": False}
+
+ async_fire_time_changed(hass, utcnow() + TEMP_TIMEOUT + timedelta(minutes=1))
+ await hass.async_block_till_done(wait_background_tasks=True)
+
+ assert session._temp_sessions == {}
+ for remote_addr in EXTERNAL_ADDRESSES:
+ set_mock_ip(remote_addr)
+ await perform_unauthenticated_request(hass, client)
+
+
+async def _drop_connection_unauthorized_request(
+ _: HomeAssistant, client: TestClient
+) -> None:
+ with pytest.raises(ServerDisconnectedError):
+ # unauthorized requests should raise ServerDisconnectedError
+ await client.get("/")
+
+
+async def _static_page_unauthorized_request(
+ hass: HomeAssistant, client: TestClient
+) -> None:
+ req = await client.get("/")
+ assert req.status == HTTPStatus.IM_A_TEAPOT
+
+ def read_static_page() -> str:
+ with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
+ return file.read()
+
+ assert await req.text() == await hass.async_add_executor_job(read_static_page)
+
+
+@pytest.mark.parametrize(
+ "test_func",
+ [
+ _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests,
+ _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_refresh_token,
+ _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_temp_session,
+ ],
+ ids=[
+ "no cookie",
+ "refresh token cookie",
+ "temp session cookie",
+ ],
+)
+@pytest.mark.parametrize(
+ ("strict_connection_mode", "request_func"),
+ [
+ (StrictConnectionMode.DROP_CONNECTION, _drop_connection_unauthorized_request),
+ (StrictConnectionMode.STATIC_PAGE, _static_page_unauthorized_request),
+ ],
+ ids=["drop connection", "static page"],
+)
+async def test_strict_connection_non_cloud_external_unauthenticated_requests(
+ hass: HomeAssistant,
+ app_strict_connection: web.Application,
+ aiohttp_client: ClientSessionGenerator,
+ hass_access_token: str,
+ test_func: Callable[
+ [
+ HomeAssistant,
+ web.Application,
+ ClientSessionGenerator,
+ str,
+ Callable[[HomeAssistant, TestClient], Awaitable[None]],
+ StrictConnectionMode,
+ ],
+ Awaitable[None],
+ ],
+ strict_connection_mode: StrictConnectionMode,
+ request_func: Callable[[HomeAssistant, TestClient], Awaitable[None]],
+) -> None:
+ """Test external unauthenticated requests with strict connection non cloud."""
+ await test_func(
+ hass,
+ app_strict_connection,
+ aiohttp_client,
+ hass_access_token,
+ request_func,
+ strict_connection_mode,
+ )
diff --git a/tests/components/http/test_init.py b/tests/components/http/test_init.py
index 9e892e2ee43..b84da595ab1 100644
--- a/tests/components/http/test_init.py
+++ b/tests/components/http/test_init.py
@@ -7,6 +7,7 @@ from ipaddress import ip_network
import logging
from pathlib import Path
from unittest.mock import Mock, patch
+from urllib.parse import quote_plus
import pytest
@@ -14,7 +15,10 @@ from homeassistant.auth.providers.legacy_api_password import (
LegacyApiPasswordAuthProvider,
)
from homeassistant.components import http
+from homeassistant.components.http.const import StrictConnectionMode
+from homeassistant.config import async_process_ha_core_config
from homeassistant.core import HomeAssistant
+from homeassistant.exceptions import ServiceValidationError
from homeassistant.helpers.http import KEY_HASS
from homeassistant.helpers.network import NoURLAvailableError
from homeassistant.setup import async_setup_component
@@ -521,3 +525,78 @@ async def test_logging(
response = await client.get("/api/states/logging.entity")
assert response.status == HTTPStatus.OK
assert "GET /api/states/logging.entity" not in caplog.text
+
+
+async def test_service_create_temporary_strict_connection_url_strict_connection_disabled(
+ hass: HomeAssistant,
+) -> None:
+ """Test service create_temporary_strict_connection_url with strict_connection not enabled."""
+ assert await async_setup_component(hass, http.DOMAIN, {"http": {}})
+ with pytest.raises(
+ ServiceValidationError,
+ match="Strict connection is not enabled for non-cloud requests",
+ ):
+ await hass.services.async_call(
+ http.DOMAIN,
+ "create_temporary_strict_connection_url",
+ blocking=True,
+ return_response=True,
+ )
+
+
+@pytest.mark.parametrize(
+ ("mode"),
+ [
+ StrictConnectionMode.DROP_CONNECTION,
+ StrictConnectionMode.STATIC_PAGE,
+ ],
+)
+async def test_service_create_temporary_strict_connection(
+ hass: HomeAssistant, mode: StrictConnectionMode
+) -> None:
+ """Test service create_temporary_strict_connection_url."""
+ assert await async_setup_component(
+ hass, http.DOMAIN, {"http": {"strict_connection": mode}}
+ )
+
+ # No external url set
+ assert hass.config.external_url is None
+ assert hass.config.internal_url is None
+ with pytest.raises(ServiceValidationError, match="No external URL available"):
+ await hass.services.async_call(
+ http.DOMAIN,
+ "create_temporary_strict_connection_url",
+ blocking=True,
+ return_response=True,
+ )
+
+ # Raise if only internal url is available
+ hass.config.api = Mock(use_ssl=False, port=8123, local_ip="192.168.123.123")
+ with pytest.raises(ServiceValidationError, match="No external URL available"):
+ await hass.services.async_call(
+ http.DOMAIN,
+ "create_temporary_strict_connection_url",
+ blocking=True,
+ return_response=True,
+ )
+
+ # Set external url too
+ external_url = "https://example.com"
+ await async_process_ha_core_config(
+ hass,
+ {"external_url": external_url},
+ )
+ assert hass.config.external_url == external_url
+ response = await hass.services.async_call(
+ http.DOMAIN,
+ "create_temporary_strict_connection_url",
+ blocking=True,
+ return_response=True,
+ )
+ assert isinstance(response, dict)
+ direct_url_prefix = f"{external_url}/auth/strict_connection/temp_token?authSig="
+ assert response.pop("direct_url").startswith(direct_url_prefix)
+ assert response.pop("url").startswith(
+ f"https://login.home-assistant.io?u={quote_plus(direct_url_prefix)}"
+ )
+ assert response == {} # No more keys in response
diff --git a/tests/components/http/test_session.py b/tests/components/http/test_session.py
new file mode 100644
index 00000000000..ae62365749a
--- /dev/null
+++ b/tests/components/http/test_session.py
@@ -0,0 +1,107 @@
+"""Tests for HTTP session."""
+
+from collections.abc import Callable
+import logging
+from typing import Any
+from unittest.mock import patch
+
+from aiohttp import web
+from aiohttp.test_utils import make_mocked_request
+import pytest
+
+from homeassistant.auth.session import SESSION_ID
+from homeassistant.components.http.session import (
+ COOKIE_NAME,
+ HomeAssistantCookieStorage,
+)
+from homeassistant.core import HomeAssistant
+
+
+def fake_request_with_strict_connection_cookie(cookie_value: str) -> web.Request:
+ """Return a fake request with a strict connection cookie."""
+ request = make_mocked_request(
+ "GET", "/", headers={"Cookie": f"{COOKIE_NAME}={cookie_value}"}
+ )
+ assert COOKIE_NAME in request.cookies
+ return request
+
+
+@pytest.fixture
+def cookie_storage(hass: HomeAssistant) -> HomeAssistantCookieStorage:
+ """Fixture for the cookie storage."""
+ return HomeAssistantCookieStorage(hass)
+
+
+def _encrypt_cookie_data(cookie_storage: HomeAssistantCookieStorage, data: Any) -> str:
+ """Encrypt cookie data."""
+ cookie_data = cookie_storage._encoder(data).encode("utf-8")
+ return cookie_storage._fernet.encrypt(cookie_data).decode("utf-8")
+
+
+@pytest.mark.parametrize(
+ "func",
+ [
+ lambda _: "invalid",
+ lambda storage: _encrypt_cookie_data(storage, "bla"),
+ lambda storage: _encrypt_cookie_data(storage, None),
+ ],
+)
+async def test_load_session_modified_cookies(
+ cookie_storage: HomeAssistantCookieStorage,
+ caplog: pytest.LogCaptureFixture,
+ func: Callable[[HomeAssistantCookieStorage], str],
+) -> None:
+ """Test that on modified cookies the session is empty and the request will be logged for ban."""
+ request = fake_request_with_strict_connection_cookie(func(cookie_storage))
+ with patch(
+ "homeassistant.components.http.session.process_wrong_login",
+ ) as mock_process_wrong_login:
+ session = await cookie_storage.load_session(request)
+ assert session.empty
+ assert (
+ "homeassistant.components.http.session",
+ logging.WARNING,
+ "Cannot decrypt/parse cookie value",
+ ) in caplog.record_tuples
+ mock_process_wrong_login.assert_called()
+
+
+async def test_load_session_validate_session(
+ hass: HomeAssistant,
+ cookie_storage: HomeAssistantCookieStorage,
+) -> None:
+ """Test load session validates the session."""
+ session = await cookie_storage.new_session()
+ session[SESSION_ID] = "bla"
+ request = fake_request_with_strict_connection_cookie(
+ _encrypt_cookie_data(cookie_storage, cookie_storage._get_session_data(session))
+ )
+
+ with patch.object(
+ hass.auth.session, "async_validate_strict_connection_session", return_value=True
+ ) as mock_validate:
+ session = await cookie_storage.load_session(request)
+ assert not session.empty
+ assert session[SESSION_ID] == "bla"
+ mock_validate.assert_called_with(session)
+
+ # verify lru_cache is working
+ mock_validate.reset_mock()
+ await cookie_storage.load_session(request)
+ mock_validate.assert_not_called()
+
+ session = await cookie_storage.new_session()
+ session[SESSION_ID] = "something"
+ request = fake_request_with_strict_connection_cookie(
+ _encrypt_cookie_data(cookie_storage, cookie_storage._get_session_data(session))
+ )
+
+ with patch.object(
+ hass.auth.session,
+ "async_validate_strict_connection_session",
+ return_value=False,
+ ):
+ session = await cookie_storage.load_session(request)
+ assert session.empty
+ assert SESSION_ID not in session
+ assert session._changed
diff --git a/tests/components/stream/conftest.py b/tests/components/stream/conftest.py
index 9ce23d99152..280d15cd1ef 100644
--- a/tests/components/stream/conftest.py
+++ b/tests/components/stream/conftest.py
@@ -14,7 +14,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Generator
-from http import HTTPStatus
import logging
import threading
from unittest.mock import Mock, patch
@@ -87,6 +86,17 @@ class HLSSync:
self._num_recvs = 0
self._num_finished = 0
+ def on_resp():
+ self._num_finished += 1
+ self.check_requests_ready()
+
+ class SyncResponse(web.Response):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ on_resp()
+
+ self.response = SyncResponse
+
def reset_request_pool(self, num_requests: int, reset_finished=True):
"""Use to reset the request counter between segments."""
self._num_recvs = 0
@@ -120,12 +130,6 @@ class HLSSync:
self.check_requests_ready()
return self._original_not_found()
- def response(self, body, headers=None, status=HTTPStatus.OK):
- """Intercept the Response call so we know when the web handler is finished."""
- self._num_finished += 1
- self.check_requests_ready()
- return self._original_response(body=body, headers=headers, status=status)
-
async def recv(self, output: StreamOutput, **kw):
"""Intercept the recv call so we know when the response is blocking on recv."""
self._num_recvs += 1
@@ -164,7 +168,7 @@ def hls_sync():
),
patch(
"homeassistant.components.stream.hls.web.Response",
- side_effect=sync.response,
+ new=sync.response,
),
):
yield sync
diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py
index e96f1c4f903..2bd76accfdd 100644
--- a/tests/components/websocket_api/test_commands.py
+++ b/tests/components/websocket_api/test_commands.py
@@ -701,7 +701,7 @@ async def test_get_services(
assert msg["id"] == id_
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
- assert msg["result"] == hass.services.async_services()
+ assert msg["result"].keys() == hass.services.async_services().keys()
async def test_get_config(
diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py
index 74b8a86ce7c..b5e71f4c9d8 100644
--- a/tests/helpers/test_service.py
+++ b/tests/helpers/test_service.py
@@ -7,6 +7,7 @@ from typing import Any
from unittest.mock import AsyncMock, Mock, patch
import pytest
+from pytest_unordered import unordered
import voluptuous as vol
# To prevent circular import when running just this file
@@ -16,6 +17,7 @@ import homeassistant.components # noqa: F401
from homeassistant.components.group import DOMAIN as DOMAIN_GROUP, Group
from homeassistant.components.logger import DOMAIN as DOMAIN_LOGGER
from homeassistant.components.shell_command import DOMAIN as DOMAIN_SHELL_COMMAND
+from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH
from homeassistant.const import (
ATTR_ENTITY_ID,
ENTITY_MATCH_ALL,
@@ -785,7 +787,7 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
"""Test async_get_all_descriptions."""
group_config = {DOMAIN_GROUP: {}}
assert await async_setup_component(hass, DOMAIN_GROUP, group_config)
- assert await async_setup_component(hass, "system_health", {})
+ assert await async_setup_component(hass, DOMAIN_SYSTEM_HEALTH, {})
with patch(
"homeassistant.helpers.service._load_services_files",
@@ -795,17 +797,20 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
# Test we only load services.yaml for integrations with services.yaml
# And system_health has no services
- assert proxy_load_services_files.mock_calls[0][1][1] == [
- await async_get_integration(hass, "group")
- ]
+ assert proxy_load_services_files.mock_calls[0][1][1] == unordered(
+ [
+ await async_get_integration(hass, DOMAIN_GROUP),
+ await async_get_integration(hass, "http"), # system_health requires http
+ ]
+ )
- assert len(descriptions) == 1
-
- assert "description" in descriptions["group"]["reload"]
- assert "fields" in descriptions["group"]["reload"]
+ assert len(descriptions) == 2
+ assert DOMAIN_GROUP in descriptions
+ assert "description" in descriptions[DOMAIN_GROUP]["reload"]
+ assert "fields" in descriptions[DOMAIN_GROUP]["reload"]
# Does not have services
- assert "system_health" not in descriptions
+ assert DOMAIN_SYSTEM_HEALTH not in descriptions
logger_config = {DOMAIN_LOGGER: {}}
@@ -833,8 +838,8 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
await async_setup_component(hass, DOMAIN_LOGGER, logger_config)
descriptions = await service.async_get_all_descriptions(hass)
- assert len(descriptions) == 2
-
+ assert len(descriptions) == 3
+ assert DOMAIN_LOGGER in descriptions
assert descriptions[DOMAIN_LOGGER]["set_default_level"]["name"] == "Translated name"
assert (
descriptions[DOMAIN_LOGGER]["set_default_level"]["description"]
diff --git a/tests/scripts/test_check_config.py b/tests/scripts/test_check_config.py
index 79c64259f8b..76acb2ff678 100644
--- a/tests/scripts/test_check_config.py
+++ b/tests/scripts/test_check_config.py
@@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest
+from homeassistant.components.http.const import StrictConnectionMode
from homeassistant.config import YAML_CONFIG_FILE
from homeassistant.scripts import check_config
@@ -134,6 +135,7 @@ def test_secrets(mock_is_file, event_loop, mock_hass_config_yaml: None) -> None:
"login_attempts_threshold": -1,
"server_port": 8123,
"ssl_profile": "modern",
+ "strict_connection": StrictConnectionMode.DISABLED,
"use_x_frame_options": True,
"server_host": ["0.0.0.0", "::"],
}