Fix getting the host for the current request (#126882)

This commit is contained in:
J. Nick Koston 2024-09-27 03:36:05 -05:00 committed by Franck Nijhof
parent 3d1bd626b0
commit b079a94bef
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
2 changed files with 36 additions and 6 deletions

View File

@ -6,6 +6,7 @@ from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
from ipaddress import ip_address from ipaddress import ip_address
from aiohttp import hdrs
from hass_nabucasa import remote from hass_nabucasa import remote
import yarl import yarl
@ -216,7 +217,10 @@ def _get_request_host() -> str | None:
"""Get the host address of the current request.""" """Get the host address of the current request."""
if (request := http.current_request.get()) is None: if (request := http.current_request.get()) is None:
raise NoURLAvailableError raise NoURLAvailableError
return request.url.host # partition the host to remove the port
# because the raw host header can contain the port
host = request.headers.get(hdrs.HOST)
return None if host is None else host.partition(":")[0]
@bind_hass @bind_hass

View File

@ -2,6 +2,8 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from aiohttp import hdrs
from multidict import CIMultiDict, CIMultiDictProxy
import pytest import pytest
from yarl import URL from yarl import URL
@ -592,7 +594,11 @@ async def test_get_request_host(hass: HomeAssistant) -> None:
with patch("homeassistant.components.http.current_request") as mock_request_context: with patch("homeassistant.components.http.current_request") as mock_request_context:
mock_request = Mock() mock_request = Mock()
mock_request.headers = CIMultiDictProxy(
CIMultiDict({hdrs.HOST: "example.com:8123"})
)
mock_request.url = URL("http://example.com:8123/test/request") mock_request.url = URL("http://example.com:8123/test/request")
mock_request.host = "example.com:8123"
mock_request_context.get = Mock(return_value=mock_request) mock_request_context.get = Mock(return_value=mock_request)
assert _get_request_host() == "example.com" assert _get_request_host() == "example.com"
@ -683,11 +689,19 @@ async def test_is_internal_request(hass: HomeAssistant, mock_current_request) ->
mock_current_request.return_value = None mock_current_request.return_value = None
assert not is_internal_request(hass) assert not is_internal_request(hass)
mock_current_request.return_value = Mock(url=URL("http://example.local:8123")) mock_current_request.return_value = Mock(
headers=CIMultiDictProxy(CIMultiDict({hdrs.HOST: "example.local:8123"})),
host="example.local:8123",
url=URL("http://example.local:8123"),
)
assert is_internal_request(hass) assert is_internal_request(hass)
mock_current_request.return_value = Mock( mock_current_request.return_value = Mock(
url=URL("http://no_match.example.local:8123") headers=CIMultiDictProxy(
CIMultiDict({hdrs.HOST: "no_match.example.local:8123"})
),
host="no_match.example.local:8123",
url=URL("http://no_match.example.local:8123"),
) )
assert not is_internal_request(hass) assert not is_internal_request(hass)
@ -700,18 +714,30 @@ async def test_is_internal_request(hass: HomeAssistant, mock_current_request) ->
assert hass.config.internal_url == "http://192.168.0.1:8123" assert hass.config.internal_url == "http://192.168.0.1:8123"
assert not is_internal_request(hass) assert not is_internal_request(hass)
mock_current_request.return_value = Mock(url=URL("http://192.168.0.1:8123")) mock_current_request.return_value = Mock(
headers=CIMultiDictProxy(CIMultiDict({hdrs.HOST: "192.168.0.1:8123"})),
host="192.168.0.1:8123",
url=URL("http://192.168.0.1:8123"),
)
assert is_internal_request(hass) assert is_internal_request(hass)
# Test for matching against local IP # Test for matching against local IP
hass.config.api = Mock(use_ssl=False, local_ip="192.168.123.123", port=8123) hass.config.api = Mock(use_ssl=False, local_ip="192.168.123.123", port=8123)
for allowed in ("127.0.0.1", "192.168.123.123"): for allowed in ("127.0.0.1", "192.168.123.123"):
mock_current_request.return_value = Mock(url=URL(f"http://{allowed}:8123")) mock_current_request.return_value = Mock(
headers=CIMultiDictProxy(CIMultiDict({hdrs.HOST: f"{allowed}:8123"})),
host=f"{allowed}:8123",
url=URL(f"http://{allowed}:8123"),
)
assert is_internal_request(hass), mock_current_request.return_value.url assert is_internal_request(hass), mock_current_request.return_value.url
# Test for matching against HassOS hostname # Test for matching against HassOS hostname
for allowed in ("hellohost", "hellohost.local"): for allowed in ("hellohost", "hellohost.local"):
mock_current_request.return_value = Mock(url=URL(f"http://{allowed}:8123")) mock_current_request.return_value = Mock(
headers=CIMultiDictProxy(CIMultiDict({hdrs.HOST: f"{allowed}:8123"})),
host=f"{allowed}:8123",
url=URL(f"http://{allowed}:8123"),
)
assert is_internal_request(hass), mock_current_request.return_value.url assert is_internal_request(hass), mock_current_request.return_value.url