mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 01:37:08 +00:00
Add PKCE implementation in oauth2 helper (#139509)
* Update config_entry_oauth2_flow.py * Specify type on request_data * Added LocalOAuth2ImplementationWithPkce * LocalOAuth2ImplementationWithPkce works more like specs * fix: Adding tests for pkce flow and feedback applied * fix last test for pkce * Clean test_abort_if_oauth_with_pkce_rejected * Improve assertion of code verifier and code challenge * Break long docstrings * Shorten docstring --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
fb2b3ce7d2
commit
76aef5be9f
@ -11,7 +11,9 @@ from __future__ import annotations
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
import asyncio
|
||||
from asyncio import Lock
|
||||
import base64
|
||||
from collections.abc import Awaitable, Callable
|
||||
import hashlib
|
||||
from http import HTTPStatus
|
||||
from json import JSONDecodeError
|
||||
import logging
|
||||
@ -166,6 +168,11 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
"""Extra data that needs to be appended to the authorize url."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def extra_token_resolve_data(self) -> dict:
|
||||
"""Extra data for the token resolve request."""
|
||||
return {}
|
||||
|
||||
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
||||
"""Generate a url for the user to authorize."""
|
||||
redirect_uri = self.redirect_uri
|
||||
@ -186,13 +193,13 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
|
||||
async def async_resolve_external_data(self, external_data: Any) -> dict:
|
||||
"""Resolve the authorization code to tokens."""
|
||||
return await self._token_request(
|
||||
{
|
||||
"grant_type": "authorization_code",
|
||||
"code": external_data["code"],
|
||||
"redirect_uri": external_data["state"]["redirect_uri"],
|
||||
}
|
||||
)
|
||||
request_data: dict = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": external_data["code"],
|
||||
"redirect_uri": external_data["state"]["redirect_uri"],
|
||||
}
|
||||
request_data.update(self.extra_token_resolve_data)
|
||||
return await self._token_request(request_data)
|
||||
|
||||
async def _async_refresh_token(self, token: dict) -> dict:
|
||||
"""Refresh tokens."""
|
||||
@ -211,7 +218,7 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
|
||||
data["client_id"] = self.client_id
|
||||
|
||||
if self.client_secret is not None:
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
|
||||
_LOGGER.debug("Sending token request to %s", self.token_url)
|
||||
@ -233,6 +240,100 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
return cast(dict, await resp.json())
|
||||
|
||||
|
||||
class LocalOAuth2ImplementationWithPkce(LocalOAuth2Implementation):
|
||||
"""Local OAuth2 implementation with PKCE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
client_id: str,
|
||||
authorize_url: str,
|
||||
token_url: str,
|
||||
client_secret: str = "",
|
||||
code_verifier_length: int = 128,
|
||||
) -> None:
|
||||
"""Initialize local auth implementation."""
|
||||
super().__init__(
|
||||
hass,
|
||||
domain,
|
||||
client_id,
|
||||
client_secret,
|
||||
authorize_url,
|
||||
token_url,
|
||||
)
|
||||
|
||||
# Generate code verifier
|
||||
self.code_verifier = LocalOAuth2ImplementationWithPkce.generate_code_verifier(
|
||||
code_verifier_length
|
||||
)
|
||||
|
||||
@property
|
||||
def extra_authorize_data(self) -> dict:
|
||||
"""Extra data that needs to be appended to the authorize url.
|
||||
|
||||
If you want to override this method,
|
||||
calling super is mandatory (for adding scopes):
|
||||
```
|
||||
@def extra_authorize_data(self) -> dict:
|
||||
data: dict = {
|
||||
"scope": "openid profile email",
|
||||
}
|
||||
data.update(super().extra_authorize_data)
|
||||
return data
|
||||
```
|
||||
"""
|
||||
return {
|
||||
"code_challenge": LocalOAuth2ImplementationWithPkce.compute_code_challenge(
|
||||
self.code_verifier
|
||||
),
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
@property
|
||||
def extra_token_resolve_data(self) -> dict:
|
||||
"""Extra data that needs to be included in the token resolve request.
|
||||
|
||||
If you want to override this method,
|
||||
calling super is mandatory (for adding `someKey`):
|
||||
```
|
||||
@def extra_token_resolve_data(self) -> dict:
|
||||
data: dict = {
|
||||
"someKey": "someValue",
|
||||
}
|
||||
data.update(super().extra_token_resolve_data)
|
||||
return data
|
||||
```
|
||||
"""
|
||||
|
||||
return {"code_verifier": self.code_verifier}
|
||||
|
||||
@staticmethod
|
||||
def generate_code_verifier(code_verifier_length: int = 128) -> str:
|
||||
"""Generate a code verifier."""
|
||||
if not 43 <= code_verifier_length <= 128:
|
||||
msg = (
|
||||
"Parameter `code_verifier_length` must validate"
|
||||
"`43 <= code_verifier_length <= 128`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return secrets.token_urlsafe(96)[:code_verifier_length]
|
||||
|
||||
@staticmethod
|
||||
def compute_code_challenge(code_verifier: str) -> str:
|
||||
"""Compute the code challenge."""
|
||||
if not 43 <= len(code_verifier) <= 128:
|
||||
msg = (
|
||||
"Parameter `code_verifier` must validate "
|
||||
"`43 <= len(code_verifier) <= 128`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
hashed = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||
encoded = base64.urlsafe_b64encode(hashed)
|
||||
return encoded.decode("ascii").replace("=", "")
|
||||
|
||||
|
||||
class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
|
||||
"""Handle a config flow."""
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
"""Tests for the Somfy config flow."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from http import HTTPStatus
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
from homeassistant.helpers.network import NoURLAvailableError
|
||||
|
||||
from tests.common import MockConfigEntry, mock_platform
|
||||
from tests.common import MockConfigEntry, MockModule, mock_integration, mock_platform
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
@ -27,6 +27,11 @@ ACCESS_TOKEN_1 = "mock-access-token-1"
|
||||
ACCESS_TOKEN_2 = "mock-access-token-2"
|
||||
AUTHORIZE_URL = "https://example.como/auth/authorize"
|
||||
TOKEN_URL = "https://example.como/auth/token"
|
||||
MOCK_SECRET_TOKEN_URLSAFE = (
|
||||
"token-"
|
||||
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -40,6 +45,22 @@ async def local_impl(
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def local_impl_pkce(
|
||||
hass: HomeAssistant,
|
||||
) -> AsyncGenerator[config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce]:
|
||||
"""Local implementation."""
|
||||
assert await setup.async_setup_component(hass, "auth", {})
|
||||
with patch(
|
||||
"homeassistant.helpers.config_entry_oauth2_flow.secrets.token_urlsafe",
|
||||
return_value=MOCK_SECRET_TOKEN_URLSAFE
|
||||
+ "bbbbbb", # Add some characters that should be removed by the logic.
|
||||
):
|
||||
yield config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce(
|
||||
hass, TEST_DOMAIN, CLIENT_ID, AUTHORIZE_URL, TOKEN_URL
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flow_handler(
|
||||
hass: HomeAssistant,
|
||||
@ -963,3 +984,143 @@ async def test_oauth2_without_secret_init(
|
||||
client = await hass_client_no_auth()
|
||||
resp = await client.get("/auth/external/callback?code=abcd&state=qwer")
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
async def test_abort_oauth_with_pkce_rejected(
|
||||
hass: HomeAssistant,
|
||||
flow_handler: type[config_entry_oauth2_flow.AbstractOAuth2FlowHandler],
|
||||
local_impl_pkce: config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Check bad oauth token."""
|
||||
flow_handler.async_register_implementation(hass, local_impl_pkce)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
code_challenge = local_impl_pkce.compute_code_challenge(MOCK_SECRET_TOKEN_URLSAFE)
|
||||
assert result["type"] == data_entry_flow.FlowResultType.EXTERNAL_STEP
|
||||
|
||||
assert result["url"].startswith(f"{AUTHORIZE_URL}?")
|
||||
assert f"client_id={CLIENT_ID}" in result["url"]
|
||||
assert "redirect_uri=https://example.com/auth/external/callback" in result["url"]
|
||||
assert f"state={state}" in result["url"]
|
||||
assert "scope=read+write" in result["url"]
|
||||
assert "response_type=code" in result["url"]
|
||||
assert f"code_challenge={code_challenge}" in result["url"]
|
||||
assert "code_challenge_method=S256" in result["url"]
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
resp = await client.get(
|
||||
f"/auth/external/callback?error=access_denied&state={state}"
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
assert result["type"] == data_entry_flow.FlowResultType.ABORT
|
||||
assert result["reason"] == "user_rejected_authorize"
|
||||
assert result["description_placeholders"] == {"error": "access_denied"}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
async def test_oauth_with_pkce_adds_code_verifier_to_token_resolve(
|
||||
hass: HomeAssistant,
|
||||
flow_handler: type[config_entry_oauth2_flow.AbstractOAuth2FlowHandler],
|
||||
local_impl_pkce: config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
) -> None:
|
||||
"""Check pkce flow."""
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
domain=TEST_DOMAIN,
|
||||
async_setup_entry=AsyncMock(return_value=True),
|
||||
),
|
||||
)
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", None)
|
||||
flow_handler.async_register_implementation(hass, local_impl_pkce)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
|
||||
code_challenge = local_impl_pkce.compute_code_challenge(MOCK_SECRET_TOKEN_URLSAFE)
|
||||
assert result["type"] == data_entry_flow.FlowResultType.EXTERNAL_STEP
|
||||
|
||||
assert result["url"].startswith(f"{AUTHORIZE_URL}?")
|
||||
assert f"client_id={CLIENT_ID}" in result["url"]
|
||||
assert "redirect_uri=https://example.com/auth/external/callback" in result["url"]
|
||||
assert f"state={state}" in result["url"]
|
||||
assert "scope=read+write" in result["url"]
|
||||
assert "response_type=code" in result["url"]
|
||||
assert f"code_challenge={code_challenge}" in result["url"]
|
||||
assert "code_challenge_method=S256" in result["url"]
|
||||
|
||||
# Setup the response when HA tries to fetch a token with the code
|
||||
aioclient_mock.post(
|
||||
TOKEN_URL,
|
||||
json={
|
||||
"refresh_token": REFRESH_TOKEN,
|
||||
"access_token": ACCESS_TOKEN_1,
|
||||
"type": "bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
# trigger the callback
|
||||
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
assert resp.status == 200
|
||||
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
# Verify the token resolve request occurred
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
assert aioclient_mock.mock_calls[0][2] == {
|
||||
"client_id": CLIENT_ID,
|
||||
"grant_type": "authorization_code",
|
||||
"code": "abcd",
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
"code_verifier": MOCK_SECRET_TOKEN_URLSAFE,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("code_verifier_length", [40, 129])
|
||||
def test_generate_code_verifier_invalid_length(code_verifier_length: int) -> None:
|
||||
"""Test generate_code_verifier with an invalid length."""
|
||||
with pytest.raises(ValueError):
|
||||
config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce.generate_code_verifier(
|
||||
code_verifier_length
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("code_verifier", ["", "yyy", "a" * 129])
|
||||
def test_compute_code_challenge_invalid_code_verifier(code_verifier: str) -> None:
|
||||
"""Test compute_code_challenge with an invalid code_verifier."""
|
||||
with pytest.raises(ValueError):
|
||||
config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce.compute_code_challenge(
|
||||
code_verifier
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user