mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Remove strict connection (#117933)
This commit is contained in:
parent
6f81852eb4
commit
cb62f4242e
@ -28,7 +28,6 @@ from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRA
|
||||
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
||||
from .models import AuthFlowResult
|
||||
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
||||
from .session import SessionManager
|
||||
|
||||
EVENT_USER_ADDED = "user_added"
|
||||
EVENT_USER_UPDATED = "user_updated"
|
||||
@ -181,7 +180,6 @@ class AuthManager:
|
||||
self._remove_expired_job = HassJob(
|
||||
self._async_remove_expired_refresh_tokens, job_type=HassJobType.Callback
|
||||
)
|
||||
self.session = SessionManager(hass, self)
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
"""Set up the auth manager."""
|
||||
@ -192,7 +190,6 @@ class AuthManager:
|
||||
)
|
||||
)
|
||||
self._async_track_next_refresh_token_expiration()
|
||||
await self.session.async_setup()
|
||||
|
||||
@property
|
||||
def auth_providers(self) -> list[AuthProvider]:
|
||||
|
@ -1,205 +0,0 @@
|
||||
"""Session auth module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import secrets
|
||||
from typing import TYPE_CHECKING, Final, TypedDict
|
||||
|
||||
from aiohttp.web import Request
|
||||
from aiohttp_session import Session, get_session, new_session
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .models import RefreshToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import AuthManager
|
||||
|
||||
|
||||
TEMP_TIMEOUT = timedelta(minutes=5)
|
||||
TEMP_TIMEOUT_SECONDS = TEMP_TIMEOUT.total_seconds()
|
||||
|
||||
SESSION_ID = "id"
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = "auth.session"
|
||||
|
||||
|
||||
class StrictConnectionTempSessionData:
|
||||
"""Data for accessing unauthorized resources for a short period of time."""
|
||||
|
||||
__slots__ = ("cancel_remove", "absolute_expiry")
|
||||
|
||||
def __init__(self, cancel_remove: CALLBACK_TYPE) -> None:
|
||||
"""Initialize the temp session data."""
|
||||
self.cancel_remove: Final[CALLBACK_TYPE] = cancel_remove
|
||||
self.absolute_expiry: Final[datetime] = dt_util.utcnow() + TEMP_TIMEOUT
|
||||
|
||||
|
||||
class StoreData(TypedDict):
|
||||
"""Data to store."""
|
||||
|
||||
unauthorized_sessions: dict[str, str]
|
||||
key: str
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Session manager."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, auth: AuthManager) -> None:
|
||||
"""Initialize the strict connection manager."""
|
||||
self._auth = auth
|
||||
self._hass = hass
|
||||
self._temp_sessions: dict[str, StrictConnectionTempSessionData] = {}
|
||||
self._strict_connection_sessions: dict[str, str] = {}
|
||||
self._store = Store[StoreData](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._key: str | None = None
|
||||
self._refresh_token_revoke_callbacks: dict[str, CALLBACK_TYPE] = {}
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return the encryption key."""
|
||||
if self._key is None:
|
||||
self._key = Fernet.generate_key().decode()
|
||||
self._async_schedule_save()
|
||||
return self._key
|
||||
|
||||
async def async_validate_request_for_strict_connection_session(
|
||||
self,
|
||||
request: Request,
|
||||
) -> bool:
|
||||
"""Check if a request has a valid strict connection session."""
|
||||
session = await get_session(request)
|
||||
if session.new or session.empty:
|
||||
return False
|
||||
result = self.async_validate_strict_connection_session(session)
|
||||
if result is False:
|
||||
session.invalidate()
|
||||
return result
|
||||
|
||||
@callback
|
||||
def async_validate_strict_connection_session(
|
||||
self,
|
||||
session: Session,
|
||||
) -> bool:
|
||||
"""Validate a strict connection session."""
|
||||
if not (session_id := session.get(SESSION_ID)):
|
||||
return False
|
||||
|
||||
if token_id := self._strict_connection_sessions.get(session_id):
|
||||
if self._auth.async_get_refresh_token(token_id):
|
||||
return True
|
||||
# refresh token is invalid, delete entry
|
||||
self._strict_connection_sessions.pop(session_id)
|
||||
self._async_schedule_save()
|
||||
|
||||
if data := self._temp_sessions.get(session_id):
|
||||
if dt_util.utcnow() <= data.absolute_expiry:
|
||||
return True
|
||||
# session expired, delete entry
|
||||
self._temp_sessions.pop(session_id).cancel_remove()
|
||||
|
||||
return False
|
||||
|
||||
@callback
|
||||
def _async_register_revoke_token_callback(self, refresh_token_id: str) -> None:
|
||||
"""Register a callback to revoke all sessions for a refresh token."""
|
||||
if refresh_token_id in self._refresh_token_revoke_callbacks:
|
||||
return
|
||||
|
||||
@callback
|
||||
def async_invalidate_auth_sessions() -> None:
|
||||
"""Invalidate all sessions for a refresh token."""
|
||||
self._strict_connection_sessions = {
|
||||
session_id: token_id
|
||||
for session_id, token_id in self._strict_connection_sessions.items()
|
||||
if token_id != refresh_token_id
|
||||
}
|
||||
self._async_schedule_save()
|
||||
|
||||
self._refresh_token_revoke_callbacks[refresh_token_id] = (
|
||||
self._auth.async_register_revoke_token_callback(
|
||||
refresh_token_id, async_invalidate_auth_sessions
|
||||
)
|
||||
)
|
||||
|
||||
async def async_create_session(
|
||||
self,
|
||||
request: Request,
|
||||
refresh_token: RefreshToken,
|
||||
) -> None:
|
||||
"""Create new session for given refresh token.
|
||||
|
||||
Caller needs to make sure that the refresh token is valid.
|
||||
By creating a session, we are implicitly revoking all other
|
||||
sessions for the given refresh token as there is one refresh
|
||||
token per device/user case.
|
||||
"""
|
||||
self._strict_connection_sessions = {
|
||||
session_id: token_id
|
||||
for session_id, token_id in self._strict_connection_sessions.items()
|
||||
if token_id != refresh_token.id
|
||||
}
|
||||
|
||||
self._async_register_revoke_token_callback(refresh_token.id)
|
||||
session_id = await self._async_create_new_session(request)
|
||||
self._strict_connection_sessions[session_id] = refresh_token.id
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_create_temp_unauthorized_session(self, request: Request) -> None:
|
||||
"""Create a temporary unauthorized session."""
|
||||
session_id = await self._async_create_new_session(
|
||||
request, max_age=int(TEMP_TIMEOUT_SECONDS)
|
||||
)
|
||||
|
||||
@callback
|
||||
def remove(_: datetime) -> None:
|
||||
self._temp_sessions.pop(session_id, None)
|
||||
|
||||
self._temp_sessions[session_id] = StrictConnectionTempSessionData(
|
||||
async_call_later(self._hass, TEMP_TIMEOUT_SECONDS, remove)
|
||||
)
|
||||
|
||||
async def _async_create_new_session(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
max_age: int | None = None,
|
||||
) -> str:
|
||||
session_id = secrets.token_hex(64)
|
||||
|
||||
session = await new_session(request)
|
||||
session[SESSION_ID] = session_id
|
||||
if max_age is not None:
|
||||
session.max_age = max_age
|
||||
return session_id
|
||||
|
||||
@callback
|
||||
def _async_schedule_save(self, delay: float = 1) -> None:
|
||||
"""Save sessions."""
|
||||
self._store.async_delay_save(self._data_to_save, delay)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> StoreData:
|
||||
"""Return the data to store."""
|
||||
return StoreData(
|
||||
unauthorized_sessions=self._strict_connection_sessions,
|
||||
key=self.key,
|
||||
)
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
"""Set up session manager."""
|
||||
data = await self._store.async_load()
|
||||
if data is None:
|
||||
return
|
||||
|
||||
self._key = data["key"]
|
||||
self._strict_connection_sessions = data["unauthorized_sessions"]
|
||||
for token_id in self._strict_connection_sessions.values():
|
||||
self._async_register_revoke_token_callback(token_id)
|
@ -162,7 +162,6 @@ from homeassistant.util import dt as dt_util
|
||||
from . import indieauth, login_flow, mfa_setup_flow
|
||||
|
||||
DOMAIN = "auth"
|
||||
STRICT_CONNECTION_URL = "/auth/strict_connection/temp_token"
|
||||
|
||||
type StoreResultType = Callable[[str, Credentials], str]
|
||||
type RetrieveResultType = Callable[[str, str], Credentials | None]
|
||||
@ -188,7 +187,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
hass.http.register_view(RevokeTokenView())
|
||||
hass.http.register_view(LinkUserView(retrieve_result))
|
||||
hass.http.register_view(OAuth2AuthorizeCallbackView())
|
||||
hass.http.register_view(StrictConnectionTempTokenView())
|
||||
|
||||
websocket_api.async_register_command(hass, websocket_current_user)
|
||||
websocket_api.async_register_command(hass, websocket_create_long_lived_access_token)
|
||||
@ -323,7 +321,6 @@ class TokenView(HomeAssistantView):
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
await hass.auth.session.async_create_session(request, refresh_token)
|
||||
return self.json(
|
||||
{
|
||||
"access_token": access_token,
|
||||
@ -392,7 +389,6 @@ class TokenView(HomeAssistantView):
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
await hass.auth.session.async_create_session(request, refresh_token)
|
||||
return self.json(
|
||||
{
|
||||
"access_token": access_token,
|
||||
@ -441,20 +437,6 @@ class LinkUserView(HomeAssistantView):
|
||||
return self.json_message("User linked")
|
||||
|
||||
|
||||
class StrictConnectionTempTokenView(HomeAssistantView):
|
||||
"""View to get temporary strict connection token."""
|
||||
|
||||
url = STRICT_CONNECTION_URL
|
||||
name = "api:auth:strict_connection:temp_token"
|
||||
requires_auth = False
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Get a temporary token and redirect to main page."""
|
||||
hass = request.app[KEY_HASS]
|
||||
await hass.auth.session.async_create_temp_unauthorized_session(request)
|
||||
raise web.HTTPSeeOther(location="/")
|
||||
|
||||
|
||||
@callback
|
||||
def _create_auth_code_store() -> tuple[StoreResultType, RetrieveResultType]:
|
||||
"""Create an in memory store."""
|
||||
|
@ -7,14 +7,11 @@ from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
from urllib.parse import quote_plus, urljoin
|
||||
|
||||
from hass_nabucasa import Cloud
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import alexa, google_assistant, http
|
||||
from homeassistant.components.auth import STRICT_CONNECTION_URL
|
||||
from homeassistant.components.http.auth import async_sign_path
|
||||
from homeassistant.components import alexa, google_assistant
|
||||
from homeassistant.config_entries import SOURCE_SYSTEM, ConfigEntry
|
||||
from homeassistant.const import (
|
||||
CONF_DESCRIPTION,
|
||||
@ -24,21 +21,8 @@ from homeassistant.const import (
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
Platform,
|
||||
)
|
||||
from homeassistant.core import (
|
||||
Event,
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import (
|
||||
HomeAssistantError,
|
||||
ServiceValidationError,
|
||||
Unauthorized,
|
||||
UnknownUser,
|
||||
)
|
||||
from homeassistant.core import Event, HassJob, HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, entityfilter
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.discovery import async_load_platform
|
||||
@ -47,7 +31,6 @@ from homeassistant.helpers.dispatcher import (
|
||||
async_dispatcher_send,
|
||||
)
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.helpers.network import NoURLAvailableError, get_url
|
||||
from homeassistant.helpers.service import async_register_admin_service
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
@ -418,50 +401,3 @@ def _setup_services(hass: HomeAssistant, prefs: CloudPreferences) -> None:
|
||||
async_register_admin_service(
|
||||
hass, DOMAIN, SERVICE_REMOTE_DISCONNECT, _service_handler
|
||||
)
|
||||
|
||||
async def create_temporary_strict_connection_url(
|
||||
call: ServiceCall,
|
||||
) -> ServiceResponse:
|
||||
"""Create a strict connection url and return it."""
|
||||
# Copied form homeassistant/helpers/service.py#_async_admin_handler
|
||||
# as the helper supports no responses yet
|
||||
if call.context.user_id:
|
||||
user = await hass.auth.async_get_user(call.context.user_id)
|
||||
if user is None:
|
||||
raise UnknownUser(context=call.context)
|
||||
if not user.is_admin:
|
||||
raise Unauthorized(context=call.context)
|
||||
|
||||
if prefs.strict_connection is http.const.StrictConnectionMode.DISABLED:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="strict_connection_not_enabled",
|
||||
)
|
||||
|
||||
try:
|
||||
url = get_url(hass, require_cloud=True)
|
||||
except NoURLAvailableError as ex:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="no_url_available",
|
||||
) from ex
|
||||
|
||||
path = async_sign_path(
|
||||
hass,
|
||||
STRICT_CONNECTION_URL,
|
||||
timedelta(hours=1),
|
||||
use_content_user=True,
|
||||
)
|
||||
url = urljoin(url, path)
|
||||
|
||||
return {
|
||||
"url": f"https://login.home-assistant.io?u={quote_plus(url)}",
|
||||
"direct_url": url,
|
||||
}
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
create_temporary_strict_connection_url,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
|
@ -250,7 +250,6 @@ class CloudClient(Interface):
|
||||
"enabled": self._prefs.remote_enabled,
|
||||
"instance_domain": self.cloud.remote.instance_domain,
|
||||
"alias": self.cloud.remote.alias,
|
||||
"strict_connection": self._prefs.strict_connection,
|
||||
},
|
||||
"version": HA_VERSION,
|
||||
"instance_id": self.prefs.instance_id,
|
||||
|
@ -33,7 +33,6 @@ PREF_GOOGLE_SETTINGS_VERSION = "google_settings_version"
|
||||
PREF_TTS_DEFAULT_VOICE = "tts_default_voice"
|
||||
PREF_GOOGLE_CONNECTED = "google_connected"
|
||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE = "remote_allow_remote_enable"
|
||||
PREF_STRICT_CONNECTION = "strict_connection"
|
||||
DEFAULT_TTS_DEFAULT_VOICE = ("en-US", "JennyNeural")
|
||||
DEFAULT_DISABLE_2FA = False
|
||||
DEFAULT_ALEXA_REPORT_STATE = True
|
||||
|
@ -19,7 +19,7 @@ from hass_nabucasa.const import STATE_DISCONNECTED
|
||||
from hass_nabucasa.voice import TTS_VOICES
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import http, websocket_api
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.alexa import (
|
||||
entities as alexa_entities,
|
||||
errors as alexa_errors,
|
||||
@ -46,7 +46,6 @@ from .const import (
|
||||
PREF_GOOGLE_REPORT_STATE,
|
||||
PREF_GOOGLE_SECURE_DEVICES_PIN,
|
||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE,
|
||||
PREF_STRICT_CONNECTION,
|
||||
PREF_TTS_DEFAULT_VOICE,
|
||||
REQUEST_TIMEOUT,
|
||||
)
|
||||
@ -449,9 +448,6 @@ def validate_language_voice(value: tuple[str, str]) -> tuple[str, str]:
|
||||
vol.Coerce(tuple), validate_language_voice
|
||||
),
|
||||
vol.Optional(PREF_REMOTE_ALLOW_REMOTE_ENABLE): bool,
|
||||
vol.Optional(PREF_STRICT_CONNECTION): vol.Coerce(
|
||||
http.const.StrictConnectionMode
|
||||
),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
|
@ -1,6 +1,5 @@
|
||||
{
|
||||
"services": {
|
||||
"create_temporary_strict_connection_url": "mdi:login-variant",
|
||||
"remote_connect": "mdi:cloud",
|
||||
"remote_disconnect": "mdi:cloud-off"
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ from hass_nabucasa.voice import MAP_VOICE
|
||||
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.components import http, webhook
|
||||
from homeassistant.components import webhook
|
||||
from homeassistant.components.google_assistant.http import (
|
||||
async_get_users as async_get_google_assistant_users,
|
||||
)
|
||||
@ -44,7 +44,6 @@ from .const import (
|
||||
PREF_INSTANCE_ID,
|
||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE,
|
||||
PREF_REMOTE_DOMAIN,
|
||||
PREF_STRICT_CONNECTION,
|
||||
PREF_TTS_DEFAULT_VOICE,
|
||||
PREF_USERNAME,
|
||||
)
|
||||
@ -177,7 +176,6 @@ class CloudPreferences:
|
||||
google_settings_version: int | UndefinedType = UNDEFINED,
|
||||
google_connected: bool | UndefinedType = UNDEFINED,
|
||||
remote_allow_remote_enable: bool | UndefinedType = UNDEFINED,
|
||||
strict_connection: http.const.StrictConnectionMode | UndefinedType = UNDEFINED,
|
||||
) -> None:
|
||||
"""Update user preferences."""
|
||||
prefs = {**self._prefs}
|
||||
@ -197,7 +195,6 @@ class CloudPreferences:
|
||||
(PREF_REMOTE_DOMAIN, remote_domain),
|
||||
(PREF_GOOGLE_CONNECTED, google_connected),
|
||||
(PREF_REMOTE_ALLOW_REMOTE_ENABLE, remote_allow_remote_enable),
|
||||
(PREF_STRICT_CONNECTION, strict_connection),
|
||||
):
|
||||
if value is not UNDEFINED:
|
||||
prefs[key] = value
|
||||
@ -245,7 +242,6 @@ class CloudPreferences:
|
||||
PREF_GOOGLE_SECURE_DEVICES_PIN: self.google_secure_devices_pin,
|
||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE: self.remote_allow_remote_enable,
|
||||
PREF_TTS_DEFAULT_VOICE: self.tts_default_voice,
|
||||
PREF_STRICT_CONNECTION: self.strict_connection,
|
||||
}
|
||||
|
||||
@property
|
||||
@ -362,20 +358,6 @@ class CloudPreferences:
|
||||
"""
|
||||
return self._prefs.get(PREF_TTS_DEFAULT_VOICE, DEFAULT_TTS_DEFAULT_VOICE) # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def strict_connection(self) -> http.const.StrictConnectionMode:
|
||||
"""Return the strict connection mode."""
|
||||
mode = self._prefs.get(PREF_STRICT_CONNECTION)
|
||||
|
||||
if mode is None:
|
||||
# Set to default value
|
||||
# We store None in the store as the default value to detect if the user has changed the
|
||||
# value or not.
|
||||
mode = http.const.StrictConnectionMode.DISABLED
|
||||
elif not isinstance(mode, http.const.StrictConnectionMode):
|
||||
mode = http.const.StrictConnectionMode(mode)
|
||||
return mode
|
||||
|
||||
async def get_cloud_user(self) -> str:
|
||||
"""Return ID of Home Assistant Cloud system user."""
|
||||
user = await self._load_cloud_user()
|
||||
@ -433,5 +415,4 @@ class CloudPreferences:
|
||||
PREF_REMOTE_DOMAIN: None,
|
||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE: True,
|
||||
PREF_USERNAME: username,
|
||||
PREF_STRICT_CONNECTION: None,
|
||||
}
|
||||
|
@ -5,14 +5,6 @@
|
||||
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]"
|
||||
}
|
||||
},
|
||||
"exceptions": {
|
||||
"strict_connection_not_enabled": {
|
||||
"message": "Strict connection is not enabled for cloud requests"
|
||||
},
|
||||
"no_url_available": {
|
||||
"message": "No cloud URL available.\nPlease mark sure you have a working Remote UI."
|
||||
}
|
||||
},
|
||||
"system_health": {
|
||||
"info": {
|
||||
"can_reach_cert_server": "Reach Certificate Server",
|
||||
@ -81,10 +73,6 @@
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"create_temporary_strict_connection_url": {
|
||||
"name": "Create a temporary strict connection URL",
|
||||
"description": "Create a temporary strict connection URL, which can be used to login on another device."
|
||||
},
|
||||
"remote_connect": {
|
||||
"name": "Remote connect",
|
||||
"description": "Makes the instance UI accessible from outside of the local network by using Home Assistant Cloud."
|
||||
|
@ -1,15 +0,0 @@
|
||||
"""Cloud util functions."""
|
||||
|
||||
from hass_nabucasa import Cloud
|
||||
|
||||
from homeassistant.components import http
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .client import CloudClient
|
||||
from .const import DOMAIN
|
||||
|
||||
|
||||
def get_strict_connection_mode(hass: HomeAssistant) -> http.const.StrictConnectionMode:
|
||||
"""Get the strict connection mode."""
|
||||
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
||||
return cloud.client.prefs.strict_connection
|
@ -10,8 +10,7 @@ import os
|
||||
import socket
|
||||
import ssl
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Final, Required, TypedDict, cast
|
||||
from urllib.parse import quote_plus, urljoin
|
||||
from typing import Any, Final, TypedDict, cast
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.abc import AbstractStreamWriter
|
||||
@ -30,20 +29,8 @@ from yarl import URL
|
||||
|
||||
from homeassistant.components.network import async_get_source_ip
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT
|
||||
from homeassistant.core import (
|
||||
Event,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import (
|
||||
HomeAssistantError,
|
||||
ServiceValidationError,
|
||||
Unauthorized,
|
||||
UnknownUser,
|
||||
)
|
||||
from homeassistant.core import Event, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import storage
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.http import (
|
||||
@ -66,14 +53,9 @@ from homeassistant.util import dt as dt_util, ssl as ssl_util
|
||||
from homeassistant.util.async_ import create_eager_task
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
from .auth import async_setup_auth, async_sign_path
|
||||
from .auth import async_setup_auth
|
||||
from .ban import setup_bans
|
||||
from .const import ( # noqa: F401
|
||||
DOMAIN,
|
||||
KEY_HASS_REFRESH_TOKEN_ID,
|
||||
KEY_HASS_USER,
|
||||
StrictConnectionMode,
|
||||
)
|
||||
from .const import DOMAIN, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER # noqa: F401
|
||||
from .cors import setup_cors
|
||||
from .decorators import require_admin # noqa: F401
|
||||
from .forwarded import async_setup_forwarded
|
||||
@ -96,7 +78,6 @@ CONF_TRUSTED_PROXIES: Final = "trusted_proxies"
|
||||
CONF_LOGIN_ATTEMPTS_THRESHOLD: Final = "login_attempts_threshold"
|
||||
CONF_IP_BAN_ENABLED: Final = "ip_ban_enabled"
|
||||
CONF_SSL_PROFILE: Final = "ssl_profile"
|
||||
CONF_STRICT_CONNECTION: Final = "strict_connection"
|
||||
|
||||
SSL_MODERN: Final = "modern"
|
||||
SSL_INTERMEDIATE: Final = "intermediate"
|
||||
@ -146,9 +127,6 @@ HTTP_SCHEMA: Final = vol.All(
|
||||
[SSL_INTERMEDIATE, SSL_MODERN]
|
||||
),
|
||||
vol.Optional(CONF_USE_X_FRAME_OPTIONS, default=True): cv.boolean,
|
||||
vol.Optional(
|
||||
CONF_STRICT_CONNECTION, default=StrictConnectionMode.DISABLED
|
||||
): vol.Coerce(StrictConnectionMode),
|
||||
}
|
||||
),
|
||||
)
|
||||
@ -172,7 +150,6 @@ class ConfData(TypedDict, total=False):
|
||||
login_attempts_threshold: int
|
||||
ip_ban_enabled: bool
|
||||
ssl_profile: str
|
||||
strict_connection: Required[StrictConnectionMode]
|
||||
|
||||
|
||||
@bind_hass
|
||||
@ -241,7 +218,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
login_threshold=login_threshold,
|
||||
is_ban_enabled=is_ban_enabled,
|
||||
use_x_frame_options=use_x_frame_options,
|
||||
strict_connection_non_cloud=conf[CONF_STRICT_CONNECTION],
|
||||
)
|
||||
|
||||
async def stop_server(event: Event) -> None:
|
||||
@ -271,7 +247,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
local_ip, host, server_port, ssl_certificate is not None
|
||||
)
|
||||
|
||||
_setup_services(hass, conf)
|
||||
return True
|
||||
|
||||
|
||||
@ -356,7 +331,6 @@ class HomeAssistantHTTP:
|
||||
login_threshold: int,
|
||||
is_ban_enabled: bool,
|
||||
use_x_frame_options: bool,
|
||||
strict_connection_non_cloud: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Initialize the server."""
|
||||
self.app[KEY_HASS] = self.hass
|
||||
@ -373,7 +347,7 @@ class HomeAssistantHTTP:
|
||||
if is_ban_enabled:
|
||||
setup_bans(self.hass, self.app, login_threshold)
|
||||
|
||||
await async_setup_auth(self.hass, self.app, strict_connection_non_cloud)
|
||||
await async_setup_auth(self.hass, self.app)
|
||||
|
||||
setup_headers(self.app, use_x_frame_options)
|
||||
setup_cors(self.app, cors_origins)
|
||||
@ -602,61 +576,3 @@ async def start_http_server_and_save_config(
|
||||
]
|
||||
|
||||
store.async_delay_save(lambda: conf, SAVE_DELAY)
|
||||
|
||||
|
||||
@callback
|
||||
def _setup_services(hass: HomeAssistant, conf: ConfData) -> None:
|
||||
"""Set up services for HTTP component."""
|
||||
|
||||
async def create_temporary_strict_connection_url(
|
||||
call: ServiceCall,
|
||||
) -> ServiceResponse:
|
||||
"""Create a strict connection url and return it."""
|
||||
# Copied form homeassistant/helpers/service.py#_async_admin_handler
|
||||
# as the helper supports no responses yet
|
||||
if call.context.user_id:
|
||||
user = await hass.auth.async_get_user(call.context.user_id)
|
||||
if user is None:
|
||||
raise UnknownUser(context=call.context)
|
||||
if not user.is_admin:
|
||||
raise Unauthorized(context=call.context)
|
||||
|
||||
if conf[CONF_STRICT_CONNECTION] is StrictConnectionMode.DISABLED:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="strict_connection_not_enabled_non_cloud",
|
||||
)
|
||||
|
||||
try:
|
||||
url = get_url(
|
||||
hass, prefer_external=True, allow_internal=False, allow_cloud=False
|
||||
)
|
||||
except NoURLAvailableError as ex:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="no_external_url_available",
|
||||
) from ex
|
||||
|
||||
# to avoid circular import
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
from homeassistant.components.auth import STRICT_CONNECTION_URL
|
||||
|
||||
path = async_sign_path(
|
||||
hass,
|
||||
STRICT_CONNECTION_URL,
|
||||
datetime.timedelta(hours=1),
|
||||
use_content_user=True,
|
||||
)
|
||||
url = urljoin(url, path)
|
||||
|
||||
return {
|
||||
"url": f"https://login.home-assistant.io?u={quote_plus(url)}",
|
||||
"direct_url": url,
|
||||
}
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
create_temporary_strict_connection_url,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
|
@ -4,18 +4,14 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import timedelta
|
||||
from http import HTTPStatus
|
||||
from ipaddress import ip_address
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any, Final
|
||||
|
||||
from aiohttp import hdrs
|
||||
from aiohttp.web import Application, Request, Response, StreamResponse, middleware
|
||||
from aiohttp.web_exceptions import HTTPBadRequest
|
||||
from aiohttp_session import session_middleware
|
||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||
import jwt
|
||||
from jwt import api_jws
|
||||
from yarl import URL
|
||||
@ -25,21 +21,13 @@ from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import singleton
|
||||
from homeassistant.helpers.http import current_request
|
||||
from homeassistant.helpers.json import json_bytes
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util.network import is_local
|
||||
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
KEY_AUTHENTICATED,
|
||||
KEY_HASS_REFRESH_TOKEN_ID,
|
||||
KEY_HASS_USER,
|
||||
StrictConnectionMode,
|
||||
)
|
||||
from .session import HomeAssistantCookieStorage
|
||||
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -51,11 +39,6 @@ SAFE_QUERY_PARAMS: Final = ["height", "width"]
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = "http.auth"
|
||||
CONTENT_USER_NAME = "Home Assistant Content"
|
||||
STRICT_CONNECTION_EXCLUDED_PATH = "/api/webhook/"
|
||||
STRICT_CONNECTION_GUARD_PAGE_NAME = "strict_connection_guard_page.html"
|
||||
STRICT_CONNECTION_GUARD_PAGE = os.path.join(
|
||||
os.path.dirname(__file__), STRICT_CONNECTION_GUARD_PAGE_NAME
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
@ -137,7 +120,6 @@ def async_user_not_allowed_do_auth(
|
||||
async def async_setup_auth(
|
||||
hass: HomeAssistant,
|
||||
app: Application,
|
||||
strict_connection_mode_non_cloud: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Create auth middleware for the app."""
|
||||
store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
@ -160,10 +142,6 @@ async def async_setup_auth(
|
||||
|
||||
hass.data[STORAGE_KEY] = refresh_token.id
|
||||
|
||||
if strict_connection_mode_non_cloud is StrictConnectionMode.GUARD_PAGE:
|
||||
# Load the guard page content on setup
|
||||
await _read_strict_connection_guard_page(hass)
|
||||
|
||||
@callback
|
||||
def async_validate_auth_header(request: Request) -> bool:
|
||||
"""Test authorization header against access token.
|
||||
@ -252,37 +230,6 @@ async def async_setup_auth(
|
||||
authenticated = True
|
||||
auth_type = "signed request"
|
||||
|
||||
if not authenticated and not request.path.startswith(
|
||||
STRICT_CONNECTION_EXCLUDED_PATH
|
||||
):
|
||||
strict_connection_mode = strict_connection_mode_non_cloud
|
||||
strict_connection_func = (
|
||||
_async_perform_strict_connection_action_on_non_local
|
||||
)
|
||||
if is_cloud_connection(hass):
|
||||
from homeassistant.components.cloud.util import ( # pylint: disable=import-outside-toplevel
|
||||
get_strict_connection_mode,
|
||||
)
|
||||
|
||||
strict_connection_mode = get_strict_connection_mode(hass)
|
||||
strict_connection_func = _async_perform_strict_connection_action
|
||||
|
||||
if (
|
||||
strict_connection_mode is not StrictConnectionMode.DISABLED
|
||||
and not await hass.auth.session.async_validate_request_for_strict_connection_session(
|
||||
request
|
||||
)
|
||||
and (
|
||||
resp := await strict_connection_func(
|
||||
hass,
|
||||
request,
|
||||
strict_connection_mode is StrictConnectionMode.GUARD_PAGE,
|
||||
)
|
||||
)
|
||||
is not None
|
||||
):
|
||||
return resp
|
||||
|
||||
if authenticated and _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Authenticated %s for %s using %s",
|
||||
@ -294,69 +241,4 @@ async def async_setup_auth(
|
||||
request[KEY_AUTHENTICATED] = authenticated
|
||||
return await handler(request)
|
||||
|
||||
app.middlewares.append(session_middleware(HomeAssistantCookieStorage(hass)))
|
||||
app.middlewares.append(auth_middleware)
|
||||
|
||||
|
||||
async def _async_perform_strict_connection_action_on_non_local(
|
||||
hass: HomeAssistant,
|
||||
request: Request,
|
||||
guard_page: bool,
|
||||
) -> StreamResponse | None:
|
||||
"""Perform strict connection mode action if the request is not local.
|
||||
|
||||
The function does the following:
|
||||
- Try to get the IP address of the request. If it fails, assume it's not local
|
||||
- If the request is local, return None (allow the request to continue)
|
||||
- If guard_page is True, return a response with the content
|
||||
- Otherwise close the connection and raise an exception
|
||||
"""
|
||||
try:
|
||||
ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
_LOGGER.debug("Invalid IP address: %s", request.remote)
|
||||
ip_address_ = None
|
||||
|
||||
if ip_address_ and is_local(ip_address_):
|
||||
return None
|
||||
|
||||
return await _async_perform_strict_connection_action(hass, request, guard_page)
|
||||
|
||||
|
||||
async def _async_perform_strict_connection_action(
|
||||
hass: HomeAssistant,
|
||||
request: Request,
|
||||
guard_page: bool,
|
||||
) -> StreamResponse | None:
|
||||
"""Perform strict connection mode action.
|
||||
|
||||
The function does the following:
|
||||
- If guard_page is True, return a response with the content
|
||||
- Otherwise close the connection and raise an exception
|
||||
"""
|
||||
|
||||
_LOGGER.debug("Perform strict connection action for %s", request.remote)
|
||||
if guard_page:
|
||||
return Response(
|
||||
text=await _read_strict_connection_guard_page(hass),
|
||||
content_type="text/html",
|
||||
status=HTTPStatus.IM_A_TEAPOT,
|
||||
)
|
||||
|
||||
if transport := request.transport:
|
||||
# it should never happen that we don't have a transport
|
||||
transport.close()
|
||||
|
||||
# We need to raise an exception to stop processing the request
|
||||
raise HTTPBadRequest
|
||||
|
||||
|
||||
@singleton.singleton(f"{DOMAIN}_{STRICT_CONNECTION_GUARD_PAGE_NAME}")
|
||||
async def _read_strict_connection_guard_page(hass: HomeAssistant) -> str:
|
||||
"""Read the strict connection guard page from disk via executor."""
|
||||
|
||||
def read_guard_page() -> str:
|
||||
with open(STRICT_CONNECTION_GUARD_PAGE, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
return await hass.async_add_executor_job(read_guard_page)
|
||||
|
@ -1,6 +1,5 @@
|
||||
"""HTTP specific constants."""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Final
|
||||
|
||||
from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401
|
||||
@ -9,11 +8,3 @@ DOMAIN: Final = "http"
|
||||
|
||||
KEY_HASS_USER: Final = "hass_user"
|
||||
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"
|
||||
|
||||
|
||||
class StrictConnectionMode(StrEnum):
|
||||
"""Enum for strict connection mode."""
|
||||
|
||||
DISABLED = "disabled"
|
||||
GUARD_PAGE = "guard_page"
|
||||
DROP_CONNECTION = "drop_connection"
|
||||
|
@ -1,5 +0,0 @@
|
||||
{
|
||||
"services": {
|
||||
"create_temporary_strict_connection_url": "mdi:login-variant"
|
||||
}
|
||||
}
|
@ -1 +0,0 @@
|
||||
create_temporary_strict_connection_url: ~
|
@ -1,160 +0,0 @@
|
||||
"""Session http module."""
|
||||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
|
||||
from aiohttp.web import Request, StreamResponse
|
||||
from aiohttp_session import Session, SessionData
|
||||
from aiohttp_session.cookie_storage import EncryptedCookieStorage
|
||||
from cryptography.fernet import InvalidToken
|
||||
|
||||
from homeassistant.auth.const import REFRESH_TOKEN_EXPIRATION
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.json import json_dumps
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads
|
||||
|
||||
from .ban import process_wrong_login
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
COOKIE_NAME = "SC"
|
||||
PREFIXED_COOKIE_NAME = f"__Host-{COOKIE_NAME}"
|
||||
SESSION_CACHE_SIZE = 16
|
||||
|
||||
|
||||
def _get_cookie_name(is_secure: bool) -> str:
|
||||
"""Return the cookie name."""
|
||||
return PREFIXED_COOKIE_NAME if is_secure else COOKIE_NAME
|
||||
|
||||
|
||||
class HomeAssistantCookieStorage(EncryptedCookieStorage):
|
||||
"""Home Assistant cookie storage.
|
||||
|
||||
Own class is required:
|
||||
- to set the secure flag based on the connection type
|
||||
- to use a LRU cache for session decryption
|
||||
"""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the cookie storage."""
|
||||
super().__init__(
|
||||
hass.auth.session.key,
|
||||
cookie_name=PREFIXED_COOKIE_NAME,
|
||||
max_age=int(REFRESH_TOKEN_EXPIRATION),
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
secure=True,
|
||||
encoder=json_dumps,
|
||||
decoder=json_loads,
|
||||
)
|
||||
self._hass = hass
|
||||
|
||||
def _secure_connection(self, request: Request) -> bool:
|
||||
"""Return if the connection is secure (https)."""
|
||||
return is_cloud_connection(self._hass) or request.secure
|
||||
|
||||
def load_cookie(self, request: Request) -> str | None:
|
||||
"""Load cookie."""
|
||||
is_secure = self._secure_connection(request)
|
||||
cookie_name = _get_cookie_name(is_secure)
|
||||
return request.cookies.get(cookie_name)
|
||||
|
||||
@lru_cache(maxsize=SESSION_CACHE_SIZE)
|
||||
def _decrypt_cookie(self, cookie: str) -> Session | None:
|
||||
"""Decrypt and validate cookie."""
|
||||
try:
|
||||
data = SessionData( # type: ignore[misc]
|
||||
self._decoder(
|
||||
self._fernet.decrypt(
|
||||
cookie.encode("utf-8"), ttl=self.max_age
|
||||
).decode("utf-8")
|
||||
)
|
||||
)
|
||||
except (InvalidToken, TypeError, ValueError, *JSON_DECODE_EXCEPTIONS):
|
||||
_LOGGER.warning("Cannot decrypt/parse cookie value")
|
||||
return None
|
||||
|
||||
session = Session(None, data=data, new=data is None, max_age=self.max_age)
|
||||
|
||||
# Validate session if not empty
|
||||
if (
|
||||
not session.empty
|
||||
and not self._hass.auth.session.async_validate_strict_connection_session(
|
||||
session
|
||||
)
|
||||
):
|
||||
# Invalidate session as it is not valid
|
||||
session.invalidate()
|
||||
|
||||
return session
|
||||
|
||||
async def new_session(self) -> Session:
|
||||
"""Create a new session and mark it as changed."""
|
||||
session = Session(None, data=None, new=True, max_age=self.max_age)
|
||||
session.changed()
|
||||
return session
|
||||
|
||||
async def load_session(self, request: Request) -> Session:
|
||||
"""Load session."""
|
||||
# Split parent function to use lru_cache
|
||||
if (cookie := self.load_cookie(request)) is None:
|
||||
return await self.new_session()
|
||||
|
||||
if (session := self._decrypt_cookie(cookie)) is None:
|
||||
# Decrypting/parsing failed, log wrong login and create a new session
|
||||
await process_wrong_login(request)
|
||||
session = await self.new_session()
|
||||
|
||||
return session
|
||||
|
||||
async def save_session(
|
||||
self, request: Request, response: StreamResponse, session: Session
|
||||
) -> None:
|
||||
"""Save session."""
|
||||
|
||||
is_secure = self._secure_connection(request)
|
||||
cookie_name = _get_cookie_name(is_secure)
|
||||
|
||||
if session.empty:
|
||||
response.del_cookie(cookie_name)
|
||||
else:
|
||||
params = self.cookie_params.copy()
|
||||
params["secure"] = is_secure
|
||||
params["max_age"] = session.max_age
|
||||
|
||||
cookie_data = self._encoder(self._get_session_data(session)).encode("utf-8")
|
||||
response.set_cookie(
|
||||
cookie_name,
|
||||
self._fernet.encrypt(cookie_data).decode("utf-8"),
|
||||
**params,
|
||||
)
|
||||
# Add Cache-Control header to not cache the cookie as it
|
||||
# is used for session management
|
||||
self._add_cache_control_header(response)
|
||||
|
||||
@staticmethod
|
||||
def _add_cache_control_header(response: StreamResponse) -> None:
|
||||
"""Add/set cache control header to no-cache="Set-Cookie"."""
|
||||
# Structure of the Cache-Control header defined in
|
||||
# https://datatracker.ietf.org/doc/html/rfc2068#section-14.9
|
||||
if header := response.headers.get("Cache-Control"):
|
||||
directives = []
|
||||
for directive in header.split(","):
|
||||
directive = directive.strip()
|
||||
directive_lowered = directive.lower()
|
||||
if directive_lowered.startswith("no-cache"):
|
||||
if "set-cookie" in directive_lowered or directive.find("=") == -1:
|
||||
# Set-Cookie is already in the no-cache directive or
|
||||
# the whole request should not be cached -> Nothing to do
|
||||
return
|
||||
|
||||
# Add Set-Cookie to the no-cache
|
||||
# [:-1] to remove the " at the end of the directive
|
||||
directive = f"{directive[:-1]}, Set-Cookie"
|
||||
|
||||
directives.append(directive)
|
||||
header = ", ".join(directives)
|
||||
else:
|
||||
header = 'no-cache="Set-Cookie"'
|
||||
response.headers["Cache-Control"] = header
|
File diff suppressed because one or more lines are too long
@ -1,16 +0,0 @@
|
||||
{
|
||||
"exceptions": {
|
||||
"strict_connection_not_enabled_non_cloud": {
|
||||
"message": "Strict connection is not enabled for non-cloud requests"
|
||||
},
|
||||
"no_external_url_available": {
|
||||
"message": "No external URL available"
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"create_temporary_strict_connection_url": {
|
||||
"name": "Create a temporary strict connection URL",
|
||||
"description": "Create a temporary strict connection URL, which can be used to login on another device."
|
||||
}
|
||||
}
|
||||
}
|
@ -7,7 +7,6 @@ aiohttp-fast-url-dispatcher==0.3.0
|
||||
aiohttp-fast-zlib==0.1.0
|
||||
aiohttp==3.9.5
|
||||
aiohttp_cors==0.7.0
|
||||
aiohttp_session==2.12.0
|
||||
aiozoneinfo==0.1.0
|
||||
astral==2.2
|
||||
async-interrupt==1.1.1
|
||||
|
@ -26,7 +26,6 @@ dependencies = [
|
||||
"aiodns==3.2.0",
|
||||
"aiohttp==3.9.5",
|
||||
"aiohttp_cors==0.7.0",
|
||||
"aiohttp_session==2.12.0",
|
||||
"aiohttp-fast-url-dispatcher==0.3.0",
|
||||
"aiohttp-fast-zlib==0.1.0",
|
||||
"aiozoneinfo==0.1.0",
|
||||
|
@ -6,7 +6,6 @@
|
||||
aiodns==3.2.0
|
||||
aiohttp==3.9.5
|
||||
aiohttp_cors==0.7.0
|
||||
aiohttp_session==2.12.0
|
||||
aiohttp-fast-url-dispatcher==0.3.0
|
||||
aiohttp-fast-zlib==0.1.0
|
||||
aiozoneinfo==0.1.0
|
||||
|
@ -24,7 +24,6 @@ from homeassistant.components.homeassistant.exposed_entities import (
|
||||
ExposedEntities,
|
||||
async_expose_entity,
|
||||
)
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.const import CONTENT_TYPE_JSON, __version__ as HA_VERSION
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
@ -388,7 +387,6 @@ async def test_cloud_connection_info(hass: HomeAssistant) -> None:
|
||||
"connected": False,
|
||||
"enabled": False,
|
||||
"instance_domain": None,
|
||||
"strict_connection": StrictConnectionMode.DISABLED,
|
||||
},
|
||||
"version": HA_VERSION,
|
||||
}
|
||||
|
@ -19,7 +19,6 @@ from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
|
||||
from homeassistant.components.cloud.const import DEFAULT_EXPOSED_DOMAINS, DOMAIN
|
||||
from homeassistant.components.google_assistant.helpers import GoogleEntity
|
||||
from homeassistant.components.homeassistant import exposed_entities
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.components.websocket_api import ERR_INVALID_FORMAT
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
@ -783,7 +782,6 @@ async def test_websocket_status(
|
||||
"google_report_state": True,
|
||||
"remote_allow_remote_enable": True,
|
||||
"remote_enabled": False,
|
||||
"strict_connection": "disabled",
|
||||
"tts_default_voice": ["en-US", "JennyNeural"],
|
||||
},
|
||||
"alexa_entities": {
|
||||
@ -903,7 +901,6 @@ async def test_websocket_update_preferences(
|
||||
assert cloud.client.prefs.alexa_enabled
|
||||
assert cloud.client.prefs.google_secure_devices_pin is None
|
||||
assert cloud.client.prefs.remote_allow_remote_enable is True
|
||||
assert cloud.client.prefs.strict_connection is StrictConnectionMode.DISABLED
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
@ -915,7 +912,6 @@ async def test_websocket_update_preferences(
|
||||
"google_secure_devices_pin": "1234",
|
||||
"tts_default_voice": ["en-GB", "RyanNeural"],
|
||||
"remote_allow_remote_enable": False,
|
||||
"strict_connection": StrictConnectionMode.DROP_CONNECTION,
|
||||
}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
@ -926,7 +922,6 @@ async def test_websocket_update_preferences(
|
||||
assert cloud.client.prefs.google_secure_devices_pin == "1234"
|
||||
assert cloud.client.prefs.remote_allow_remote_enable is False
|
||||
assert cloud.client.prefs.tts_default_voice == ("en-GB", "RyanNeural")
|
||||
assert cloud.client.prefs.strict_connection is StrictConnectionMode.DROP_CONNECTION
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -3,7 +3,6 @@
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from hass_nabucasa import Cloud
|
||||
import pytest
|
||||
@ -14,16 +13,11 @@ from homeassistant.components.cloud import (
|
||||
CloudNotConnected,
|
||||
async_get_or_create_cloudhook,
|
||||
)
|
||||
from homeassistant.components.cloud.const import (
|
||||
DOMAIN,
|
||||
PREF_CLOUDHOOKS,
|
||||
PREF_STRICT_CONNECTION,
|
||||
)
|
||||
from homeassistant.components.cloud.const import DOMAIN, PREF_CLOUDHOOKS
|
||||
from homeassistant.components.cloud.prefs import STORAGE_KEY
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import ServiceValidationError, Unauthorized
|
||||
from homeassistant.exceptions import Unauthorized
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry, MockUser
|
||||
@ -301,77 +295,3 @@ async def test_cloud_logout(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert cloud.is_logged_in is False
|
||||
|
||||
|
||||
async def test_service_create_temporary_strict_connection_url_strict_connection_disabled(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test service create_temporary_strict_connection_url with strict_connection not enabled."""
|
||||
mock_config_entry = MockConfigEntry(domain=DOMAIN)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
with pytest.raises(
|
||||
ServiceValidationError,
|
||||
match="Strict connection is not enabled for cloud requests",
|
||||
):
|
||||
await hass.services.async_call(
|
||||
cloud.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode"),
|
||||
[
|
||||
StrictConnectionMode.DROP_CONNECTION,
|
||||
StrictConnectionMode.GUARD_PAGE,
|
||||
],
|
||||
)
|
||||
async def test_service_create_temporary_strict_connection(
|
||||
hass: HomeAssistant,
|
||||
set_cloud_prefs: Callable[[dict[str, Any]], Coroutine[Any, Any, None]],
|
||||
mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test service create_temporary_strict_connection_url."""
|
||||
mock_config_entry = MockConfigEntry(domain=DOMAIN)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await set_cloud_prefs(
|
||||
{
|
||||
PREF_STRICT_CONNECTION: mode,
|
||||
}
|
||||
)
|
||||
|
||||
# No cloud url set
|
||||
with pytest.raises(ServiceValidationError, match="No cloud URL available"):
|
||||
await hass.services.async_call(
|
||||
cloud.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
# Patch cloud url
|
||||
url = "https://example.com"
|
||||
with patch(
|
||||
"homeassistant.helpers.network._get_cloud_url",
|
||||
return_value=url,
|
||||
):
|
||||
response = await hass.services.async_call(
|
||||
cloud.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
assert isinstance(response, dict)
|
||||
direct_url_prefix = f"{url}/auth/strict_connection/temp_token?authSig="
|
||||
assert response.pop("direct_url").startswith(direct_url_prefix)
|
||||
assert response.pop("url").startswith(
|
||||
f"https://login.home-assistant.io?u={quote_plus(direct_url_prefix)}"
|
||||
)
|
||||
assert response == {} # No more keys in response
|
||||
|
@ -6,13 +6,8 @@ from unittest.mock import ANY, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||
from homeassistant.components.cloud.const import (
|
||||
DOMAIN,
|
||||
PREF_STRICT_CONNECTION,
|
||||
PREF_TTS_DEFAULT_VOICE,
|
||||
)
|
||||
from homeassistant.components.cloud.const import DOMAIN, PREF_TTS_DEFAULT_VOICE
|
||||
from homeassistant.components.cloud.prefs import STORAGE_KEY, CloudPreferences
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
@ -179,39 +174,3 @@ async def test_tts_default_voice_legacy_gender(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert cloud.client.prefs.tts_default_voice == (expected_language, voice)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", list(StrictConnectionMode))
|
||||
async def test_strict_connection_convertion(
|
||||
hass: HomeAssistant,
|
||||
cloud: MagicMock,
|
||||
hass_storage: dict[str, Any],
|
||||
mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test strict connection string value will be converted to the enum."""
|
||||
hass_storage[STORAGE_KEY] = {
|
||||
"version": 1,
|
||||
"data": {PREF_STRICT_CONNECTION: mode.value},
|
||||
}
|
||||
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert cloud.client.prefs.strict_connection is mode
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_data", [{}, {PREF_STRICT_CONNECTION: None}])
|
||||
async def test_strict_connection_default(
|
||||
hass: HomeAssistant,
|
||||
cloud: MagicMock,
|
||||
hass_storage: dict[str, Any],
|
||||
storage_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test strict connection default values."""
|
||||
hass_storage[STORAGE_KEY] = {
|
||||
"version": 1,
|
||||
"data": storage_data,
|
||||
}
|
||||
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert cloud.client.prefs.strict_connection is StrictConnectionMode.DISABLED
|
||||
|
@ -1,294 +0,0 @@
|
||||
"""Test strict connection mode for cloud."""
|
||||
|
||||
from collections.abc import Awaitable, Callable, Coroutine, Generator
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from aiohttp import ServerDisconnectedError, web
|
||||
from aiohttp.test_utils import TestClient
|
||||
from aiohttp_session import get_session
|
||||
import pytest
|
||||
from yarl import URL
|
||||
|
||||
from homeassistant.auth.models import RefreshToken
|
||||
from homeassistant.auth.session import SESSION_ID, TEMP_TIMEOUT
|
||||
from homeassistant.components.cloud.const import PREF_STRICT_CONNECTION
|
||||
from homeassistant.components.http import KEY_HASS
|
||||
from homeassistant.components.http.auth import (
|
||||
STRICT_CONNECTION_GUARD_PAGE,
|
||||
async_setup_auth,
|
||||
async_sign_path,
|
||||
)
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED, StrictConnectionMode
|
||||
from homeassistant.components.http.session import COOKIE_NAME, PREFIXED_COOKIE_NAME
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def refresh_token(hass: HomeAssistant, hass_access_token: str) -> RefreshToken:
|
||||
"""Return a refresh token."""
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
assert refresh_token
|
||||
session = hass.auth.session
|
||||
assert session._strict_connection_sessions == {}
|
||||
assert session._temp_sessions == {}
|
||||
return refresh_token
|
||||
|
||||
|
||||
@contextmanager
|
||||
def simulate_cloud_request() -> Generator[None, None, None]:
|
||||
"""Simulate a cloud request."""
|
||||
with patch(
|
||||
"hass_nabucasa.remote.is_cloud_request", Mock(get=Mock(return_value=True))
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_strict_connection(
|
||||
hass: HomeAssistant, refresh_token: RefreshToken
|
||||
) -> web.Application:
|
||||
"""Fixture to set up a web.Application."""
|
||||
|
||||
async def handler(request):
|
||||
"""Return if request was authenticated."""
|
||||
return web.json_response(data={"authenticated": request[KEY_AUTHENTICATED]})
|
||||
|
||||
app = web.Application()
|
||||
app[KEY_HASS] = hass
|
||||
app.router.add_get("/", handler)
|
||||
|
||||
async def set_cookie(request: web.Request) -> web.Response:
|
||||
hass = request.app[KEY_HASS]
|
||||
# Clear all sessions
|
||||
hass.auth.session._temp_sessions.clear()
|
||||
hass.auth.session._strict_connection_sessions.clear()
|
||||
|
||||
if request.query["token"] == "refresh":
|
||||
await hass.auth.session.async_create_session(request, refresh_token)
|
||||
else:
|
||||
await hass.auth.session.async_create_temp_unauthorized_session(request)
|
||||
session = await get_session(request)
|
||||
return web.Response(text=session[SESSION_ID])
|
||||
|
||||
app.router.add_get("/test/cookie", set_cookie)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(name="client")
|
||||
async def set_up_fixture(
|
||||
hass: HomeAssistant,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
app_strict_connection: web.Application,
|
||||
cloud: MagicMock,
|
||||
socket_enabled: None,
|
||||
) -> TestClient:
|
||||
"""Set up the fixture."""
|
||||
|
||||
await async_setup_auth(hass, app_strict_connection, StrictConnectionMode.DISABLED)
|
||||
assert await async_setup_component(hass, "cloud", {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
return await aiohttp_client(app_strict_connection)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"strict_connection_mode", [e.value for e in StrictConnectionMode]
|
||||
)
|
||||
async def test_strict_connection_cloud_authenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
client: TestClient,
|
||||
hass_access_token: str,
|
||||
set_cloud_prefs: Callable[[dict[str, Any]], Coroutine[Any, Any, None]],
|
||||
refresh_token: RefreshToken,
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test authenticated requests with strict connection."""
|
||||
assert hass.auth.session._strict_connection_sessions == {}
|
||||
|
||||
signed_path = async_sign_path(
|
||||
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
|
||||
)
|
||||
|
||||
await set_cloud_prefs(
|
||||
{
|
||||
PREF_STRICT_CONNECTION: strict_connection_mode,
|
||||
}
|
||||
)
|
||||
|
||||
with simulate_cloud_request():
|
||||
assert is_cloud_connection(hass)
|
||||
req = await client.get(
|
||||
"/", headers={"Authorization": f"Bearer {hass_access_token}"}
|
||||
)
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": True}
|
||||
req = await client.get(signed_path)
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": True}
|
||||
|
||||
|
||||
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
client: TestClient,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
_: RefreshToken,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection cloud enabled."""
|
||||
with simulate_cloud_request():
|
||||
assert is_cloud_connection(hass)
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests_refresh_token(
|
||||
hass: HomeAssistant,
|
||||
client: TestClient,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
refresh_token: RefreshToken,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection cloud enabled and refresh token cookie."""
|
||||
session = hass.auth.session
|
||||
|
||||
# set strict connection cookie with refresh token
|
||||
session_id = await _modify_cookie_for_cloud(client, "refresh")
|
||||
assert session._strict_connection_sessions == {session_id: refresh_token.id}
|
||||
with simulate_cloud_request():
|
||||
assert is_cloud_connection(hass)
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": False}
|
||||
|
||||
# Invalidate refresh token, which should also invalidate session
|
||||
hass.auth.async_remove_refresh_token(refresh_token)
|
||||
assert session._strict_connection_sessions == {}
|
||||
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests_temp_session(
|
||||
hass: HomeAssistant,
|
||||
client: TestClient,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
_: RefreshToken,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection cloud enabled and temp cookie."""
|
||||
session = hass.auth.session
|
||||
|
||||
# set strict connection cookie with temp session
|
||||
assert session._temp_sessions == {}
|
||||
session_id = await _modify_cookie_for_cloud(client, "temp")
|
||||
assert session_id in session._temp_sessions
|
||||
with simulate_cloud_request():
|
||||
assert is_cloud_connection(hass)
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.OK
|
||||
assert await resp.json() == {"authenticated": False}
|
||||
|
||||
async_fire_time_changed(hass, utcnow() + TEMP_TIMEOUT + timedelta(minutes=1))
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
assert session._temp_sessions == {}
|
||||
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _drop_connection_unauthorized_request(
|
||||
_: HomeAssistant, client: TestClient
|
||||
) -> None:
|
||||
with pytest.raises(ServerDisconnectedError):
|
||||
# unauthorized requests should raise ServerDisconnectedError
|
||||
await client.get("/")
|
||||
|
||||
|
||||
async def _guard_page_unauthorized_request(
|
||||
hass: HomeAssistant, client: TestClient
|
||||
) -> None:
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.IM_A_TEAPOT
|
||||
|
||||
def read_guard_page() -> str:
|
||||
with open(STRICT_CONNECTION_GUARD_PAGE, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
assert await req.text() == await hass.async_add_executor_job(read_guard_page)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_func",
|
||||
[
|
||||
_test_strict_connection_cloud_enabled_external_unauthenticated_requests,
|
||||
_test_strict_connection_cloud_enabled_external_unauthenticated_requests_refresh_token,
|
||||
_test_strict_connection_cloud_enabled_external_unauthenticated_requests_temp_session,
|
||||
],
|
||||
ids=[
|
||||
"no cookie",
|
||||
"refresh token cookie",
|
||||
"temp session cookie",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("strict_connection_mode", "request_func"),
|
||||
[
|
||||
(StrictConnectionMode.DROP_CONNECTION, _drop_connection_unauthorized_request),
|
||||
(StrictConnectionMode.GUARD_PAGE, _guard_page_unauthorized_request),
|
||||
],
|
||||
ids=["drop connection", "static page"],
|
||||
)
|
||||
async def test_strict_connection_cloud_external_unauthenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
client: TestClient,
|
||||
refresh_token: RefreshToken,
|
||||
set_cloud_prefs: Callable[[dict[str, Any]], Coroutine[Any, Any, None]],
|
||||
test_func: Callable[
|
||||
[
|
||||
HomeAssistant,
|
||||
TestClient,
|
||||
Callable[[HomeAssistant, TestClient], Awaitable[None]],
|
||||
RefreshToken,
|
||||
],
|
||||
Awaitable[None],
|
||||
],
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
request_func: Callable[[HomeAssistant, TestClient], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection cloud."""
|
||||
await set_cloud_prefs(
|
||||
{
|
||||
PREF_STRICT_CONNECTION: strict_connection_mode,
|
||||
}
|
||||
)
|
||||
|
||||
await test_func(
|
||||
hass,
|
||||
client,
|
||||
request_func,
|
||||
refresh_token,
|
||||
)
|
||||
|
||||
|
||||
async def _modify_cookie_for_cloud(client: TestClient, token_type: str) -> str:
|
||||
"""Modify cookie for cloud."""
|
||||
# Cloud cookie has set secure=true and will not set on insecure connection
|
||||
# As we test with insecure connection, we need to set it manually
|
||||
# We get the session via http and modify the cookie name to the secure one
|
||||
session_id = await (await client.get(f"/test/cookie?token={token_type}")).text()
|
||||
cookie_jar = client.session.cookie_jar
|
||||
localhost = URL("http://127.0.0.1")
|
||||
cookie = cookie_jar.filter_cookies(localhost)[COOKIE_NAME].value
|
||||
assert cookie
|
||||
cookie_jar.clear()
|
||||
cookie_jar.update_cookies({PREFIXED_COOKIE_NAME: cookie}, localhost)
|
||||
return session_id
|
@ -1,28 +1,23 @@
|
||||
"""The tests for the Home Assistant HTTP component."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import timedelta
|
||||
from http import HTTPStatus
|
||||
from ipaddress import ip_network
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from aiohttp import BasicAuth, ServerDisconnectedError, web
|
||||
from aiohttp.test_utils import TestClient
|
||||
from aiohttp import BasicAuth, web
|
||||
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||
from aiohttp_session import get_session
|
||||
import jwt
|
||||
import pytest
|
||||
import yarl
|
||||
from yarl import URL
|
||||
|
||||
from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.models import RefreshToken, User
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.auth.providers import trusted_networks
|
||||
from homeassistant.auth.providers.legacy_api_password import (
|
||||
LegacyApiPasswordAuthProvider,
|
||||
)
|
||||
from homeassistant.auth.session import SESSION_ID, TEMP_TIMEOUT
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.http import KEY_HASS
|
||||
from homeassistant.components.http.auth import (
|
||||
@ -30,12 +25,11 @@ from homeassistant.components.http.auth import (
|
||||
DATA_SIGN_SECRET,
|
||||
SIGN_QUERY_PARAM,
|
||||
STORAGE_KEY,
|
||||
STRICT_CONNECTION_GUARD_PAGE,
|
||||
async_setup_auth,
|
||||
async_sign_path,
|
||||
async_user_not_allowed_do_auth,
|
||||
)
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED, StrictConnectionMode
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||
from homeassistant.components.http.forwarded import async_setup_forwarded
|
||||
from homeassistant.components.http.request_context import (
|
||||
current_request,
|
||||
@ -43,11 +37,10 @@ from homeassistant.components.http.request_context import (
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from . import HTTP_HEADER_HA_AUTH
|
||||
|
||||
from tests.common import MockUser, async_fire_time_changed
|
||||
from tests.common import MockUser
|
||||
from tests.test_util import mock_real_ip
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
@ -137,7 +130,7 @@ async def test_cant_access_with_password_in_header(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test access with password in header."""
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||
@ -154,7 +147,7 @@ async def test_cant_access_with_password_in_query(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test access with password in URL."""
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.get("/", params={"api_password": API_PASSWORD})
|
||||
@ -174,7 +167,7 @@ async def test_basic_auth_does_not_work(
|
||||
legacy_auth: LegacyApiPasswordAuthProvider,
|
||||
) -> None:
|
||||
"""Test access with basic authentication."""
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
req = await client.get("/", auth=BasicAuth("homeassistant", API_PASSWORD))
|
||||
@ -198,7 +191,7 @@ async def test_cannot_access_with_trusted_ip(
|
||||
hass_owner_user: MockUser,
|
||||
) -> None:
|
||||
"""Test access with an untrusted ip address."""
|
||||
await async_setup_auth(hass, app2, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app2)
|
||||
|
||||
set_mock_ip = mock_real_ip(app2)
|
||||
client = await aiohttp_client(app2)
|
||||
@ -226,7 +219,7 @@ async def test_auth_active_access_with_access_token_in_header(
|
||||
) -> None:
|
||||
"""Test access with access token in header."""
|
||||
token = hass_access_token
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
@ -262,7 +255,7 @@ async def test_auth_active_access_with_trusted_ip(
|
||||
hass_owner_user: MockUser,
|
||||
) -> None:
|
||||
"""Test access with an untrusted ip address."""
|
||||
await async_setup_auth(hass, app2, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app2)
|
||||
|
||||
set_mock_ip = mock_real_ip(app2)
|
||||
client = await aiohttp_client(app2)
|
||||
@ -289,7 +282,7 @@ async def test_auth_legacy_support_api_password_cannot_access(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test access using api_password if auth.support_legacy."""
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||
@ -311,7 +304,7 @@ async def test_auth_access_signed_path_with_refresh_token(
|
||||
"""Test access with signed url."""
|
||||
app.router.add_post("/", mock_handler)
|
||||
app.router.add_get("/another_path", mock_handler)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
@ -356,7 +349,7 @@ async def test_auth_access_signed_path_with_query_param(
|
||||
"""Test access with signed url and query params."""
|
||||
app.router.add_post("/", mock_handler)
|
||||
app.router.add_get("/another_path", mock_handler)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
@ -386,7 +379,7 @@ async def test_auth_access_signed_path_with_query_param_order(
|
||||
"""Test access with signed url and query params different order."""
|
||||
app.router.add_post("/", mock_handler)
|
||||
app.router.add_get("/another_path", mock_handler)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
@ -427,7 +420,7 @@ async def test_auth_access_signed_path_with_query_param_safe_param(
|
||||
"""Test access with signed url and changing a safe param."""
|
||||
app.router.add_post("/", mock_handler)
|
||||
app.router.add_get("/another_path", mock_handler)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
@ -466,7 +459,7 @@ async def test_auth_access_signed_path_with_query_param_tamper(
|
||||
"""Test access with signed url and query params that have been tampered with."""
|
||||
app.router.add_post("/", mock_handler)
|
||||
app.router.add_get("/another_path", mock_handler)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
@ -535,7 +528,7 @@ async def test_auth_access_signed_path_with_http(
|
||||
)
|
||||
|
||||
app.router.add_get("/hello", mock_handler)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
@ -559,7 +552,7 @@ async def test_auth_access_signed_path_with_content_user(
|
||||
hass: HomeAssistant, app, aiohttp_client: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test access signed url uses content user."""
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
signed_path = async_sign_path(hass, "/", timedelta(seconds=5))
|
||||
signature = yarl.URL(signed_path).query["authSig"]
|
||||
claims = jwt.decode(
|
||||
@ -579,7 +572,7 @@ async def test_local_only_user_rejected(
|
||||
) -> None:
|
||||
"""Test access with access token in header."""
|
||||
token = hass_access_token
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
set_mock_ip = mock_real_ip(app)
|
||||
client = await aiohttp_client(app)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
@ -645,7 +638,7 @@ async def test_create_user_once(hass: HomeAssistant) -> None:
|
||||
"""Test that we reuse the user."""
|
||||
cur_users = len(await hass.auth.async_get_users())
|
||||
app = web.Application()
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
users = await hass.auth.async_get_users()
|
||||
assert len(users) == cur_users + 1
|
||||
|
||||
@ -657,287 +650,7 @@ async def test_create_user_once(hass: HomeAssistant) -> None:
|
||||
assert len(user.refresh_tokens) == 1
|
||||
assert user.system_generated
|
||||
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
await async_setup_auth(hass, app)
|
||||
|
||||
# test it did not create a user
|
||||
assert len(await hass.auth.async_get_users()) == cur_users + 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_strict_connection(hass):
|
||||
"""Fixture to set up a web.Application."""
|
||||
|
||||
async def handler(request):
|
||||
"""Return if request was authenticated."""
|
||||
return web.json_response(data={"authenticated": request[KEY_AUTHENTICATED]})
|
||||
|
||||
app = web.Application()
|
||||
app[KEY_HASS] = hass
|
||||
app.router.add_get("/", handler)
|
||||
async_setup_forwarded(app, True, [])
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"strict_connection_mode", [e.value for e in StrictConnectionMode]
|
||||
)
|
||||
async def test_strict_connection_non_cloud_authenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
app_strict_connection: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test authenticated requests with strict connection."""
|
||||
token = hass_access_token
|
||||
await async_setup_auth(hass, app_strict_connection, strict_connection_mode)
|
||||
set_mock_ip = mock_real_ip(app_strict_connection)
|
||||
client = await aiohttp_client(app_strict_connection)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
assert refresh_token
|
||||
assert hass.auth.session._strict_connection_sessions == {}
|
||||
|
||||
signed_path = async_sign_path(
|
||||
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
|
||||
)
|
||||
|
||||
for remote_addr in (*LOCALHOST_ADDRESSES, *PRIVATE_ADDRESSES, *EXTERNAL_ADDRESSES):
|
||||
set_mock_ip(remote_addr)
|
||||
|
||||
# authorized requests should work normally
|
||||
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": True}
|
||||
req = await client.get(signed_path)
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": True}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"strict_connection_mode", [e.value for e in StrictConnectionMode]
|
||||
)
|
||||
async def test_strict_connection_non_cloud_local_unauthenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
app_strict_connection: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test local unauthenticated requests with strict connection."""
|
||||
await async_setup_auth(hass, app_strict_connection, strict_connection_mode)
|
||||
set_mock_ip = mock_real_ip(app_strict_connection)
|
||||
client = await aiohttp_client(app_strict_connection)
|
||||
assert hass.auth.session._strict_connection_sessions == {}
|
||||
|
||||
for remote_addr in (*LOCALHOST_ADDRESSES, *PRIVATE_ADDRESSES):
|
||||
set_mock_ip(remote_addr)
|
||||
# local requests should work normally
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": False}
|
||||
|
||||
|
||||
def _add_set_cookie_endpoint(app: web.Application, refresh_token: RefreshToken) -> None:
|
||||
"""Add an endpoint to set a cookie."""
|
||||
|
||||
async def set_cookie(request: web.Request) -> web.Response:
|
||||
hass = request.app[KEY_HASS]
|
||||
# Clear all sessions
|
||||
hass.auth.session._temp_sessions.clear()
|
||||
hass.auth.session._strict_connection_sessions.clear()
|
||||
|
||||
if request.query["token"] == "refresh":
|
||||
await hass.auth.session.async_create_session(request, refresh_token)
|
||||
else:
|
||||
await hass.auth.session.async_create_temp_unauthorized_session(request)
|
||||
session = await get_session(request)
|
||||
return web.Response(text=session[SESSION_ID])
|
||||
|
||||
app.router.add_get("/test/cookie", set_cookie)
|
||||
|
||||
|
||||
async def _test_strict_connection_non_cloud_enabled_setup(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> tuple[TestClient, Callable[[str], None], RefreshToken]:
|
||||
"""Test external unauthenticated requests with strict connection non cloud enabled."""
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
assert refresh_token
|
||||
session = hass.auth.session
|
||||
assert session._strict_connection_sessions == {}
|
||||
assert session._temp_sessions == {}
|
||||
|
||||
_add_set_cookie_endpoint(app, refresh_token)
|
||||
await async_setup_auth(hass, app, strict_connection_mode)
|
||||
set_mock_ip = mock_real_ip(app)
|
||||
client = await aiohttp_client(app)
|
||||
return (client, set_mock_ip, refresh_token)
|
||||
|
||||
|
||||
async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection non cloud enabled."""
|
||||
client, set_mock_ip, _ = await _test_strict_connection_non_cloud_enabled_setup(
|
||||
hass, app, aiohttp_client, hass_access_token, strict_connection_mode
|
||||
)
|
||||
|
||||
for remote_addr in EXTERNAL_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_refresh_token(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection non cloud enabled and refresh token cookie."""
|
||||
(
|
||||
client,
|
||||
set_mock_ip,
|
||||
refresh_token,
|
||||
) = await _test_strict_connection_non_cloud_enabled_setup(
|
||||
hass, app, aiohttp_client, hass_access_token, strict_connection_mode
|
||||
)
|
||||
session = hass.auth.session
|
||||
|
||||
# set strict connection cookie with refresh token
|
||||
set_mock_ip(LOCALHOST_ADDRESSES[0])
|
||||
session_id = await (await client.get("/test/cookie?token=refresh")).text()
|
||||
assert session._strict_connection_sessions == {session_id: refresh_token.id}
|
||||
for remote_addr in EXTERNAL_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": False}
|
||||
|
||||
# Invalidate refresh token, which should also invalidate session
|
||||
hass.auth.async_remove_refresh_token(refresh_token)
|
||||
assert session._strict_connection_sessions == {}
|
||||
for remote_addr in EXTERNAL_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_temp_session(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection non cloud enabled and temp cookie."""
|
||||
client, set_mock_ip, _ = await _test_strict_connection_non_cloud_enabled_setup(
|
||||
hass, app, aiohttp_client, hass_access_token, strict_connection_mode
|
||||
)
|
||||
session = hass.auth.session
|
||||
|
||||
# set strict connection cookie with temp session
|
||||
assert session._temp_sessions == {}
|
||||
set_mock_ip(LOCALHOST_ADDRESSES[0])
|
||||
session_id = await (await client.get("/test/cookie?token=temp")).text()
|
||||
assert client.session.cookie_jar.filter_cookies(URL("http://127.0.0.1"))
|
||||
assert session_id in session._temp_sessions
|
||||
for remote_addr in EXTERNAL_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.OK
|
||||
assert await resp.json() == {"authenticated": False}
|
||||
|
||||
async_fire_time_changed(hass, utcnow() + TEMP_TIMEOUT + timedelta(minutes=1))
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
assert session._temp_sessions == {}
|
||||
for remote_addr in EXTERNAL_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _drop_connection_unauthorized_request(
|
||||
_: HomeAssistant, client: TestClient
|
||||
) -> None:
|
||||
with pytest.raises(ServerDisconnectedError):
|
||||
# unauthorized requests should raise ServerDisconnectedError
|
||||
await client.get("/")
|
||||
|
||||
|
||||
async def _guard_page_unauthorized_request(
|
||||
hass: HomeAssistant, client: TestClient
|
||||
) -> None:
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.IM_A_TEAPOT
|
||||
|
||||
def read_guard_page() -> str:
|
||||
with open(STRICT_CONNECTION_GUARD_PAGE, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
assert await req.text() == await hass.async_add_executor_job(read_guard_page)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_func",
|
||||
[
|
||||
_test_strict_connection_non_cloud_enabled_external_unauthenticated_requests,
|
||||
_test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_refresh_token,
|
||||
_test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_temp_session,
|
||||
],
|
||||
ids=[
|
||||
"no cookie",
|
||||
"refresh token cookie",
|
||||
"temp session cookie",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("strict_connection_mode", "request_func"),
|
||||
[
|
||||
(StrictConnectionMode.DROP_CONNECTION, _drop_connection_unauthorized_request),
|
||||
(StrictConnectionMode.GUARD_PAGE, _guard_page_unauthorized_request),
|
||||
],
|
||||
ids=["drop connection", "static page"],
|
||||
)
|
||||
async def test_strict_connection_non_cloud_external_unauthenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
app_strict_connection: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
test_func: Callable[
|
||||
[
|
||||
HomeAssistant,
|
||||
web.Application,
|
||||
ClientSessionGenerator,
|
||||
str,
|
||||
Callable[[HomeAssistant, TestClient], Awaitable[None]],
|
||||
StrictConnectionMode,
|
||||
],
|
||||
Awaitable[None],
|
||||
],
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
request_func: Callable[[HomeAssistant, TestClient], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection non cloud."""
|
||||
await test_func(
|
||||
hass,
|
||||
app_strict_connection,
|
||||
aiohttp_client,
|
||||
hass_access_token,
|
||||
request_func,
|
||||
strict_connection_mode,
|
||||
)
|
||||
|
@ -7,7 +7,6 @@ from ipaddress import ip_network
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pytest
|
||||
|
||||
@ -15,10 +14,7 @@ from homeassistant.auth.providers.legacy_api_password import (
|
||||
LegacyApiPasswordAuthProvider,
|
||||
)
|
||||
from homeassistant.components import http
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ServiceValidationError
|
||||
from homeassistant.helpers.http import KEY_HASS
|
||||
from homeassistant.helpers.network import NoURLAvailableError
|
||||
from homeassistant.setup import async_setup_component
|
||||
@ -525,78 +521,3 @@ async def test_logging(
|
||||
response = await client.get("/api/states/logging.entity")
|
||||
assert response.status == HTTPStatus.OK
|
||||
assert "GET /api/states/logging.entity" not in caplog.text
|
||||
|
||||
|
||||
async def test_service_create_temporary_strict_connection_url_strict_connection_disabled(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test service create_temporary_strict_connection_url with strict_connection not enabled."""
|
||||
assert await async_setup_component(hass, http.DOMAIN, {"http": {}})
|
||||
with pytest.raises(
|
||||
ServiceValidationError,
|
||||
match="Strict connection is not enabled for non-cloud requests",
|
||||
):
|
||||
await hass.services.async_call(
|
||||
http.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode"),
|
||||
[
|
||||
StrictConnectionMode.DROP_CONNECTION,
|
||||
StrictConnectionMode.GUARD_PAGE,
|
||||
],
|
||||
)
|
||||
async def test_service_create_temporary_strict_connection(
|
||||
hass: HomeAssistant, mode: StrictConnectionMode
|
||||
) -> None:
|
||||
"""Test service create_temporary_strict_connection_url."""
|
||||
assert await async_setup_component(
|
||||
hass, http.DOMAIN, {"http": {"strict_connection": mode}}
|
||||
)
|
||||
|
||||
# No external url set
|
||||
assert hass.config.external_url is None
|
||||
assert hass.config.internal_url is None
|
||||
with pytest.raises(ServiceValidationError, match="No external URL available"):
|
||||
await hass.services.async_call(
|
||||
http.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
# Raise if only internal url is available
|
||||
hass.config.api = Mock(use_ssl=False, port=8123, local_ip="192.168.123.123")
|
||||
with pytest.raises(ServiceValidationError, match="No external URL available"):
|
||||
await hass.services.async_call(
|
||||
http.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
# Set external url too
|
||||
external_url = "https://example.com"
|
||||
await async_process_ha_core_config(
|
||||
hass,
|
||||
{"external_url": external_url},
|
||||
)
|
||||
assert hass.config.external_url == external_url
|
||||
response = await hass.services.async_call(
|
||||
http.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
assert isinstance(response, dict)
|
||||
direct_url_prefix = f"{external_url}/auth/strict_connection/temp_token?authSig="
|
||||
assert response.pop("direct_url").startswith(direct_url_prefix)
|
||||
assert response.pop("url").startswith(
|
||||
f"https://login.home-assistant.io?u={quote_plus(direct_url_prefix)}"
|
||||
)
|
||||
assert response == {} # No more keys in response
|
||||
|
@ -1,107 +0,0 @@
|
||||
"""Tests for HTTP session."""
|
||||
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import make_mocked_request
|
||||
import pytest
|
||||
|
||||
from homeassistant.auth.session import SESSION_ID
|
||||
from homeassistant.components.http.session import (
|
||||
COOKIE_NAME,
|
||||
HomeAssistantCookieStorage,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
def fake_request_with_strict_connection_cookie(cookie_value: str) -> web.Request:
|
||||
"""Return a fake request with a strict connection cookie."""
|
||||
request = make_mocked_request(
|
||||
"GET", "/", headers={"Cookie": f"{COOKIE_NAME}={cookie_value}"}
|
||||
)
|
||||
assert COOKIE_NAME in request.cookies
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cookie_storage(hass: HomeAssistant) -> HomeAssistantCookieStorage:
|
||||
"""Fixture for the cookie storage."""
|
||||
return HomeAssistantCookieStorage(hass)
|
||||
|
||||
|
||||
def _encrypt_cookie_data(cookie_storage: HomeAssistantCookieStorage, data: Any) -> str:
|
||||
"""Encrypt cookie data."""
|
||||
cookie_data = cookie_storage._encoder(data).encode("utf-8")
|
||||
return cookie_storage._fernet.encrypt(cookie_data).decode("utf-8")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func",
|
||||
[
|
||||
lambda _: "invalid",
|
||||
lambda storage: _encrypt_cookie_data(storage, "bla"),
|
||||
lambda storage: _encrypt_cookie_data(storage, None),
|
||||
],
|
||||
)
|
||||
async def test_load_session_modified_cookies(
|
||||
cookie_storage: HomeAssistantCookieStorage,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
func: Callable[[HomeAssistantCookieStorage], str],
|
||||
) -> None:
|
||||
"""Test that on modified cookies the session is empty and the request will be logged for ban."""
|
||||
request = fake_request_with_strict_connection_cookie(func(cookie_storage))
|
||||
with patch(
|
||||
"homeassistant.components.http.session.process_wrong_login",
|
||||
) as mock_process_wrong_login:
|
||||
session = await cookie_storage.load_session(request)
|
||||
assert session.empty
|
||||
assert (
|
||||
"homeassistant.components.http.session",
|
||||
logging.WARNING,
|
||||
"Cannot decrypt/parse cookie value",
|
||||
) in caplog.record_tuples
|
||||
mock_process_wrong_login.assert_called()
|
||||
|
||||
|
||||
async def test_load_session_validate_session(
|
||||
hass: HomeAssistant,
|
||||
cookie_storage: HomeAssistantCookieStorage,
|
||||
) -> None:
|
||||
"""Test load session validates the session."""
|
||||
session = await cookie_storage.new_session()
|
||||
session[SESSION_ID] = "bla"
|
||||
request = fake_request_with_strict_connection_cookie(
|
||||
_encrypt_cookie_data(cookie_storage, cookie_storage._get_session_data(session))
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
hass.auth.session, "async_validate_strict_connection_session", return_value=True
|
||||
) as mock_validate:
|
||||
session = await cookie_storage.load_session(request)
|
||||
assert not session.empty
|
||||
assert session[SESSION_ID] == "bla"
|
||||
mock_validate.assert_called_with(session)
|
||||
|
||||
# verify lru_cache is working
|
||||
mock_validate.reset_mock()
|
||||
await cookie_storage.load_session(request)
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
session = await cookie_storage.new_session()
|
||||
session[SESSION_ID] = "something"
|
||||
request = fake_request_with_strict_connection_cookie(
|
||||
_encrypt_cookie_data(cookie_storage, cookie_storage._get_session_data(session))
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
hass.auth.session,
|
||||
"async_validate_strict_connection_session",
|
||||
return_value=False,
|
||||
):
|
||||
session = await cookie_storage.load_session(request)
|
||||
assert session.empty
|
||||
assert SESSION_ID not in session
|
||||
assert session._changed
|
@ -800,11 +800,10 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
|
||||
assert proxy_load_services_files.mock_calls[0][1][1] == unordered(
|
||||
[
|
||||
await async_get_integration(hass, DOMAIN_GROUP),
|
||||
await async_get_integration(hass, "http"), # system_health requires http
|
||||
]
|
||||
)
|
||||
|
||||
assert len(descriptions) == 2
|
||||
assert len(descriptions) == 1
|
||||
assert DOMAIN_GROUP in descriptions
|
||||
assert "description" in descriptions[DOMAIN_GROUP]["reload"]
|
||||
assert "fields" in descriptions[DOMAIN_GROUP]["reload"]
|
||||
@ -838,7 +837,7 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
|
||||
await async_setup_component(hass, DOMAIN_LOGGER, logger_config)
|
||||
descriptions = await service.async_get_all_descriptions(hass)
|
||||
|
||||
assert len(descriptions) == 3
|
||||
assert len(descriptions) == 2
|
||||
assert DOMAIN_LOGGER in descriptions
|
||||
assert descriptions[DOMAIN_LOGGER]["set_default_level"]["name"] == "Translated name"
|
||||
assert (
|
||||
|
@ -5,7 +5,6 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.config import YAML_CONFIG_FILE
|
||||
from homeassistant.scripts import check_config
|
||||
|
||||
@ -135,7 +134,6 @@ def test_secrets(mock_is_file, event_loop, mock_hass_config_yaml: None) -> None:
|
||||
"login_attempts_threshold": -1,
|
||||
"server_port": 8123,
|
||||
"ssl_profile": "modern",
|
||||
"strict_connection": StrictConnectionMode.DISABLED,
|
||||
"use_x_frame_options": True,
|
||||
"server_host": ["0.0.0.0", "::"],
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user