mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
OAuth2 to use current request header (#43668)
This commit is contained in:
parent
69c2818c56
commit
f9fa24950b
@ -90,8 +90,8 @@ class ToonLocalOAuth2Implementation(config_entry_oauth2_flow.LocalOAuth2Implemen
|
||||
"""Initialize local Toon auth implementation."""
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": external_data,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"code": external_data["code"],
|
||||
"redirect_uri": external_data["state"]["redirect_uri"],
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
|
@ -19,9 +19,9 @@ import voluptuous as vol
|
||||
from yarl import URL
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components import http
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.network import NoURLAvailableError, get_url
|
||||
from homeassistant.helpers.network import NoURLAvailableError
|
||||
|
||||
from .aiohttp_client import async_get_clientsession
|
||||
|
||||
@ -32,6 +32,7 @@ DATA_VIEW_REGISTERED = "oauth2_view_reg"
|
||||
DATA_IMPLEMENTATIONS = "oauth2_impl"
|
||||
DATA_PROVIDERS = "oauth2_providers"
|
||||
AUTH_CALLBACK_PATH = "/auth/external/callback"
|
||||
HEADER_FRONTEND_BASE = "HA-Frontend-Base"
|
||||
|
||||
CLOCK_OUT_OF_SYNC_MAX_SEC = 20
|
||||
|
||||
@ -64,7 +65,7 @@ class AbstractOAuth2Implementation(ABC):
|
||||
Pass external data in with:
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
flow_id=flow_id, user_input=external_data
|
||||
flow_id=flow_id, user_input={'code': 'abcd', 'state': { … }
|
||||
)
|
||||
|
||||
"""
|
||||
@ -124,7 +125,17 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
@property
|
||||
def redirect_uri(self) -> str:
|
||||
"""Return the redirect uri."""
|
||||
return f"{get_url(self.hass, require_current_request=True)}{AUTH_CALLBACK_PATH}"
|
||||
req = http.current_request.get()
|
||||
|
||||
if req is None:
|
||||
raise RuntimeError("No current request in context")
|
||||
|
||||
ha_host = req.headers.get(HEADER_FRONTEND_BASE)
|
||||
|
||||
if ha_host is None:
|
||||
raise RuntimeError("No header in request")
|
||||
|
||||
return f"{ha_host}{AUTH_CALLBACK_PATH}"
|
||||
|
||||
@property
|
||||
def extra_authorize_data(self) -> dict:
|
||||
@ -133,14 +144,17 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
|
||||
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
||||
"""Generate a url for the user to authorize."""
|
||||
redirect_uri = self.redirect_uri
|
||||
return str(
|
||||
URL(self.authorize_url)
|
||||
.with_query(
|
||||
{
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"state": _encode_jwt(self.hass, {"flow_id": flow_id}),
|
||||
"redirect_uri": redirect_uri,
|
||||
"state": _encode_jwt(
|
||||
self.hass, {"flow_id": flow_id, "redirect_uri": redirect_uri}
|
||||
),
|
||||
}
|
||||
)
|
||||
.update_query(self.extra_authorize_data)
|
||||
@ -151,8 +165,8 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
return await self._token_request(
|
||||
{
|
||||
"grant_type": "authorization_code",
|
||||
"code": external_data,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"code": external_data["code"],
|
||||
"redirect_uri": external_data["state"]["redirect_uri"],
|
||||
}
|
||||
)
|
||||
|
||||
@ -384,7 +398,7 @@ def async_add_implementation_provider(
|
||||
] = async_provide_implementation
|
||||
|
||||
|
||||
class OAuth2AuthorizeCallbackView(HomeAssistantView):
|
||||
class OAuth2AuthorizeCallbackView(http.HomeAssistantView):
|
||||
"""OAuth2 Authorization Callback View."""
|
||||
|
||||
requires_auth = False
|
||||
@ -406,7 +420,8 @@ class OAuth2AuthorizeCallbackView(HomeAssistantView):
|
||||
return web.Response(text="Invalid state")
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
flow_id=state["flow_id"], user_input=request.query["code"]
|
||||
flow_id=state["flow_id"],
|
||||
user_input={"state": state, "code": request.query["code"]},
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
|
@ -4,7 +4,7 @@ from typing import Optional, cast
|
||||
|
||||
import yarl
|
||||
|
||||
from homeassistant.components.http import current_request
|
||||
from homeassistant.components import http
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.loader import bind_hass
|
||||
@ -49,7 +49,7 @@ def get_url(
|
||||
prefer_cloud: bool = False,
|
||||
) -> str:
|
||||
"""Get a URL to this instance."""
|
||||
if require_current_request and current_request.get() is None:
|
||||
if require_current_request and http.current_request.get() is None:
|
||||
raise NoURLAvailableError
|
||||
|
||||
order = [TYPE_URL_INTERNAL, TYPE_URL_EXTERNAL]
|
||||
@ -125,7 +125,7 @@ def get_url(
|
||||
|
||||
def _get_request_host() -> Optional[str]:
|
||||
"""Get the host address of the current request."""
|
||||
request = current_request.get()
|
||||
request = http.current_request.get()
|
||||
if request is None:
|
||||
raise NoURLAvailableError
|
||||
return yarl.URL(request.url).host
|
||||
|
@ -13,7 +13,9 @@ CLIENT_ID = "1234"
|
||||
CLIENT_SECRET = "5678"
|
||||
|
||||
|
||||
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_full_flow(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Check full flow."""
|
||||
assert await setup.async_setup_component(
|
||||
hass,
|
||||
@ -27,7 +29,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"NEW_DOMAIN", context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["url"] == (
|
||||
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
|
||||
|
@ -91,7 +91,9 @@ async def test_abort_if_existing_entry(hass):
|
||||
assert result["reason"] == "single_instance_allowed"
|
||||
|
||||
|
||||
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_full_flow(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Check full flow."""
|
||||
assert await setup.async_setup_component(
|
||||
hass,
|
||||
@ -109,7 +111,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
assert result["url"] == (
|
||||
|
@ -14,7 +14,9 @@ CLIENT_ID = "1234"
|
||||
CLIENT_SECRET = "5678"
|
||||
|
||||
|
||||
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_full_flow(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Check full flow."""
|
||||
assert await setup.async_setup_component(
|
||||
hass,
|
||||
@ -31,7 +33,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"home_connect", context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
assert result["url"] == (
|
||||
|
@ -12,7 +12,9 @@ PROJECT_ID = "project-id-4321"
|
||||
SUBSCRIBER_ID = "subscriber-id-9876"
|
||||
|
||||
|
||||
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_full_flow(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Check full flow."""
|
||||
assert await setup.async_setup_component(
|
||||
hass,
|
||||
@ -31,7 +33,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
oauth_authorize = OAUTH2_AUTHORIZE.format(project_id=PROJECT_ID)
|
||||
assert result["url"] == (
|
||||
|
@ -42,7 +42,9 @@ async def test_abort_if_existing_entry(hass):
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
|
||||
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_full_flow(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Check full flow."""
|
||||
assert await setup.async_setup_component(
|
||||
hass,
|
||||
@ -56,7 +58,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"netatmo", context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
scope = "+".join(
|
||||
[
|
||||
|
@ -333,7 +333,9 @@ async def test_abort_cloud_flow_if_local_device_exists(hass):
|
||||
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||
|
||||
|
||||
async def test_full_user_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_full_user_flow(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Check full flow."""
|
||||
assert await setup.async_setup_component(
|
||||
hass,
|
||||
@ -351,7 +353,13 @@ async def test_full_user_flow(hass, aiohttp_client, aioclient_mock, current_requ
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"environment": ENV_CLOUD}
|
||||
)
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
|
@ -52,7 +52,9 @@ async def test_abort_if_existing_entry(hass):
|
||||
assert result["reason"] == "single_instance_allowed"
|
||||
|
||||
|
||||
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_full_flow(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Check full flow."""
|
||||
assert await setup.async_setup_component(
|
||||
hass,
|
||||
@ -69,7 +71,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
assert result["url"] == (
|
||||
|
@ -40,7 +40,9 @@ async def test_zeroconf_abort_if_existing_entry(hass):
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
|
||||
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_full_flow(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Check a full flow."""
|
||||
assert await setup.async_setup_component(
|
||||
hass,
|
||||
@ -56,7 +58,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
assert result["url"] == (
|
||||
@ -103,7 +111,7 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
|
||||
|
||||
async def test_abort_if_spotify_error(
|
||||
hass, aiohttp_client, aioclient_mock, current_request
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Check Spotify errors causes flow to abort."""
|
||||
await setup.async_setup_component(
|
||||
@ -120,7 +128,13 @@ async def test_abort_if_spotify_error(
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
|
||||
@ -144,7 +158,9 @@ async def test_abort_if_spotify_error(
|
||||
assert result["reason"] == "connection_error"
|
||||
|
||||
|
||||
async def test_reauthentication(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_reauthentication(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Test Spotify reauthentication."""
|
||||
await setup.async_setup_component(
|
||||
hass,
|
||||
@ -173,7 +189,13 @@ async def test_reauthentication(hass, aiohttp_client, aioclient_mock, current_re
|
||||
result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {})
|
||||
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
|
||||
@ -202,7 +224,7 @@ async def test_reauthentication(hass, aiohttp_client, aioclient_mock, current_re
|
||||
|
||||
|
||||
async def test_reauth_account_mismatch(
|
||||
hass, aiohttp_client, aioclient_mock, current_request
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Test Spotify reauthentication with different account."""
|
||||
await setup.async_setup_component(
|
||||
@ -230,7 +252,13 @@ async def test_reauth_account_mismatch(
|
||||
result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {})
|
||||
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
|
||||
|
@ -40,7 +40,7 @@ async def test_abort_if_no_configuration(hass):
|
||||
|
||||
|
||||
async def test_full_flow_implementation(
|
||||
hass, aiohttp_client, aioclient_mock, current_request
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Test registering an integration and finishing flow works."""
|
||||
await setup_component(hass)
|
||||
@ -53,7 +53,13 @@ async def test_full_flow_implementation(
|
||||
assert result["step_id"] == "pick_implementation"
|
||||
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"implementation": "eneco"}
|
||||
@ -97,7 +103,9 @@ async def test_full_flow_implementation(
|
||||
}
|
||||
|
||||
|
||||
async def test_no_agreements(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_no_agreements(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Test abort when there are no displays."""
|
||||
await setup_component(hass)
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
@ -105,7 +113,13 @@ async def test_no_agreements(hass, aiohttp_client, aioclient_mock, current_reque
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"implementation": "eneco"}
|
||||
)
|
||||
@ -130,7 +144,7 @@ async def test_no_agreements(hass, aiohttp_client, aioclient_mock, current_reque
|
||||
|
||||
|
||||
async def test_multiple_agreements(
|
||||
hass, aiohttp_client, aioclient_mock, current_request
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Test abort when there are no displays."""
|
||||
await setup_component(hass)
|
||||
@ -139,7 +153,13 @@ async def test_multiple_agreements(
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"implementation": "eneco"}
|
||||
)
|
||||
@ -174,7 +194,7 @@ async def test_multiple_agreements(
|
||||
|
||||
|
||||
async def test_agreement_already_set_up(
|
||||
hass, aiohttp_client, aioclient_mock, current_request
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Test showing display form again if display already exists."""
|
||||
await setup_component(hass)
|
||||
@ -184,7 +204,13 @@ async def test_agreement_already_set_up(
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"implementation": "eneco"}
|
||||
)
|
||||
@ -208,14 +234,22 @@ async def test_agreement_already_set_up(
|
||||
assert result3["reason"] == "already_configured"
|
||||
|
||||
|
||||
async def test_toon_abort(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_toon_abort(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Test we abort on Toon error."""
|
||||
await setup_component(hass)
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"implementation": "eneco"}
|
||||
)
|
||||
@ -239,7 +273,7 @@ async def test_toon_abort(hass, aiohttp_client, aioclient_mock, current_request)
|
||||
assert result2["reason"] == "connection_error"
|
||||
|
||||
|
||||
async def test_import(hass):
|
||||
async def test_import(hass, current_request_with_host):
|
||||
"""Test if importing step works."""
|
||||
await setup_component(hass)
|
||||
|
||||
@ -253,7 +287,9 @@ async def test_import(hass):
|
||||
assert result["reason"] == "already_in_progress"
|
||||
|
||||
|
||||
async def test_import_migration(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_import_migration(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Test if importing step with migration works."""
|
||||
old_entry = MockConfigEntry(domain=DOMAIN, unique_id=123, version=1)
|
||||
old_entry.add_to_hass(hass)
|
||||
@ -269,7 +305,13 @@ async def test_import_migration(hass, aiohttp_client, aioclient_mock, current_re
|
||||
assert flows[0]["context"][CONF_MIGRATE] == old_entry.entry_id
|
||||
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": flows[0]["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": flows[0]["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
await hass.config_entries.flow.async_configure(
|
||||
flows[0]["flow_id"], {"implementation": "eneco"}
|
||||
)
|
||||
|
@ -197,7 +197,11 @@ class ComponentFactory:
|
||||
assert result
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
self._hass, {"flow_id": result["flow_id"]}
|
||||
self._hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "http://127.0.0.1:8080/auth/external/callback",
|
||||
},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
assert result["url"] == (
|
||||
|
@ -71,7 +71,13 @@ async def test_config_reauth_profile(
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
client: TestClient = await aiohttp_client(hass.http.app)
|
||||
resp = await client.get(f"{AUTH_CALLBACK_PATH}?code=abcd&state={state}")
|
||||
|
@ -21,7 +21,9 @@ async def test_abort_if_existing_entry(hass):
|
||||
assert result["reason"] == "single_instance_allowed"
|
||||
|
||||
|
||||
async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
async def test_full_flow(
|
||||
hass, aiohttp_client, aioclient_mock, current_request_with_host
|
||||
):
|
||||
"""Check full flow."""
|
||||
assert await setup.async_setup_component(
|
||||
hass,
|
||||
@ -35,7 +37,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"xbox", context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
scope = "+".join(["Xboxlive.signin", "Xboxlive.offline_access"])
|
||||
|
||||
|
@ -7,6 +7,7 @@ import ssl
|
||||
import threading
|
||||
|
||||
from aiohttp.test_utils import make_mocked_request
|
||||
import multidict
|
||||
import pytest
|
||||
import requests_mock as _requests_mock
|
||||
|
||||
@ -22,11 +23,11 @@ from homeassistant.components.websocket_api.auth import (
|
||||
from homeassistant.components.websocket_api.http import URL
|
||||
from homeassistant.const import ATTR_NOW, EVENT_TIME_CHANGED
|
||||
from homeassistant.exceptions import ServiceNotFound
|
||||
from homeassistant.helpers import event
|
||||
from homeassistant.helpers import config_entry_oauth2_flow, event
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import location
|
||||
|
||||
from tests.async_mock import MagicMock, Mock, patch
|
||||
from tests.async_mock import MagicMock, patch
|
||||
from tests.ignore_uncaught_exceptions import IGNORE_UNCAUGHT_EXCEPTIONS
|
||||
|
||||
pytest.register_assert_rewrite("tests.common")
|
||||
@ -277,19 +278,29 @@ def hass_client(hass, aiohttp_client, hass_access_token):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_request(hass):
|
||||
def current_request():
|
||||
"""Mock current request."""
|
||||
with patch("homeassistant.helpers.network.current_request") as mock_request_context:
|
||||
with patch("homeassistant.components.http.current_request") as mock_request_context:
|
||||
mocked_request = make_mocked_request(
|
||||
"GET",
|
||||
"/some/request",
|
||||
headers={"Host": "example.com"},
|
||||
sslcontext=ssl.SSLContext(ssl.PROTOCOL_TLS),
|
||||
)
|
||||
mock_request_context.get = Mock(return_value=mocked_request)
|
||||
mock_request_context.get.return_value = mocked_request
|
||||
yield mock_request_context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_request_with_host(current_request):
|
||||
"""Mock current request with a host header."""
|
||||
new_headers = multidict.CIMultiDict(current_request.get.return_value.headers)
|
||||
new_headers[config_entry_oauth2_flow.HEADER_FRONTEND_BASE] = "https://example.com"
|
||||
current_request.get.return_value = current_request.get.return_value.clone(
|
||||
headers=new_headers
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hass_ws_client(aiohttp_client, hass_access_token, hass):
|
||||
"""Websocket client fixture connected to websocket server."""
|
||||
|
@ -6,7 +6,6 @@ import time
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow, setup
|
||||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
from homeassistant.helpers.network import NoURLAvailableError
|
||||
|
||||
@ -146,14 +145,14 @@ async def test_abort_if_no_url_available(hass, flow_handler, local_impl):
|
||||
|
||||
|
||||
async def test_abort_if_oauth_error(
|
||||
hass, flow_handler, local_impl, aiohttp_client, aioclient_mock, current_request
|
||||
hass,
|
||||
flow_handler,
|
||||
local_impl,
|
||||
aiohttp_client,
|
||||
aioclient_mock,
|
||||
current_request_with_host,
|
||||
):
|
||||
"""Check bad oauth token."""
|
||||
await async_process_ha_core_config(
|
||||
hass,
|
||||
{"external_url": "https://example.com"},
|
||||
)
|
||||
|
||||
flow_handler.async_register_implementation(hass, local_impl)
|
||||
config_entry_oauth2_flow.async_register_implementation(
|
||||
hass, TEST_DOMAIN, MockOAuth2Implementation()
|
||||
@ -171,7 +170,13 @@ async def test_abort_if_oauth_error(
|
||||
result["flow_id"], user_input={"implementation": TEST_DOMAIN}
|
||||
)
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
assert result["url"] == (
|
||||
@ -203,10 +208,6 @@ async def test_abort_if_oauth_error(
|
||||
|
||||
async def test_step_discovery(hass, flow_handler, local_impl):
|
||||
"""Check flow triggers from discovery."""
|
||||
await async_process_ha_core_config(
|
||||
hass,
|
||||
{"external_url": "https://example.com"},
|
||||
)
|
||||
flow_handler.async_register_implementation(hass, local_impl)
|
||||
config_entry_oauth2_flow.async_register_implementation(
|
||||
hass, TEST_DOMAIN, MockOAuth2Implementation()
|
||||
@ -222,11 +223,6 @@ async def test_step_discovery(hass, flow_handler, local_impl):
|
||||
|
||||
async def test_abort_discovered_multiple(hass, flow_handler, local_impl):
|
||||
"""Test if aborts when discovered multiple times."""
|
||||
await async_process_ha_core_config(
|
||||
hass,
|
||||
{"external_url": "https://example.com"},
|
||||
)
|
||||
|
||||
flow_handler.async_register_implementation(hass, local_impl)
|
||||
config_entry_oauth2_flow.async_register_implementation(
|
||||
hass, TEST_DOMAIN, MockOAuth2Implementation()
|
||||
@ -249,10 +245,6 @@ async def test_abort_discovered_multiple(hass, flow_handler, local_impl):
|
||||
|
||||
async def test_abort_discovered_existing_entries(hass, flow_handler, local_impl):
|
||||
"""Test if abort discovery when entries exists."""
|
||||
await async_process_ha_core_config(
|
||||
hass,
|
||||
{"external_url": "https://example.com"},
|
||||
)
|
||||
flow_handler.async_register_implementation(hass, local_impl)
|
||||
config_entry_oauth2_flow.async_register_implementation(
|
||||
hass, TEST_DOMAIN, MockOAuth2Implementation()
|
||||
@ -273,14 +265,14 @@ async def test_abort_discovered_existing_entries(hass, flow_handler, local_impl)
|
||||
|
||||
|
||||
async def test_full_flow(
|
||||
hass, flow_handler, local_impl, aiohttp_client, aioclient_mock, current_request
|
||||
hass,
|
||||
flow_handler,
|
||||
local_impl,
|
||||
aiohttp_client,
|
||||
aioclient_mock,
|
||||
current_request_with_host,
|
||||
):
|
||||
"""Check full flow."""
|
||||
await async_process_ha_core_config(
|
||||
hass,
|
||||
{"external_url": "https://example.com"},
|
||||
)
|
||||
|
||||
flow_handler.async_register_implementation(hass, local_impl)
|
||||
config_entry_oauth2_flow.async_register_implementation(
|
||||
hass, TEST_DOMAIN, MockOAuth2Implementation()
|
||||
@ -298,7 +290,13 @@ async def test_full_flow(
|
||||
result["flow_id"], user_input={"implementation": TEST_DOMAIN}
|
||||
)
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
assert result["url"] == (
|
||||
|
@ -500,7 +500,7 @@ async def test_get_url(hass: HomeAssistant):
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.network._get_request_host", return_value="example.com"
|
||||
), patch("homeassistant.helpers.network.current_request"):
|
||||
), patch("homeassistant.components.http.current_request"):
|
||||
assert get_url(hass, require_current_request=True) == "https://example.com"
|
||||
assert (
|
||||
get_url(hass, require_current_request=True, require_ssl=True)
|
||||
@ -512,7 +512,7 @@ async def test_get_url(hass: HomeAssistant):
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.network._get_request_host", return_value="example.local"
|
||||
), patch("homeassistant.helpers.network.current_request"):
|
||||
), patch("homeassistant.components.http.current_request"):
|
||||
assert get_url(hass, require_current_request=True) == "http://example.local"
|
||||
|
||||
with pytest.raises(NoURLAvailableError):
|
||||
@ -533,7 +533,7 @@ async def test_get_request_host(hass: HomeAssistant):
|
||||
with pytest.raises(NoURLAvailableError):
|
||||
_get_request_host()
|
||||
|
||||
with patch("homeassistant.helpers.network.current_request") as mock_request_context:
|
||||
with patch("homeassistant.components.http.current_request") as mock_request_context:
|
||||
mock_request = Mock()
|
||||
mock_request.url = "http://example.com:8123/test/request"
|
||||
mock_request_context.get = Mock(return_value=mock_request)
|
||||
|
Loading…
x
Reference in New Issue
Block a user