Store redirect URI in context instead of asking each time (#77380)

* Store redirect URI in context instead of asking each time

* Fix tests
This commit is contained in:
Paulus Schoutsen 2022-08-29 19:28:42 -04:00 committed by GitHub
parent 2224d0f43a
commit 14f68ec1a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 30 additions and 21 deletions

View File

@ -193,7 +193,6 @@ class LoginFlowBaseView(HomeAssistantView):
self, self,
request: web.Request, request: web.Request,
client_id: str, client_id: str,
redirect_uri: str,
result: data_entry_flow.FlowResult, result: data_entry_flow.FlowResult,
) -> web.Response: ) -> web.Response:
"""Convert the flow result to a response.""" """Convert the flow result to a response."""
@ -214,10 +213,13 @@ class LoginFlowBaseView(HomeAssistantView):
hass: HomeAssistant = request.app["hass"] hass: HomeAssistant = request.app["hass"]
if not await indieauth.verify_redirect_uri(hass, client_id, redirect_uri): if not await indieauth.verify_redirect_uri(
hass, client_id, result["context"]["redirect_uri"]
):
return self.json_message("Invalid redirect URI", HTTPStatus.FORBIDDEN) return self.json_message("Invalid redirect URI", HTTPStatus.FORBIDDEN)
result.pop("data") result.pop("data")
result.pop("context")
result_obj: Credentials = result.pop("result") result_obj: Credentials = result.pop("result")
@ -278,6 +280,7 @@ class LoginFlowIndexView(LoginFlowBaseView):
context={ context={
"ip_address": ip_address(request.remote), # type: ignore[arg-type] "ip_address": ip_address(request.remote), # type: ignore[arg-type]
"credential_only": data.get("type") == "link_user", "credential_only": data.get("type") == "link_user",
"redirect_uri": redirect_uri,
}, },
) )
except data_entry_flow.UnknownHandler: except data_entry_flow.UnknownHandler:
@ -287,9 +290,7 @@ class LoginFlowIndexView(LoginFlowBaseView):
"Handler does not support init", HTTPStatus.BAD_REQUEST "Handler does not support init", HTTPStatus.BAD_REQUEST
) )
return await self._async_flow_result_to_response( return await self._async_flow_result_to_response(request, client_id, result)
request, client_id, redirect_uri, result
)
class LoginFlowResourceView(LoginFlowBaseView): class LoginFlowResourceView(LoginFlowBaseView):
@ -304,7 +305,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
@RequestDataValidator( @RequestDataValidator(
vol.Schema( vol.Schema(
{vol.Required("client_id"): str, vol.Required("redirect_uri"): str}, {vol.Required("client_id"): str},
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
) )
@ -314,7 +315,6 @@ class LoginFlowResourceView(LoginFlowBaseView):
) -> web.Response: ) -> web.Response:
"""Handle progressing a login flow request.""" """Handle progressing a login flow request."""
client_id: str = data.pop("client_id") client_id: str = data.pop("client_id")
redirect_uri: str = data.pop("redirect_uri")
if not indieauth.verify_client_id(client_id): if not indieauth.verify_client_id(client_id):
return self.json_message("Invalid client id", HTTPStatus.BAD_REQUEST) return self.json_message("Invalid client id", HTTPStatus.BAD_REQUEST)
@ -330,9 +330,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
except vol.Invalid: except vol.Invalid:
return self.json_message("User input malformed", HTTPStatus.BAD_REQUEST) return self.json_message("User input malformed", HTTPStatus.BAD_REQUEST)
return await self._async_flow_result_to_response( return await self._async_flow_result_to_response(request, client_id, result)
request, client_id, redirect_uri, result
)
async def delete(self, request: web.Request, flow_id: str) -> web.Response: async def delete(self, request: web.Request, flow_id: str) -> web.Response:
"""Cancel a flow in progress.""" """Cancel a flow in progress."""

View File

@ -121,6 +121,7 @@ def _prepare_config_flow_result_json(result, prepare_result_json):
data = result.copy() data = result.copy()
data["result"] = entry_json(result["result"]) data["result"] = entry_json(result["result"])
data.pop("data") data.pop("data")
data.pop("context")
return data return data

View File

@ -484,6 +484,7 @@ class FlowHandler:
data=data, data=data,
description=description, description=description,
description_placeholders=description_placeholders, description_placeholders=description_placeholders,
context=self.context,
) )
@callback @callback

View File

@ -30,6 +30,7 @@ class _BaseFlowManagerView(HomeAssistantView):
data = result.copy() data = result.copy()
data.pop("result") data.pop("result")
data.pop("data") data.pop("data")
data.pop("context")
return data return data
if "data_schema" not in result: if "data_schema" not in result:

View File

@ -64,7 +64,6 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
f"/auth/login_flow/{step['flow_id']}", f"/auth/login_flow/{step['flow_id']}",
json={ json={
"client_id": CLIENT_ID, "client_id": CLIENT_ID,
"redirect_uri": CLIENT_REDIRECT_URI,
"username": "test-user", "username": "test-user",
"password": "test-pass", "password": "test-pass",
}, },
@ -133,7 +132,6 @@ async def test_auth_code_checks_local_only_user(hass, aiohttp_client):
f"/auth/login_flow/{step['flow_id']}", f"/auth/login_flow/{step['flow_id']}",
json={ json={
"client_id": CLIENT_ID, "client_id": CLIENT_ID,
"redirect_uri": CLIENT_REDIRECT_URI,
"username": "test-user", "username": "test-user",
"password": "test-pass", "password": "test-pass",
}, },

View File

@ -48,7 +48,6 @@ async def async_get_code(hass, aiohttp_client):
f"/auth/login_flow/{step['flow_id']}", f"/auth/login_flow/{step['flow_id']}",
json={ json={
"client_id": CLIENT_ID, "client_id": CLIENT_ID,
"redirect_uri": CLIENT_REDIRECT_URI,
"username": "2nd-user", "username": "2nd-user",
"password": "2nd-pass", "password": "2nd-pass",
}, },

View File

@ -61,7 +61,6 @@ async def test_invalid_username_password(hass, aiohttp_client):
f"/auth/login_flow/{step['flow_id']}", f"/auth/login_flow/{step['flow_id']}",
json={ json={
"client_id": CLIENT_ID, "client_id": CLIENT_ID,
"redirect_uri": CLIENT_REDIRECT_URI,
"username": "wrong-user", "username": "wrong-user",
"password": "test-pass", "password": "test-pass",
}, },
@ -82,7 +81,6 @@ async def test_invalid_username_password(hass, aiohttp_client):
f"/auth/login_flow/{step['flow_id']}", f"/auth/login_flow/{step['flow_id']}",
json={ json={
"client_id": CLIENT_ID, "client_id": CLIENT_ID,
"redirect_uri": CLIENT_REDIRECT_URI,
"username": "test-user", "username": "test-user",
"password": "wrong-pass", "password": "wrong-pass",
}, },
@ -103,7 +101,6 @@ async def test_invalid_username_password(hass, aiohttp_client):
f"/auth/login_flow/{step['flow_id']}", f"/auth/login_flow/{step['flow_id']}",
json={ json={
"client_id": CLIENT_ID, "client_id": CLIENT_ID,
"redirect_uri": "http://some-other-domain.com",
"username": "wrong-user", "username": "wrong-user",
"password": "test-pass", "password": "test-pass",
}, },
@ -116,7 +113,21 @@ async def test_invalid_username_password(hass, aiohttp_client):
assert step["step_id"] == "init" assert step["step_id"] == "init"
assert step["errors"]["base"] == "invalid_auth" assert step["errors"]["base"] == "invalid_auth"
# Incorrect redirect URI
async def test_invalid_redirect_uri(hass, aiohttp_client):
"""Test invalid redirect URI."""
client = await async_setup_auth(hass, aiohttp_client)
resp = await client.post(
"/auth/login_flow",
json={
"client_id": CLIENT_ID,
"handler": ["insecure_example", None],
"redirect_uri": "https://some-other-domain.com",
},
)
assert resp.status == HTTPStatus.OK
step = await resp.json()
with patch( with patch(
"homeassistant.components.auth.indieauth.fetch_redirect_uris", return_value=[] "homeassistant.components.auth.indieauth.fetch_redirect_uris", return_value=[]
), patch( ), patch(
@ -126,7 +137,6 @@ async def test_invalid_username_password(hass, aiohttp_client):
f"/auth/login_flow/{step['flow_id']}", f"/auth/login_flow/{step['flow_id']}",
json={ json={
"client_id": CLIENT_ID, "client_id": CLIENT_ID,
"redirect_uri": "http://some-other-domain.com",
"username": "test-user", "username": "test-user",
"password": "test-pass", "password": "test-pass",
}, },
@ -165,7 +175,6 @@ async def test_login_exist_user(hass, aiohttp_client):
f"/auth/login_flow/{step['flow_id']}", f"/auth/login_flow/{step['flow_id']}",
json={ json={
"client_id": CLIENT_ID, "client_id": CLIENT_ID,
"redirect_uri": CLIENT_REDIRECT_URI,
"username": "test-user", "username": "test-user",
"password": "test-pass", "password": "test-pass",
}, },
@ -206,14 +215,13 @@ async def test_login_local_only_user(hass, aiohttp_client):
f"/auth/login_flow/{step['flow_id']}", f"/auth/login_flow/{step['flow_id']}",
json={ json={
"client_id": CLIENT_ID, "client_id": CLIENT_ID,
"redirect_uri": CLIENT_REDIRECT_URI,
"username": "test-user", "username": "test-user",
"password": "test-pass", "password": "test-pass",
}, },
) )
assert len(mock_not_allowed_do_auth.mock_calls) == 1
assert resp.status == HTTPStatus.FORBIDDEN assert resp.status == HTTPStatus.FORBIDDEN
assert len(mock_not_allowed_do_auth.mock_calls) == 1
assert await resp.json() == {"message": "Login blocked: User is local only"} assert await resp.json() == {"message": "Login blocked: User is local only"}

View File

@ -120,6 +120,7 @@ async def test_pairing(hass, mock_tv_pairable, mock_setup_entry):
) )
assert result == { assert result == {
"context": {"source": "user", "unique_id": "ABCDEFGHIJKLF"},
"flow_id": ANY, "flow_id": ANY,
"type": "create_entry", "type": "create_entry",
"description": None, "description": None,

View File

@ -117,6 +117,7 @@ async def test_user_form_pin_not_required(hass, two_factor_verify_form):
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
expected = { expected = {
"context": {"source": "user"},
"title": TEST_USERNAME, "title": TEST_USERNAME,
"description": None, "description": None,
"description_placeholders": None, "description_placeholders": None,
@ -286,6 +287,7 @@ async def test_pin_form_success(hass, pin_form):
assert len(mock_update_saved_pin.mock_calls) == 1 assert len(mock_update_saved_pin.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
expected = { expected = {
"context": {"source": "user"},
"title": TEST_USERNAME, "title": TEST_USERNAME,
"description": None, "description": None,
"description_placeholders": None, "description_placeholders": None,