diff --git a/homeassistant/components/toon/oauth2.py b/homeassistant/components/toon/oauth2.py index 2622e0a9027..e3a83583ac6 100644 --- a/homeassistant/components/toon/oauth2.py +++ b/homeassistant/components/toon/oauth2.py @@ -90,8 +90,8 @@ class ToonLocalOAuth2Implementation(config_entry_oauth2_flow.LocalOAuth2Implemen """Initialize local Toon auth implementation.""" data = { "grant_type": "authorization_code", - "code": external_data, - "redirect_uri": self.redirect_uri, + "code": external_data["code"], + "redirect_uri": external_data["state"]["redirect_uri"], "tenant_id": self.tenant_id, } diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py index 4d05ad7beab..526a774cc39 100644 --- a/homeassistant/helpers/config_entry_oauth2_flow.py +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -19,9 +19,9 @@ import voluptuous as vol from yarl import URL from homeassistant import config_entries -from homeassistant.components.http import HomeAssistantView +from homeassistant.components import http from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers.network import NoURLAvailableError, get_url +from homeassistant.helpers.network import NoURLAvailableError from .aiohttp_client import async_get_clientsession @@ -32,6 +32,7 @@ DATA_VIEW_REGISTERED = "oauth2_view_reg" DATA_IMPLEMENTATIONS = "oauth2_impl" DATA_PROVIDERS = "oauth2_providers" AUTH_CALLBACK_PATH = "/auth/external/callback" +HEADER_FRONTEND_BASE = "HA-Frontend-Base" CLOCK_OUT_OF_SYNC_MAX_SEC = 20 @@ -64,7 +65,7 @@ class AbstractOAuth2Implementation(ABC): Pass external data in with: await hass.config_entries.flow.async_configure( - flow_id=flow_id, user_input=external_data + flow_id=flow_id, user_input={'code': 'abcd', 'state': { … } ) """ @@ -124,7 +125,17 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation): @property def redirect_uri(self) -> str: """Return the redirect uri.""" - return f"{get_url(self.hass, require_current_request=True)}{AUTH_CALLBACK_PATH}" + req = http.current_request.get() + + if req is None: + raise RuntimeError("No current request in context") + + ha_host = req.headers.get(HEADER_FRONTEND_BASE) + + if ha_host is None: + raise RuntimeError("No header in request") + + return f"{ha_host}{AUTH_CALLBACK_PATH}" @property def extra_authorize_data(self) -> dict: @@ -133,14 +144,17 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation): async def async_generate_authorize_url(self, flow_id: str) -> str: """Generate a url for the user to authorize.""" + redirect_uri = self.redirect_uri return str( URL(self.authorize_url) .with_query( { "response_type": "code", "client_id": self.client_id, - "redirect_uri": self.redirect_uri, - "state": _encode_jwt(self.hass, {"flow_id": flow_id}), + "redirect_uri": redirect_uri, + "state": _encode_jwt( + self.hass, {"flow_id": flow_id, "redirect_uri": redirect_uri} + ), } ) .update_query(self.extra_authorize_data) @@ -151,8 +165,8 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation): return await self._token_request( { "grant_type": "authorization_code", - "code": external_data, - "redirect_uri": self.redirect_uri, + "code": external_data["code"], + "redirect_uri": external_data["state"]["redirect_uri"], } ) @@ -384,7 +398,7 @@ def async_add_implementation_provider( ] = async_provide_implementation -class OAuth2AuthorizeCallbackView(HomeAssistantView): +class OAuth2AuthorizeCallbackView(http.HomeAssistantView): """OAuth2 Authorization Callback View.""" requires_auth = False @@ -406,7 +420,8 @@ class OAuth2AuthorizeCallbackView(HomeAssistantView): return web.Response(text="Invalid state") await hass.config_entries.flow.async_configure( - flow_id=state["flow_id"], user_input=request.query["code"] + flow_id=state["flow_id"], + user_input={"state": state, "code": request.query["code"]}, ) return web.Response( diff --git a/homeassistant/helpers/network.py b/homeassistant/helpers/network.py index 3990662dc02..4e066eaa13c 100644 --- a/homeassistant/helpers/network.py +++ b/homeassistant/helpers/network.py @@ -4,7 +4,7 @@ from typing import Optional, cast import yarl -from homeassistant.components.http import current_request +from homeassistant.components import http from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import bind_hass @@ -49,7 +49,7 @@ def get_url( prefer_cloud: bool = False, ) -> str: """Get a URL to this instance.""" - if require_current_request and current_request.get() is None: + if require_current_request and http.current_request.get() is None: raise NoURLAvailableError order = [TYPE_URL_INTERNAL, TYPE_URL_EXTERNAL] @@ -125,7 +125,7 @@ def get_url( def _get_request_host() -> Optional[str]: """Get the host address of the current request.""" - request = current_request.get() + request = http.current_request.get() if request is None: raise NoURLAvailableError return yarl.URL(request.url).host diff --git a/script/scaffold/templates/config_flow_oauth2/tests/test_config_flow.py b/script/scaffold/templates/config_flow_oauth2/tests/test_config_flow.py index 36a42431cf3..ed974601646 100644 --- a/script/scaffold/templates/config_flow_oauth2/tests/test_config_flow.py +++ b/script/scaffold/templates/config_flow_oauth2/tests/test_config_flow.py @@ -13,7 +13,9 @@ CLIENT_ID = "1234" CLIENT_SECRET = "5678" -async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): +async def test_full_flow( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Check full flow.""" assert await setup.async_setup_component( hass, @@ -27,7 +29,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): result = await hass.config_entries.flow.async_init( "NEW_DOMAIN", context={"source": config_entries.SOURCE_USER} ) - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) assert result["url"] == ( f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}" diff --git a/tests/components/almond/test_config_flow.py b/tests/components/almond/test_config_flow.py index b2144205895..afcad55bf2a 100644 --- a/tests/components/almond/test_config_flow.py +++ b/tests/components/almond/test_config_flow.py @@ -91,7 +91,9 @@ async def test_abort_if_existing_entry(hass): assert result["reason"] == "single_instance_allowed" -async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): +async def test_full_flow( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Check full flow.""" assert await setup.async_setup_component( hass, @@ -109,7 +111,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP assert result["url"] == ( diff --git a/tests/components/home_connect/test_config_flow.py b/tests/components/home_connect/test_config_flow.py index 5d65df98e5b..5c94f8b3362 100644 --- a/tests/components/home_connect/test_config_flow.py +++ b/tests/components/home_connect/test_config_flow.py @@ -14,7 +14,9 @@ CLIENT_ID = "1234" CLIENT_SECRET = "5678" -async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): +async def test_full_flow( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Check full flow.""" assert await setup.async_setup_component( hass, @@ -31,7 +33,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): result = await hass.config_entries.flow.async_init( "home_connect", context={"source": config_entries.SOURCE_USER} ) - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP assert result["url"] == ( diff --git a/tests/components/nest/test_config_flow_sdm.py b/tests/components/nest/test_config_flow_sdm.py index 1df751f3980..6573b17980e 100644 --- a/tests/components/nest/test_config_flow_sdm.py +++ b/tests/components/nest/test_config_flow_sdm.py @@ -12,7 +12,9 @@ PROJECT_ID = "project-id-4321" SUBSCRIBER_ID = "subscriber-id-9876" -async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): +async def test_full_flow( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Check full flow.""" assert await setup.async_setup_component( hass, @@ -31,7 +33,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) oauth_authorize = OAUTH2_AUTHORIZE.format(project_id=PROJECT_ID) assert result["url"] == ( diff --git a/tests/components/netatmo/test_config_flow.py b/tests/components/netatmo/test_config_flow.py index 8cee7a8c750..74a5d8dcc92 100644 --- a/tests/components/netatmo/test_config_flow.py +++ b/tests/components/netatmo/test_config_flow.py @@ -42,7 +42,9 @@ async def test_abort_if_existing_entry(hass): assert result["reason"] == "already_configured" -async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): +async def test_full_flow( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Check full flow.""" assert await setup.async_setup_component( hass, @@ -56,7 +58,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): result = await hass.config_entries.flow.async_init( "netatmo", context={"source": config_entries.SOURCE_USER} ) - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) scope = "+".join( [ diff --git a/tests/components/smappee/test_config_flow.py b/tests/components/smappee/test_config_flow.py index 7434d469f96..55d063c2b1c 100644 --- a/tests/components/smappee/test_config_flow.py +++ b/tests/components/smappee/test_config_flow.py @@ -333,7 +333,9 @@ async def test_abort_cloud_flow_if_local_device_exists(hass): assert len(hass.config_entries.async_entries(DOMAIN)) == 1 -async def test_full_user_flow(hass, aiohttp_client, aioclient_mock, current_request): +async def test_full_user_flow( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Check full flow.""" assert await setup.async_setup_component( hass, @@ -351,7 +353,13 @@ async def test_full_user_flow(hass, aiohttp_client, aioclient_mock, current_requ result = await hass.config_entries.flow.async_configure( result["flow_id"], {"environment": ENV_CLOUD} ) - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) client = await aiohttp_client(hass.http.app) resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") diff --git a/tests/components/somfy/test_config_flow.py b/tests/components/somfy/test_config_flow.py index 89b4fbe9b13..4276a6a18d4 100644 --- a/tests/components/somfy/test_config_flow.py +++ b/tests/components/somfy/test_config_flow.py @@ -52,7 +52,9 @@ async def test_abort_if_existing_entry(hass): assert result["reason"] == "single_instance_allowed" -async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): +async def test_full_flow( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Check full flow.""" assert await setup.async_setup_component( hass, @@ -69,7 +71,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP assert result["url"] == ( diff --git a/tests/components/spotify/test_config_flow.py b/tests/components/spotify/test_config_flow.py index 3b3c85dd828..53e87e5bdae 100644 --- a/tests/components/spotify/test_config_flow.py +++ b/tests/components/spotify/test_config_flow.py @@ -40,7 +40,9 @@ async def test_zeroconf_abort_if_existing_entry(hass): assert result["reason"] == "already_configured" -async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): +async def test_full_flow( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Check a full flow.""" assert await setup.async_setup_component( hass, @@ -56,7 +58,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): ) # pylint: disable=protected-access - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP assert result["url"] == ( @@ -103,7 +111,7 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): async def test_abort_if_spotify_error( - hass, aiohttp_client, aioclient_mock, current_request + hass, aiohttp_client, aioclient_mock, current_request_with_host ): """Check Spotify errors causes flow to abort.""" await setup.async_setup_component( @@ -120,7 +128,13 @@ async def test_abort_if_spotify_error( ) # pylint: disable=protected-access - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) client = await aiohttp_client(hass.http.app) await client.get(f"/auth/external/callback?code=abcd&state={state}") @@ -144,7 +158,9 @@ async def test_abort_if_spotify_error( assert result["reason"] == "connection_error" -async def test_reauthentication(hass, aiohttp_client, aioclient_mock, current_request): +async def test_reauthentication( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Test Spotify reauthentication.""" await setup.async_setup_component( hass, @@ -173,7 +189,13 @@ async def test_reauthentication(hass, aiohttp_client, aioclient_mock, current_re result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {}) # pylint: disable=protected-access - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) client = await aiohttp_client(hass.http.app) await client.get(f"/auth/external/callback?code=abcd&state={state}") @@ -202,7 +224,7 @@ async def test_reauthentication(hass, aiohttp_client, aioclient_mock, current_re async def test_reauth_account_mismatch( - hass, aiohttp_client, aioclient_mock, current_request + hass, aiohttp_client, aioclient_mock, current_request_with_host ): """Test Spotify reauthentication with different account.""" await setup.async_setup_component( @@ -230,7 +252,13 @@ async def test_reauth_account_mismatch( result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {}) # pylint: disable=protected-access - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) client = await aiohttp_client(hass.http.app) await client.get(f"/auth/external/callback?code=abcd&state={state}") diff --git a/tests/components/toon/test_config_flow.py b/tests/components/toon/test_config_flow.py index 6fb7a7b53dc..b7eb3898b47 100644 --- a/tests/components/toon/test_config_flow.py +++ b/tests/components/toon/test_config_flow.py @@ -40,7 +40,7 @@ async def test_abort_if_no_configuration(hass): async def test_full_flow_implementation( - hass, aiohttp_client, aioclient_mock, current_request + hass, aiohttp_client, aioclient_mock, current_request_with_host ): """Test registering an integration and finishing flow works.""" await setup_component(hass) @@ -53,7 +53,13 @@ async def test_full_flow_implementation( assert result["step_id"] == "pick_implementation" # pylint: disable=protected-access - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) result2 = await hass.config_entries.flow.async_configure( result["flow_id"], {"implementation": "eneco"} @@ -97,7 +103,9 @@ async def test_full_flow_implementation( } -async def test_no_agreements(hass, aiohttp_client, aioclient_mock, current_request): +async def test_no_agreements( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Test abort when there are no displays.""" await setup_component(hass) result = await hass.config_entries.flow.async_init( @@ -105,7 +113,13 @@ async def test_no_agreements(hass, aiohttp_client, aioclient_mock, current_reque ) # pylint: disable=protected-access - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) await hass.config_entries.flow.async_configure( result["flow_id"], {"implementation": "eneco"} ) @@ -130,7 +144,7 @@ async def test_no_agreements(hass, aiohttp_client, aioclient_mock, current_reque async def test_multiple_agreements( - hass, aiohttp_client, aioclient_mock, current_request + hass, aiohttp_client, aioclient_mock, current_request_with_host ): """Test abort when there are no displays.""" await setup_component(hass) @@ -139,7 +153,13 @@ async def test_multiple_agreements( ) # pylint: disable=protected-access - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) await hass.config_entries.flow.async_configure( result["flow_id"], {"implementation": "eneco"} ) @@ -174,7 +194,7 @@ async def test_multiple_agreements( async def test_agreement_already_set_up( - hass, aiohttp_client, aioclient_mock, current_request + hass, aiohttp_client, aioclient_mock, current_request_with_host ): """Test showing display form again if display already exists.""" await setup_component(hass) @@ -184,7 +204,13 @@ async def test_agreement_already_set_up( ) # pylint: disable=protected-access - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) await hass.config_entries.flow.async_configure( result["flow_id"], {"implementation": "eneco"} ) @@ -208,14 +234,22 @@ async def test_agreement_already_set_up( assert result3["reason"] == "already_configured" -async def test_toon_abort(hass, aiohttp_client, aioclient_mock, current_request): +async def test_toon_abort( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Test we abort on Toon error.""" await setup_component(hass) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER} ) # pylint: disable=protected-access - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) await hass.config_entries.flow.async_configure( result["flow_id"], {"implementation": "eneco"} ) @@ -239,7 +273,7 @@ async def test_toon_abort(hass, aiohttp_client, aioclient_mock, current_request) assert result2["reason"] == "connection_error" -async def test_import(hass): +async def test_import(hass, current_request_with_host): """Test if importing step works.""" await setup_component(hass) @@ -253,7 +287,9 @@ async def test_import(hass): assert result["reason"] == "already_in_progress" -async def test_import_migration(hass, aiohttp_client, aioclient_mock, current_request): +async def test_import_migration( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Test if importing step with migration works.""" old_entry = MockConfigEntry(domain=DOMAIN, unique_id=123, version=1) old_entry.add_to_hass(hass) @@ -269,7 +305,13 @@ async def test_import_migration(hass, aiohttp_client, aioclient_mock, current_re assert flows[0]["context"][CONF_MIGRATE] == old_entry.entry_id # pylint: disable=protected-access - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": flows[0]["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": flows[0]["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) await hass.config_entries.flow.async_configure( flows[0]["flow_id"], {"implementation": "eneco"} ) diff --git a/tests/components/withings/common.py b/tests/components/withings/common.py index a09876868a7..000900c3355 100644 --- a/tests/components/withings/common.py +++ b/tests/components/withings/common.py @@ -197,7 +197,11 @@ class ComponentFactory: assert result # pylint: disable=protected-access state = config_entry_oauth2_flow._encode_jwt( - self._hass, {"flow_id": result["flow_id"]} + self._hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "http://127.0.0.1:8080/auth/external/callback", + }, ) assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP assert result["url"] == ( diff --git a/tests/components/withings/test_config_flow.py b/tests/components/withings/test_config_flow.py index cb0ea5b29ab..8380c134013 100644 --- a/tests/components/withings/test_config_flow.py +++ b/tests/components/withings/test_config_flow.py @@ -71,7 +71,13 @@ async def test_config_reauth_profile( ) # pylint: disable=protected-access - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) client: TestClient = await aiohttp_client(hass.http.app) resp = await client.get(f"{AUTH_CALLBACK_PATH}?code=abcd&state={state}") diff --git a/tests/components/xbox/test_config_flow.py b/tests/components/xbox/test_config_flow.py index 176c5eea60a..516a57c039b 100644 --- a/tests/components/xbox/test_config_flow.py +++ b/tests/components/xbox/test_config_flow.py @@ -21,7 +21,9 @@ async def test_abort_if_existing_entry(hass): assert result["reason"] == "single_instance_allowed" -async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): +async def test_full_flow( + hass, aiohttp_client, aioclient_mock, current_request_with_host +): """Check full flow.""" assert await setup.async_setup_component( hass, @@ -35,7 +37,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request): result = await hass.config_entries.flow.async_init( "xbox", context={"source": config_entries.SOURCE_USER} ) - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) scope = "+".join(["Xboxlive.signin", "Xboxlive.offline_access"]) diff --git a/tests/conftest.py b/tests/conftest.py index 285179e3a9b..d5c4c61ddf1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ import ssl import threading from aiohttp.test_utils import make_mocked_request +import multidict import pytest import requests_mock as _requests_mock @@ -22,11 +23,11 @@ from homeassistant.components.websocket_api.auth import ( from homeassistant.components.websocket_api.http import URL from homeassistant.const import ATTR_NOW, EVENT_TIME_CHANGED from homeassistant.exceptions import ServiceNotFound -from homeassistant.helpers import event +from homeassistant.helpers import config_entry_oauth2_flow, event from homeassistant.setup import async_setup_component from homeassistant.util import location -from tests.async_mock import MagicMock, Mock, patch +from tests.async_mock import MagicMock, patch from tests.ignore_uncaught_exceptions import IGNORE_UNCAUGHT_EXCEPTIONS pytest.register_assert_rewrite("tests.common") @@ -277,19 +278,29 @@ def hass_client(hass, aiohttp_client, hass_access_token): @pytest.fixture -def current_request(hass): +def current_request(): """Mock current request.""" - with patch("homeassistant.helpers.network.current_request") as mock_request_context: + with patch("homeassistant.components.http.current_request") as mock_request_context: mocked_request = make_mocked_request( "GET", "/some/request", headers={"Host": "example.com"}, sslcontext=ssl.SSLContext(ssl.PROTOCOL_TLS), ) - mock_request_context.get = Mock(return_value=mocked_request) + mock_request_context.get.return_value = mocked_request yield mock_request_context +@pytest.fixture +def current_request_with_host(current_request): + """Mock current request with a host header.""" + new_headers = multidict.CIMultiDict(current_request.get.return_value.headers) + new_headers[config_entry_oauth2_flow.HEADER_FRONTEND_BASE] = "https://example.com" + current_request.get.return_value = current_request.get.return_value.clone( + headers=new_headers + ) + + @pytest.fixture def hass_ws_client(aiohttp_client, hass_access_token, hass): """Websocket client fixture connected to websocket server.""" diff --git a/tests/helpers/test_config_entry_oauth2_flow.py b/tests/helpers/test_config_entry_oauth2_flow.py index 7ce71defb7e..157bbf3bc23 100644 --- a/tests/helpers/test_config_entry_oauth2_flow.py +++ b/tests/helpers/test_config_entry_oauth2_flow.py @@ -6,7 +6,6 @@ import time import pytest from homeassistant import config_entries, data_entry_flow, setup -from homeassistant.config import async_process_ha_core_config from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers.network import NoURLAvailableError @@ -146,14 +145,14 @@ async def test_abort_if_no_url_available(hass, flow_handler, local_impl): async def test_abort_if_oauth_error( - hass, flow_handler, local_impl, aiohttp_client, aioclient_mock, current_request + hass, + flow_handler, + local_impl, + aiohttp_client, + aioclient_mock, + current_request_with_host, ): """Check bad oauth token.""" - await async_process_ha_core_config( - hass, - {"external_url": "https://example.com"}, - ) - flow_handler.async_register_implementation(hass, local_impl) config_entry_oauth2_flow.async_register_implementation( hass, TEST_DOMAIN, MockOAuth2Implementation() @@ -171,7 +170,13 @@ async def test_abort_if_oauth_error( result["flow_id"], user_input={"implementation": TEST_DOMAIN} ) - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP assert result["url"] == ( @@ -203,10 +208,6 @@ async def test_abort_if_oauth_error( async def test_step_discovery(hass, flow_handler, local_impl): """Check flow triggers from discovery.""" - await async_process_ha_core_config( - hass, - {"external_url": "https://example.com"}, - ) flow_handler.async_register_implementation(hass, local_impl) config_entry_oauth2_flow.async_register_implementation( hass, TEST_DOMAIN, MockOAuth2Implementation() @@ -222,11 +223,6 @@ async def test_step_discovery(hass, flow_handler, local_impl): async def test_abort_discovered_multiple(hass, flow_handler, local_impl): """Test if aborts when discovered multiple times.""" - await async_process_ha_core_config( - hass, - {"external_url": "https://example.com"}, - ) - flow_handler.async_register_implementation(hass, local_impl) config_entry_oauth2_flow.async_register_implementation( hass, TEST_DOMAIN, MockOAuth2Implementation() @@ -249,10 +245,6 @@ async def test_abort_discovered_multiple(hass, flow_handler, local_impl): async def test_abort_discovered_existing_entries(hass, flow_handler, local_impl): """Test if abort discovery when entries exists.""" - await async_process_ha_core_config( - hass, - {"external_url": "https://example.com"}, - ) flow_handler.async_register_implementation(hass, local_impl) config_entry_oauth2_flow.async_register_implementation( hass, TEST_DOMAIN, MockOAuth2Implementation() @@ -273,14 +265,14 @@ async def test_abort_discovered_existing_entries(hass, flow_handler, local_impl) async def test_full_flow( - hass, flow_handler, local_impl, aiohttp_client, aioclient_mock, current_request + hass, + flow_handler, + local_impl, + aiohttp_client, + aioclient_mock, + current_request_with_host, ): """Check full flow.""" - await async_process_ha_core_config( - hass, - {"external_url": "https://example.com"}, - ) - flow_handler.async_register_implementation(hass, local_impl) config_entry_oauth2_flow.async_register_implementation( hass, TEST_DOMAIN, MockOAuth2Implementation() @@ -298,7 +290,13 @@ async def test_full_flow( result["flow_id"], user_input={"implementation": TEST_DOMAIN} ) - state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + state = config_entry_oauth2_flow._encode_jwt( + hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", + }, + ) assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP assert result["url"] == ( diff --git a/tests/helpers/test_network.py b/tests/helpers/test_network.py index 495c9d511bd..c470fbd7834 100644 --- a/tests/helpers/test_network.py +++ b/tests/helpers/test_network.py @@ -500,7 +500,7 @@ async def test_get_url(hass: HomeAssistant): with patch( "homeassistant.helpers.network._get_request_host", return_value="example.com" - ), patch("homeassistant.helpers.network.current_request"): + ), patch("homeassistant.components.http.current_request"): assert get_url(hass, require_current_request=True) == "https://example.com" assert ( get_url(hass, require_current_request=True, require_ssl=True) @@ -512,7 +512,7 @@ async def test_get_url(hass: HomeAssistant): with patch( "homeassistant.helpers.network._get_request_host", return_value="example.local" - ), patch("homeassistant.helpers.network.current_request"): + ), patch("homeassistant.components.http.current_request"): assert get_url(hass, require_current_request=True) == "http://example.local" with pytest.raises(NoURLAvailableError): @@ -533,7 +533,7 @@ async def test_get_request_host(hass: HomeAssistant): with pytest.raises(NoURLAvailableError): _get_request_host() - with patch("homeassistant.helpers.network.current_request") as mock_request_context: + with patch("homeassistant.components.http.current_request") as mock_request_context: mock_request = Mock() mock_request.url = "http://example.com:8123/test/request" mock_request_context.get = Mock(return_value=mock_request)