diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 75f13caa6aa..36e44508928 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -1,4 +1,5 @@ """Support to serve the Home Assistant API as WSGI application.""" +from contextvars import ContextVar from ipaddress import ip_network import logging import os @@ -28,6 +29,7 @@ from .ban import setup_bans from .const import KEY_AUTHENTICATED, KEY_HASS, KEY_HASS_USER, KEY_REAL_IP # noqa: F401 from .cors import setup_cors from .real_ip import setup_real_ip +from .request_context import setup_request_context from .static import CACHE_HEADERS, CachingStaticResource from .view import HomeAssistantView # noqa: F401 from .web_runner import HomeAssistantTCPSite @@ -295,6 +297,7 @@ class HomeAssistantHTTP: app[KEY_HASS] = hass # This order matters + setup_request_context(app, current_request) setup_real_ip(app, use_x_forwarded_for, trusted_proxies) if is_ban_enabled: @@ -447,3 +450,8 @@ async def start_http_server_and_save_config( ] await store.async_save(conf) + + +current_request: ContextVar[Optional[web.Request]] = ContextVar( + "current_request", default=None +) diff --git a/homeassistant/components/http/request_context.py b/homeassistant/components/http/request_context.py new file mode 100644 index 00000000000..23a85214c3f --- /dev/null +++ b/homeassistant/components/http/request_context.py @@ -0,0 +1,20 @@ +"""Middleware to set the request context.""" + +from aiohttp.web import middleware + +from homeassistant.core import callback + +# mypy: allow-untyped-defs + + +@callback +def setup_request_context(app, context): + """Create request context middleware for the app.""" + + @middleware + async def request_context_middleware(request, handler): + """Request context middleware.""" + context.set(request) + return await handler(request) + + app.middlewares.append(request_context_middleware) diff --git a/homeassistant/helpers/network.py b/homeassistant/helpers/network.py index cebe0318496..471cabd0032 100644 --- a/homeassistant/helpers/network.py +++ b/homeassistant/helpers/network.py @@ -1,9 +1,10 @@ """Network helpers.""" from ipaddress import ip_address -from typing import cast +from typing import Optional, cast import yarl +from homeassistant.components.http import current_request from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import bind_hass @@ -27,6 +28,7 @@ class NoURLAvailableError(HomeAssistantError): def get_url( hass: HomeAssistant, *, + require_current_request: bool = False, require_ssl: bool = False, require_standard_port: bool = False, allow_internal: bool = True, @@ -37,6 +39,9 @@ def get_url( prefer_cloud: bool = False, ) -> str: """Get a URL to this instance.""" + if require_current_request and current_request.get() is None: + raise NoURLAvailableError + order = [TYPE_URL_INTERNAL, TYPE_URL_EXTERNAL] if prefer_external: order.reverse() @@ -49,6 +54,7 @@ def get_url( return _get_internal_url( hass, allow_ip=allow_ip, + require_current_request=require_current_request, require_ssl=require_ssl, require_standard_port=require_standard_port, ) @@ -62,6 +68,7 @@ def get_url( allow_cloud=allow_cloud, allow_ip=allow_ip, prefer_cloud=prefer_cloud, + require_current_request=require_current_request, require_ssl=require_ssl, require_standard_port=require_standard_port, ) @@ -72,11 +79,20 @@ def get_url( raise NoURLAvailableError +def _get_request_host() -> Optional[str]: + """Get the host address of the current request.""" + request = current_request.get() + if request is None: + raise NoURLAvailableError + return yarl.URL(request.url).host + + @bind_hass def _get_internal_url( hass: HomeAssistant, *, allow_ip: bool = True, + require_current_request: bool = False, require_ssl: bool = False, require_standard_port: bool = False, ) -> str: @@ -84,7 +100,8 @@ def _get_internal_url( if hass.config.internal_url: internal_url = yarl.URL(hass.config.internal_url) if ( - (not require_ssl or internal_url.scheme == "https") + (not require_current_request or internal_url.host == _get_request_host()) + and (not require_ssl or internal_url.scheme == "https") and (not require_standard_port or internal_url.is_default_port()) and (allow_ip or not is_ip_address(str(internal_url.host))) ): @@ -96,6 +113,7 @@ def _get_internal_url( hass, internal=True, allow_ip=allow_ip, + require_current_request=require_current_request, require_ssl=require_ssl, require_standard_port=require_standard_port, ) @@ -109,8 +127,10 @@ def _get_internal_url( ip_url = yarl.URL.build( scheme="http", host=hass.config.api.local_ip, port=hass.config.api.port ) - if not is_loopback(ip_address(ip_url.host)) and ( - not require_standard_port or ip_url.is_default_port() + if ( + not is_loopback(ip_address(ip_url.host)) + and (not require_current_request or ip_url.host == _get_request_host()) + and (not require_standard_port or ip_url.is_default_port()) ): return normalize_url(str(ip_url)) @@ -124,6 +144,7 @@ def _get_external_url( allow_cloud: bool = True, allow_ip: bool = True, prefer_cloud: bool = False, + require_current_request: bool = False, require_ssl: bool = False, require_standard_port: bool = False, ) -> str: @@ -138,6 +159,9 @@ def _get_external_url( external_url = yarl.URL(hass.config.external_url) if ( (allow_ip or not is_ip_address(str(external_url.host))) + and ( + not require_current_request or external_url.host == _get_request_host() + ) and (not require_standard_port or external_url.is_default_port()) and ( not require_ssl @@ -153,6 +177,7 @@ def _get_external_url( return _get_deprecated_base_url( hass, allow_ip=allow_ip, + require_current_request=require_current_request, require_ssl=require_ssl, require_standard_port=require_standard_port, ) @@ -161,7 +186,7 @@ def _get_external_url( if allow_cloud: try: - return _get_cloud_url(hass) + return _get_cloud_url(hass, require_current_request=require_current_request) except NoURLAvailableError: pass @@ -169,13 +194,16 @@ def _get_external_url( @bind_hass -def _get_cloud_url(hass: HomeAssistant) -> str: +def _get_cloud_url(hass: HomeAssistant, require_current_request: bool = False) -> str: """Get external Home Assistant Cloud URL of this instance.""" if "cloud" in hass.config.components: try: - return cast(str, hass.components.cloud.async_remote_ui_url()) + cloud_url = yarl.URL(cast(str, hass.components.cloud.async_remote_ui_url())) except hass.components.cloud.CloudNotAvailable: - pass + raise NoURLAvailableError + + if not require_current_request or cloud_url.host == _get_request_host(): + return normalize_url(str(cloud_url)) raise NoURLAvailableError @@ -186,6 +214,7 @@ def _get_deprecated_base_url( *, internal: bool = False, allow_ip: bool = True, + require_current_request: bool = False, require_ssl: bool = False, require_standard_port: bool = False, ) -> str: @@ -197,6 +226,7 @@ def _get_deprecated_base_url( # Rules that apply to both internal and external if ( (allow_ip or not is_ip_address(str(base_url.host))) + and (not require_current_request or base_url.host == _get_request_host()) and (not require_ssl or base_url.scheme == "https") and (not require_standard_port or base_url.is_default_port()) ): diff --git a/tests/components/http/test_request_context.py b/tests/components/http/test_request_context.py new file mode 100644 index 00000000000..f511b860dca --- /dev/null +++ b/tests/components/http/test_request_context.py @@ -0,0 +1,33 @@ +"""Test request context middleware.""" +from contextvars import ContextVar + +from aiohttp import web + +from homeassistant.components.http.request_context import setup_request_context + + +async def test_request_context_middleware(aiohttp_client): + """Test that request context is set from middleware.""" + context = ContextVar("request", default=None) + app = web.Application() + + async def mock_handler(request): + """Return the real IP as text.""" + request_context = context.get() + assert request_context + assert request_context == request + + return web.Response(text="hi!") + + app.router.add_get("/", mock_handler) + setup_request_context(app, context) + mock_api_client = await aiohttp_client(app) + + resp = await mock_api_client.get("/") + assert resp.status == 200 + + text = await resp.text() + assert text == "hi!" + + # We are outside of the context here, should be None + assert context.get() is None diff --git a/tests/helpers/test_network.py b/tests/helpers/test_network.py index f6665b054e7..1754511d95c 100644 --- a/tests/helpers/test_network.py +++ b/tests/helpers/test_network.py @@ -10,6 +10,7 @@ from homeassistant.helpers.network import ( _get_deprecated_base_url, _get_external_url, _get_internal_url, + _get_request_host, get_url, ) @@ -20,6 +21,9 @@ async def test_get_url_internal(hass: HomeAssistant): """Test getting an instance URL when the user has set an internal URL.""" assert hass.config.internal_url is None + with pytest.raises(NoURLAvailableError): + _get_internal_url(hass, require_current_request=True) + # Test with internal URL: http://example.local:8123 await async_process_ha_core_config( hass, {"internal_url": "http://example.local:8123"}, @@ -35,6 +39,31 @@ async def test_get_url_internal(hass: HomeAssistant): with pytest.raises(NoURLAvailableError): _get_internal_url(hass, require_ssl=True) + with pytest.raises(NoURLAvailableError): + _get_internal_url(hass, require_current_request=True) + + with patch( + "homeassistant.helpers.network._get_request_host", return_value="example.local" + ): + assert ( + _get_internal_url(hass, require_current_request=True) + == "http://example.local:8123" + ) + + with pytest.raises(NoURLAvailableError): + _get_internal_url( + hass, require_current_request=True, require_standard_port=True + ) + + with pytest.raises(NoURLAvailableError): + _get_internal_url(hass, require_current_request=True, require_ssl=True) + + with patch( + "homeassistant.helpers.network._get_request_host", + return_value="no_match.example.local", + ), pytest.raises(NoURLAvailableError): + _get_internal_url(hass, require_current_request=True) + # Test with internal URL: https://example.local:8123 await async_process_ha_core_config( hass, {"internal_url": "https://example.local:8123"}, @@ -104,6 +133,25 @@ async def test_get_url_internal(hass: HomeAssistant): with pytest.raises(NoURLAvailableError): _get_internal_url(hass, allow_ip=False) + with patch( + "homeassistant.helpers.network._get_request_host", return_value="192.168.0.1" + ): + assert ( + _get_internal_url(hass, require_current_request=True) + == "http://192.168.0.1:8123" + ) + + with pytest.raises(NoURLAvailableError): + _get_internal_url(hass, require_current_request=True, allow_ip=False) + + with pytest.raises(NoURLAvailableError): + _get_internal_url( + hass, require_current_request=True, require_standard_port=True + ) + + with pytest.raises(NoURLAvailableError): + _get_internal_url(hass, require_current_request=True, require_ssl=True) + async def test_get_url_internal_fallback(hass: HomeAssistant): """Test getting an instance URL when the user has not set an internal URL.""" @@ -171,6 +219,9 @@ async def test_get_url_external(hass: HomeAssistant): """Test getting an instance URL when the user has set an external URL.""" assert hass.config.external_url is None + with pytest.raises(NoURLAvailableError): + _get_external_url(hass, require_current_request=True) + # Test with external URL: http://example.com:8123 await async_process_ha_core_config( hass, {"external_url": "http://example.com:8123"}, @@ -188,6 +239,31 @@ async def test_get_url_external(hass: HomeAssistant): with pytest.raises(NoURLAvailableError): _get_external_url(hass, require_ssl=True) + with pytest.raises(NoURLAvailableError): + _get_external_url(hass, require_current_request=True) + + with patch( + "homeassistant.helpers.network._get_request_host", return_value="example.com" + ): + assert ( + _get_external_url(hass, require_current_request=True) + == "http://example.com:8123" + ) + + with pytest.raises(NoURLAvailableError): + _get_external_url( + hass, require_current_request=True, require_standard_port=True + ) + + with pytest.raises(NoURLAvailableError): + _get_external_url(hass, require_current_request=True, require_ssl=True) + + with patch( + "homeassistant.helpers.network._get_request_host", + return_value="no_match.example.com", + ), pytest.raises(NoURLAvailableError): + _get_external_url(hass, require_current_request=True) + # Test with external URL: http://example.com:80/ await async_process_ha_core_config( hass, {"external_url": "http://example.com:80/"}, @@ -245,6 +321,20 @@ async def test_get_url_external(hass: HomeAssistant): with pytest.raises(NoURLAvailableError): _get_external_url(hass, require_ssl=True) + with patch( + "homeassistant.helpers.network._get_request_host", return_value="192.168.0.1" + ): + assert ( + _get_external_url(hass, require_current_request=True) + == "https://192.168.0.1" + ) + + with pytest.raises(NoURLAvailableError): + _get_external_url(hass, require_current_request=True, allow_ip=False) + + with pytest.raises(NoURLAvailableError): + _get_external_url(hass, require_current_request=True, require_ssl=True) + async def test_get_cloud_url(hass: HomeAssistant): """Test getting an instance URL when the user has set an external URL.""" @@ -258,6 +348,24 @@ async def test_get_cloud_url(hass: HomeAssistant): ): assert _get_cloud_url(hass) == "https://example.nabu.casa" + with pytest.raises(NoURLAvailableError): + _get_cloud_url(hass, require_current_request=True) + + with patch( + "homeassistant.helpers.network._get_request_host", + return_value="example.nabu.casa", + ): + assert ( + _get_cloud_url(hass, require_current_request=True) + == "https://example.nabu.casa" + ) + + with patch( + "homeassistant.helpers.network._get_request_host", + return_value="no_match.nabu.casa", + ), pytest.raises(NoURLAvailableError): + _get_cloud_url(hass, require_current_request=True) + with patch.object( hass.components.cloud, "async_remote_ui_url", @@ -372,6 +480,51 @@ async def test_get_url(hass: HomeAssistant): with pytest.raises(NoURLAvailableError): get_url(hass, allow_external=False, allow_internal=False) + with pytest.raises(NoURLAvailableError): + get_url(hass, require_current_request=True) + + with patch( + "homeassistant.helpers.network._get_request_host", return_value="example.com" + ), patch("homeassistant.helpers.network.current_request"): + assert get_url(hass, require_current_request=True) == "https://example.com" + assert ( + get_url(hass, require_current_request=True, require_ssl=True) + == "https://example.com" + ) + + with pytest.raises(NoURLAvailableError): + get_url(hass, require_current_request=True, allow_external=False) + + with patch( + "homeassistant.helpers.network._get_request_host", return_value="example.local" + ), patch("homeassistant.helpers.network.current_request"): + assert get_url(hass, require_current_request=True) == "http://example.local" + + with pytest.raises(NoURLAvailableError): + get_url(hass, require_current_request=True, allow_internal=False) + + with pytest.raises(NoURLAvailableError): + get_url(hass, require_current_request=True, require_ssl=True) + + with patch( + "homeassistant.helpers.network._get_request_host", + return_value="no_match.example.com", + ), pytest.raises(NoURLAvailableError): + _get_internal_url(hass, require_current_request=True) + + +async def test_get_request_host(hass: HomeAssistant): + """Test getting the host of the current web request from the request context.""" + with pytest.raises(NoURLAvailableError): + _get_request_host() + + with patch("homeassistant.helpers.network.current_request") as mock_request_context: + mock_request = Mock() + mock_request.url = "http://example.com:8123/test/request" + mock_request_context.get = Mock(return_value=mock_request) + + assert _get_request_host() == "example.com" + async def test_get_deprecated_base_url_internal(hass: HomeAssistant): """Test getting an internal instance URL from the deprecated base_url."""