OAuth2 to use current request header (#43668)

This commit is contained in:
Paulus Schoutsen 2020-11-27 08:55:34 +01:00 committed by GitHub
parent 69c2818c56
commit f9fa24950b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 258 additions and 90 deletions

View File

@ -90,8 +90,8 @@ class ToonLocalOAuth2Implementation(config_entry_oauth2_flow.LocalOAuth2Implemen
"""Initialize local Toon auth implementation."""
data = {
"grant_type": "authorization_code",
"code": external_data,
"redirect_uri": self.redirect_uri,
"code": external_data["code"],
"redirect_uri": external_data["state"]["redirect_uri"],
"tenant_id": self.tenant_id,
}

View File

@ -19,9 +19,9 @@ import voluptuous as vol
from yarl import URL
from homeassistant import config_entries
from homeassistant.components.http import HomeAssistantView
from homeassistant.components import http
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.network import NoURLAvailableError, get_url
from homeassistant.helpers.network import NoURLAvailableError
from .aiohttp_client import async_get_clientsession
@ -32,6 +32,7 @@ DATA_VIEW_REGISTERED = "oauth2_view_reg"
DATA_IMPLEMENTATIONS = "oauth2_impl"
DATA_PROVIDERS = "oauth2_providers"
AUTH_CALLBACK_PATH = "/auth/external/callback"
HEADER_FRONTEND_BASE = "HA-Frontend-Base"
CLOCK_OUT_OF_SYNC_MAX_SEC = 20
@ -64,7 +65,7 @@ class AbstractOAuth2Implementation(ABC):
Pass external data in with:
await hass.config_entries.flow.async_configure(
flow_id=flow_id, user_input=external_data
flow_id=flow_id, user_input={'code': 'abcd', 'state': { }
)
"""
@ -124,7 +125,17 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
@property
def redirect_uri(self) -> str:
"""Return the redirect uri."""
return f"{get_url(self.hass, require_current_request=True)}{AUTH_CALLBACK_PATH}"
req = http.current_request.get()
if req is None:
raise RuntimeError("No current request in context")
ha_host = req.headers.get(HEADER_FRONTEND_BASE)
if ha_host is None:
raise RuntimeError("No header in request")
return f"{ha_host}{AUTH_CALLBACK_PATH}"
@property
def extra_authorize_data(self) -> dict:
@ -133,14 +144,17 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
async def async_generate_authorize_url(self, flow_id: str) -> str:
"""Generate a url for the user to authorize."""
redirect_uri = self.redirect_uri
return str(
URL(self.authorize_url)
.with_query(
{
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"state": _encode_jwt(self.hass, {"flow_id": flow_id}),
"redirect_uri": redirect_uri,
"state": _encode_jwt(
self.hass, {"flow_id": flow_id, "redirect_uri": redirect_uri}
),
}
)
.update_query(self.extra_authorize_data)
@ -151,8 +165,8 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
return await self._token_request(
{
"grant_type": "authorization_code",
"code": external_data,
"redirect_uri": self.redirect_uri,
"code": external_data["code"],
"redirect_uri": external_data["state"]["redirect_uri"],
}
)
@ -384,7 +398,7 @@ def async_add_implementation_provider(
] = async_provide_implementation
class OAuth2AuthorizeCallbackView(HomeAssistantView):
class OAuth2AuthorizeCallbackView(http.HomeAssistantView):
"""OAuth2 Authorization Callback View."""
requires_auth = False
@ -406,7 +420,8 @@ class OAuth2AuthorizeCallbackView(HomeAssistantView):
return web.Response(text="Invalid state")
await hass.config_entries.flow.async_configure(
flow_id=state["flow_id"], user_input=request.query["code"]
flow_id=state["flow_id"],
user_input={"state": state, "code": request.query["code"]},
)
return web.Response(

View File

@ -4,7 +4,7 @@ from typing import Optional, cast
import yarl
from homeassistant.components.http import current_request
from homeassistant.components import http
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import bind_hass
@ -49,7 +49,7 @@ def get_url(
prefer_cloud: bool = False,
) -> str:
"""Get a URL to this instance."""
if require_current_request and current_request.get() is None:
if require_current_request and http.current_request.get() is None:
raise NoURLAvailableError
order = [TYPE_URL_INTERNAL, TYPE_URL_EXTERNAL]
@ -125,7 +125,7 @@ def get_url(
def _get_request_host() -> Optional[str]:
"""Get the host address of the current request."""
request = current_request.get()
request = http.current_request.get()
if request is None:
raise NoURLAvailableError
return yarl.URL(request.url).host

View File

@ -13,7 +13,9 @@ CLIENT_ID = "1234"
CLIENT_SECRET = "5678"
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
@ -27,7 +29,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
"NEW_DOMAIN", context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
assert result["url"] == (
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"

View File

@ -91,7 +91,9 @@ async def test_abort_if_existing_entry(hass):
assert result["reason"] == "single_instance_allowed"
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
@ -109,7 +111,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == (

View File

@ -14,7 +14,9 @@ CLIENT_ID = "1234"
CLIENT_SECRET = "5678"
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
@ -31,7 +33,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
"home_connect", context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == (

View File

@ -12,7 +12,9 @@ PROJECT_ID = "project-id-4321"
SUBSCRIBER_ID = "subscriber-id-9876"
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
@ -31,7 +33,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
oauth_authorize = OAUTH2_AUTHORIZE.format(project_id=PROJECT_ID)
assert result["url"] == (

View File

@ -42,7 +42,9 @@ async def test_abort_if_existing_entry(hass):
assert result["reason"] == "already_configured"
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
@ -56,7 +58,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
"netatmo", context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
scope = "+".join(
[

View File

@ -333,7 +333,9 @@ async def test_abort_cloud_flow_if_local_device_exists(hass):
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
async def test_full_user_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_user_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
@ -351,7 +353,13 @@ async def test_full_user_flow(hass, aiohttp_client, aioclient_mock, current_requ
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"environment": ENV_CLOUD}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
client = await aiohttp_client(hass.http.app)
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")

View File

@ -52,7 +52,9 @@ async def test_abort_if_existing_entry(hass):
assert result["reason"] == "single_instance_allowed"
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
@ -69,7 +71,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == (

View File

@ -40,7 +40,9 @@ async def test_zeroconf_abort_if_existing_entry(hass):
assert result["reason"] == "already_configured"
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check a full flow."""
assert await setup.async_setup_component(
hass,
@ -56,7 +58,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
)
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == (
@ -103,7 +111,7 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_abort_if_spotify_error(
hass, aiohttp_client, aioclient_mock, current_request
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check Spotify errors causes flow to abort."""
await setup.async_setup_component(
@ -120,7 +128,13 @@ async def test_abort_if_spotify_error(
)
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
client = await aiohttp_client(hass.http.app)
await client.get(f"/auth/external/callback?code=abcd&state={state}")
@ -144,7 +158,9 @@ async def test_abort_if_spotify_error(
assert result["reason"] == "connection_error"
async def test_reauthentication(hass, aiohttp_client, aioclient_mock, current_request):
async def test_reauthentication(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Test Spotify reauthentication."""
await setup.async_setup_component(
hass,
@ -173,7 +189,13 @@ async def test_reauthentication(hass, aiohttp_client, aioclient_mock, current_re
result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {})
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
client = await aiohttp_client(hass.http.app)
await client.get(f"/auth/external/callback?code=abcd&state={state}")
@ -202,7 +224,7 @@ async def test_reauthentication(hass, aiohttp_client, aioclient_mock, current_re
async def test_reauth_account_mismatch(
hass, aiohttp_client, aioclient_mock, current_request
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Test Spotify reauthentication with different account."""
await setup.async_setup_component(
@ -230,7 +252,13 @@ async def test_reauth_account_mismatch(
result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {})
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
client = await aiohttp_client(hass.http.app)
await client.get(f"/auth/external/callback?code=abcd&state={state}")

View File

@ -40,7 +40,7 @@ async def test_abort_if_no_configuration(hass):
async def test_full_flow_implementation(
hass, aiohttp_client, aioclient_mock, current_request
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Test registering an integration and finishing flow works."""
await setup_component(hass)
@ -53,7 +53,13 @@ async def test_full_flow_implementation(
assert result["step_id"] == "pick_implementation"
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {"implementation": "eneco"}
@ -97,7 +103,9 @@ async def test_full_flow_implementation(
}
async def test_no_agreements(hass, aiohttp_client, aioclient_mock, current_request):
async def test_no_agreements(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Test abort when there are no displays."""
await setup_component(hass)
result = await hass.config_entries.flow.async_init(
@ -105,7 +113,13 @@ async def test_no_agreements(hass, aiohttp_client, aioclient_mock, current_reque
)
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
await hass.config_entries.flow.async_configure(
result["flow_id"], {"implementation": "eneco"}
)
@ -130,7 +144,7 @@ async def test_no_agreements(hass, aiohttp_client, aioclient_mock, current_reque
async def test_multiple_agreements(
hass, aiohttp_client, aioclient_mock, current_request
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Test abort when there are no displays."""
await setup_component(hass)
@ -139,7 +153,13 @@ async def test_multiple_agreements(
)
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
await hass.config_entries.flow.async_configure(
result["flow_id"], {"implementation": "eneco"}
)
@ -174,7 +194,7 @@ async def test_multiple_agreements(
async def test_agreement_already_set_up(
hass, aiohttp_client, aioclient_mock, current_request
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Test showing display form again if display already exists."""
await setup_component(hass)
@ -184,7 +204,13 @@ async def test_agreement_already_set_up(
)
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
await hass.config_entries.flow.async_configure(
result["flow_id"], {"implementation": "eneco"}
)
@ -208,14 +234,22 @@ async def test_agreement_already_set_up(
assert result3["reason"] == "already_configured"
async def test_toon_abort(hass, aiohttp_client, aioclient_mock, current_request):
async def test_toon_abort(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Test we abort on Toon error."""
await setup_component(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
await hass.config_entries.flow.async_configure(
result["flow_id"], {"implementation": "eneco"}
)
@ -239,7 +273,7 @@ async def test_toon_abort(hass, aiohttp_client, aioclient_mock, current_request)
assert result2["reason"] == "connection_error"
async def test_import(hass):
async def test_import(hass, current_request_with_host):
"""Test if importing step works."""
await setup_component(hass)
@ -253,7 +287,9 @@ async def test_import(hass):
assert result["reason"] == "already_in_progress"
async def test_import_migration(hass, aiohttp_client, aioclient_mock, current_request):
async def test_import_migration(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Test if importing step with migration works."""
old_entry = MockConfigEntry(domain=DOMAIN, unique_id=123, version=1)
old_entry.add_to_hass(hass)
@ -269,7 +305,13 @@ async def test_import_migration(hass, aiohttp_client, aioclient_mock, current_re
assert flows[0]["context"][CONF_MIGRATE] == old_entry.entry_id
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": flows[0]["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": flows[0]["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
await hass.config_entries.flow.async_configure(
flows[0]["flow_id"], {"implementation": "eneco"}
)

View File

@ -197,7 +197,11 @@ class ComponentFactory:
assert result
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(
self._hass, {"flow_id": result["flow_id"]}
self._hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "http://127.0.0.1:8080/auth/external/callback",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == (

View File

@ -71,7 +71,13 @@ async def test_config_reauth_profile(
)
# pylint: disable=protected-access
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
client: TestClient = await aiohttp_client(hass.http.app)
resp = await client.get(f"{AUTH_CALLBACK_PATH}?code=abcd&state={state}")

View File

@ -21,7 +21,9 @@ async def test_abort_if_existing_entry(hass):
assert result["reason"] == "single_instance_allowed"
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
@ -35,7 +37,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
"xbox", context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
scope = "+".join(["Xboxlive.signin", "Xboxlive.offline_access"])

View File

@ -7,6 +7,7 @@ import ssl
import threading
from aiohttp.test_utils import make_mocked_request
import multidict
import pytest
import requests_mock as _requests_mock
@ -22,11 +23,11 @@ from homeassistant.components.websocket_api.auth import (
from homeassistant.components.websocket_api.http import URL
from homeassistant.const import ATTR_NOW, EVENT_TIME_CHANGED
from homeassistant.exceptions import ServiceNotFound
from homeassistant.helpers import event
from homeassistant.helpers import config_entry_oauth2_flow, event
from homeassistant.setup import async_setup_component
from homeassistant.util import location
from tests.async_mock import MagicMock, Mock, patch
from tests.async_mock import MagicMock, patch
from tests.ignore_uncaught_exceptions import IGNORE_UNCAUGHT_EXCEPTIONS
pytest.register_assert_rewrite("tests.common")
@ -277,19 +278,29 @@ def hass_client(hass, aiohttp_client, hass_access_token):
@pytest.fixture
def current_request(hass):
def current_request():
"""Mock current request."""
with patch("homeassistant.helpers.network.current_request") as mock_request_context:
with patch("homeassistant.components.http.current_request") as mock_request_context:
mocked_request = make_mocked_request(
"GET",
"/some/request",
headers={"Host": "example.com"},
sslcontext=ssl.SSLContext(ssl.PROTOCOL_TLS),
)
mock_request_context.get = Mock(return_value=mocked_request)
mock_request_context.get.return_value = mocked_request
yield mock_request_context
@pytest.fixture
def current_request_with_host(current_request):
"""Mock current request with a host header."""
new_headers = multidict.CIMultiDict(current_request.get.return_value.headers)
new_headers[config_entry_oauth2_flow.HEADER_FRONTEND_BASE] = "https://example.com"
current_request.get.return_value = current_request.get.return_value.clone(
headers=new_headers
)
@pytest.fixture
def hass_ws_client(aiohttp_client, hass_access_token, hass):
"""Websocket client fixture connected to websocket server."""

View File

@ -6,7 +6,6 @@ import time
import pytest
from homeassistant import config_entries, data_entry_flow, setup
from homeassistant.config import async_process_ha_core_config
from homeassistant.helpers import config_entry_oauth2_flow
from homeassistant.helpers.network import NoURLAvailableError
@ -146,14 +145,14 @@ async def test_abort_if_no_url_available(hass, flow_handler, local_impl):
async def test_abort_if_oauth_error(
hass, flow_handler, local_impl, aiohttp_client, aioclient_mock, current_request
hass,
flow_handler,
local_impl,
aiohttp_client,
aioclient_mock,
current_request_with_host,
):
"""Check bad oauth token."""
await async_process_ha_core_config(
hass,
{"external_url": "https://example.com"},
)
flow_handler.async_register_implementation(hass, local_impl)
config_entry_oauth2_flow.async_register_implementation(
hass, TEST_DOMAIN, MockOAuth2Implementation()
@ -171,7 +170,13 @@ async def test_abort_if_oauth_error(
result["flow_id"], user_input={"implementation": TEST_DOMAIN}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == (
@ -203,10 +208,6 @@ async def test_abort_if_oauth_error(
async def test_step_discovery(hass, flow_handler, local_impl):
"""Check flow triggers from discovery."""
await async_process_ha_core_config(
hass,
{"external_url": "https://example.com"},
)
flow_handler.async_register_implementation(hass, local_impl)
config_entry_oauth2_flow.async_register_implementation(
hass, TEST_DOMAIN, MockOAuth2Implementation()
@ -222,11 +223,6 @@ async def test_step_discovery(hass, flow_handler, local_impl):
async def test_abort_discovered_multiple(hass, flow_handler, local_impl):
"""Test if aborts when discovered multiple times."""
await async_process_ha_core_config(
hass,
{"external_url": "https://example.com"},
)
flow_handler.async_register_implementation(hass, local_impl)
config_entry_oauth2_flow.async_register_implementation(
hass, TEST_DOMAIN, MockOAuth2Implementation()
@ -249,10 +245,6 @@ async def test_abort_discovered_multiple(hass, flow_handler, local_impl):
async def test_abort_discovered_existing_entries(hass, flow_handler, local_impl):
"""Test if abort discovery when entries exists."""
await async_process_ha_core_config(
hass,
{"external_url": "https://example.com"},
)
flow_handler.async_register_implementation(hass, local_impl)
config_entry_oauth2_flow.async_register_implementation(
hass, TEST_DOMAIN, MockOAuth2Implementation()
@ -273,14 +265,14 @@ async def test_abort_discovered_existing_entries(hass, flow_handler, local_impl)
async def test_full_flow(
hass, flow_handler, local_impl, aiohttp_client, aioclient_mock, current_request
hass,
flow_handler,
local_impl,
aiohttp_client,
aioclient_mock,
current_request_with_host,
):
"""Check full flow."""
await async_process_ha_core_config(
hass,
{"external_url": "https://example.com"},
)
flow_handler.async_register_implementation(hass, local_impl)
config_entry_oauth2_flow.async_register_implementation(
hass, TEST_DOMAIN, MockOAuth2Implementation()
@ -298,7 +290,13 @@ async def test_full_flow(
result["flow_id"], user_input={"implementation": TEST_DOMAIN}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == (

View File

@ -500,7 +500,7 @@ async def test_get_url(hass: HomeAssistant):
with patch(
"homeassistant.helpers.network._get_request_host", return_value="example.com"
), patch("homeassistant.helpers.network.current_request"):
), patch("homeassistant.components.http.current_request"):
assert get_url(hass, require_current_request=True) == "https://example.com"
assert (
get_url(hass, require_current_request=True, require_ssl=True)
@ -512,7 +512,7 @@ async def test_get_url(hass: HomeAssistant):
with patch(
"homeassistant.helpers.network._get_request_host", return_value="example.local"
), patch("homeassistant.helpers.network.current_request"):
), patch("homeassistant.components.http.current_request"):
assert get_url(hass, require_current_request=True) == "http://example.local"
with pytest.raises(NoURLAvailableError):
@ -533,7 +533,7 @@ async def test_get_request_host(hass: HomeAssistant):
with pytest.raises(NoURLAvailableError):
_get_request_host()
with patch("homeassistant.helpers.network.current_request") as mock_request_context:
with patch("homeassistant.components.http.current_request") as mock_request_context:
mock_request = Mock()
mock_request.url = "http://example.com:8123/test/request"
mock_request_context.get = Mock(return_value=mock_request)