From 630ddd6a8c8f14d9a7558f3722fdd3c2b8daf8f1 Mon Sep 17 00:00:00 2001 From: Robert Resch Date: Mon, 29 Apr 2024 21:26:40 +0200 Subject: [PATCH] Revert "Remove strict connection" (#116416) --- homeassistant/components/cloud/__init__.py | 8 ++++++++ homeassistant/components/cloud/prefs.py | 11 ++++++++++- homeassistant/components/http/__init__.py | 18 +++++++++++++++--- tests/components/cloud/test_http_api.py | 2 ++ tests/components/cloud/test_init.py | 2 -- tests/components/cloud/test_prefs.py | 1 - .../components/cloud/test_strict_connection.py | 1 - tests/components/http/test_init.py | 2 -- tests/helpers/test_service.py | 5 +++-- tests/scripts/test_check_config.py | 2 ++ 10 files changed, 40 insertions(+), 12 deletions(-) diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index 13f1d34b5cd..2552fe4bf5c 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -30,6 +30,7 @@ from homeassistant.core import ( HomeAssistant, ServiceCall, ServiceResponse, + SupportsResponse, callback, ) from homeassistant.exceptions import ( @@ -457,3 +458,10 @@ def _setup_services(hass: HomeAssistant, prefs: CloudPreferences) -> None: "url": f"https://login.home-assistant.io?u={quote_plus(url)}", "direct_url": url, } + + hass.services.async_register( + DOMAIN, + "create_temporary_strict_connection_url", + create_temporary_strict_connection_url, + supports_response=SupportsResponse.ONLY, + ) diff --git a/homeassistant/components/cloud/prefs.py b/homeassistant/components/cloud/prefs.py index b4e692d02c4..72207513ca9 100644 --- a/homeassistant/components/cloud/prefs.py +++ b/homeassistant/components/cloud/prefs.py @@ -365,7 +365,16 @@ class CloudPreferences: @property def strict_connection(self) -> http.const.StrictConnectionMode: """Return the strict connection mode.""" - return http.const.StrictConnectionMode.DISABLED + mode = self._prefs.get(PREF_STRICT_CONNECTION) + + if mode is None: + # Set to default value + # We store None in the store as the default value to detect if the user has changed the + # value or not. + mode = http.const.StrictConnectionMode.DISABLED + elif not isinstance(mode, http.const.StrictConnectionMode): + mode = http.const.StrictConnectionMode(mode) + return mode async def get_cloud_user(self) -> str: """Return ID of Home Assistant Cloud system user.""" diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index c783d2f0b71..83601599d88 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -10,7 +10,7 @@ import os import socket import ssl from tempfile import NamedTemporaryFile -from typing import Any, Final, TypedDict, cast +from typing import Any, Final, Required, TypedDict, cast from urllib.parse import quote_plus, urljoin from aiohttp import web @@ -36,6 +36,7 @@ from homeassistant.core import ( HomeAssistant, ServiceCall, ServiceResponse, + SupportsResponse, callback, ) from homeassistant.exceptions import ( @@ -145,6 +146,9 @@ HTTP_SCHEMA: Final = vol.All( [SSL_INTERMEDIATE, SSL_MODERN] ), vol.Optional(CONF_USE_X_FRAME_OPTIONS, default=True): cv.boolean, + vol.Optional( + CONF_STRICT_CONNECTION, default=StrictConnectionMode.DISABLED + ): vol.Coerce(StrictConnectionMode), } ), ) @@ -168,6 +172,7 @@ class ConfData(TypedDict, total=False): login_attempts_threshold: int ip_ban_enabled: bool ssl_profile: str + strict_connection: Required[StrictConnectionMode] @bind_hass @@ -234,7 +239,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: login_threshold=login_threshold, is_ban_enabled=is_ban_enabled, use_x_frame_options=use_x_frame_options, - strict_connection_non_cloud=StrictConnectionMode.DISABLED, + strict_connection_non_cloud=conf[CONF_STRICT_CONNECTION], ) async def stop_server(event: Event) -> None: @@ -615,7 +620,7 @@ def _setup_services(hass: HomeAssistant, conf: ConfData) -> None: if not user.is_admin: raise Unauthorized(context=call.context) - if StrictConnectionMode.DISABLED is StrictConnectionMode.DISABLED: + if conf[CONF_STRICT_CONNECTION] is StrictConnectionMode.DISABLED: raise ServiceValidationError( translation_domain=DOMAIN, translation_key="strict_connection_not_enabled_non_cloud", @@ -647,3 +652,10 @@ def _setup_services(hass: HomeAssistant, conf: ConfData) -> None: "url": f"https://login.home-assistant.io?u={quote_plus(url)}", "direct_url": url, } + + hass.services.async_register( + DOMAIN, + "create_temporary_strict_connection_url", + create_temporary_strict_connection_url, + supports_response=SupportsResponse.ONLY, + ) diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index 1e4dc3173e2..d9d2b5c6742 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -915,6 +915,7 @@ async def test_websocket_update_preferences( "google_secure_devices_pin": "1234", "tts_default_voice": ["en-GB", "RyanNeural"], "remote_allow_remote_enable": False, + "strict_connection": StrictConnectionMode.DROP_CONNECTION, } ) response = await client.receive_json() @@ -925,6 +926,7 @@ async def test_websocket_update_preferences( assert cloud.client.prefs.google_secure_devices_pin == "1234" assert cloud.client.prefs.remote_allow_remote_enable is False assert cloud.client.prefs.tts_default_voice == ("en-GB", "RyanNeural") + assert cloud.client.prefs.strict_connection is StrictConnectionMode.DROP_CONNECTION @pytest.mark.parametrize( diff --git a/tests/components/cloud/test_init.py b/tests/components/cloud/test_init.py index d917dc12a7c..bc4526975da 100644 --- a/tests/components/cloud/test_init.py +++ b/tests/components/cloud/test_init.py @@ -303,7 +303,6 @@ async def test_cloud_logout( assert cloud.is_logged_in is False -@pytest.mark.skip(reason="Remove strict connection config option") async def test_service_create_temporary_strict_connection_url_strict_connection_disabled( hass: HomeAssistant, ) -> None: @@ -324,7 +323,6 @@ async def test_service_create_temporary_strict_connection_url_strict_connection_ ) -@pytest.mark.skip(reason="Remove strict connection config option") @pytest.mark.parametrize( ("mode"), [ diff --git a/tests/components/cloud/test_prefs.py b/tests/components/cloud/test_prefs.py index a8ce88f5700..57715fe2bdf 100644 --- a/tests/components/cloud/test_prefs.py +++ b/tests/components/cloud/test_prefs.py @@ -181,7 +181,6 @@ async def test_tts_default_voice_legacy_gender( assert cloud.client.prefs.tts_default_voice == (expected_language, voice) -@pytest.mark.skip(reason="Remove strict connection config option") @pytest.mark.parametrize("mode", list(StrictConnectionMode)) async def test_strict_connection_convertion( hass: HomeAssistant, diff --git a/tests/components/cloud/test_strict_connection.py b/tests/components/cloud/test_strict_connection.py index c3329740207..f275bc4d2dd 100644 --- a/tests/components/cloud/test_strict_connection.py +++ b/tests/components/cloud/test_strict_connection.py @@ -226,7 +226,6 @@ async def _guard_page_unauthorized_request( assert await req.text() == await hass.async_add_executor_job(read_guard_page) -@pytest.mark.skip(reason="Remove strict connection config option") @pytest.mark.parametrize( "test_func", [ diff --git a/tests/components/http/test_init.py b/tests/components/http/test_init.py index 9e576e10f4d..b554737e7b3 100644 --- a/tests/components/http/test_init.py +++ b/tests/components/http/test_init.py @@ -527,7 +527,6 @@ async def test_logging( assert "GET /api/states/logging.entity" not in caplog.text -@pytest.mark.skip(reason="Remove strict connection config option") async def test_service_create_temporary_strict_connection_url_strict_connection_disabled( hass: HomeAssistant, ) -> None: @@ -545,7 +544,6 @@ async def test_service_create_temporary_strict_connection_url_strict_connection_ ) -@pytest.mark.skip(reason="Remove strict connection config option") @pytest.mark.parametrize( ("mode"), [ diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index c9d92c2f25a..e32768ee33e 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -800,10 +800,11 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None: assert proxy_load_services_files.mock_calls[0][1][1] == unordered( [ await async_get_integration(hass, DOMAIN_GROUP), + await async_get_integration(hass, "http"), # system_health requires http ] ) - assert len(descriptions) == 1 + assert len(descriptions) == 2 assert DOMAIN_GROUP in descriptions assert "description" in descriptions[DOMAIN_GROUP]["reload"] assert "fields" in descriptions[DOMAIN_GROUP]["reload"] @@ -837,7 +838,7 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None: await async_setup_component(hass, DOMAIN_LOGGER, logger_config) descriptions = await service.async_get_all_descriptions(hass) - assert len(descriptions) == 2 + assert len(descriptions) == 3 assert DOMAIN_LOGGER in descriptions assert descriptions[DOMAIN_LOGGER]["set_default_level"]["name"] == "Translated name" assert ( diff --git a/tests/scripts/test_check_config.py b/tests/scripts/test_check_config.py index 79c64259f8b..76acb2ff678 100644 --- a/tests/scripts/test_check_config.py +++ b/tests/scripts/test_check_config.py @@ -5,6 +5,7 @@ from unittest.mock import patch import pytest +from homeassistant.components.http.const import StrictConnectionMode from homeassistant.config import YAML_CONFIG_FILE from homeassistant.scripts import check_config @@ -134,6 +135,7 @@ def test_secrets(mock_is_file, event_loop, mock_hass_config_yaml: None) -> None: "login_attempts_threshold": -1, "server_port": 8123, "ssl_profile": "modern", + "strict_connection": StrictConnectionMode.DISABLED, "use_x_frame_options": True, "server_host": ["0.0.0.0", "::"], }