Add PKCE implementation in oauth2 helper (#139509)

* Update config_entry_oauth2_flow.py

* Specify type on request_data

* Added LocalOAuth2ImplementationWithPkce

* LocalOAuth2ImplementationWithPkce works more like specs

* fix: Adding tests for pkce flow and feedback applied

* fix last test for pkce

* Clean test_abort_if_oauth_with_pkce_rejected

* Improve assertion of code verifier and code challenge

* Break long docstrings

* Shorten docstring

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Stephan van Rooij 2025-03-17 14:16:52 +01:00 committed by GitHub
parent fb2b3ce7d2
commit 76aef5be9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 273 additions and 11 deletions

View File

@ -11,7 +11,9 @@ from __future__ import annotations
from abc import ABC, ABCMeta, abstractmethod
import asyncio
from asyncio import Lock
import base64
from collections.abc import Awaitable, Callable
import hashlib
from http import HTTPStatus
from json import JSONDecodeError
import logging
@ -166,6 +168,11 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
"""Extra data that needs to be appended to the authorize url."""
return {}
@property
def extra_token_resolve_data(self) -> dict:
"""Extra data for the token resolve request."""
return {}
async def async_generate_authorize_url(self, flow_id: str) -> str:
"""Generate a url for the user to authorize."""
redirect_uri = self.redirect_uri
@ -186,13 +193,13 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
async def async_resolve_external_data(self, external_data: Any) -> dict:
"""Resolve the authorization code to tokens."""
return await self._token_request(
{
"grant_type": "authorization_code",
"code": external_data["code"],
"redirect_uri": external_data["state"]["redirect_uri"],
}
)
request_data: dict = {
"grant_type": "authorization_code",
"code": external_data["code"],
"redirect_uri": external_data["state"]["redirect_uri"],
}
request_data.update(self.extra_token_resolve_data)
return await self._token_request(request_data)
async def _async_refresh_token(self, token: dict) -> dict:
"""Refresh tokens."""
@ -211,7 +218,7 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
data["client_id"] = self.client_id
if self.client_secret is not None:
if self.client_secret:
data["client_secret"] = self.client_secret
_LOGGER.debug("Sending token request to %s", self.token_url)
@ -233,6 +240,100 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
return cast(dict, await resp.json())
class LocalOAuth2ImplementationWithPkce(LocalOAuth2Implementation):
"""Local OAuth2 implementation with PKCE."""
def __init__(
self,
hass: HomeAssistant,
domain: str,
client_id: str,
authorize_url: str,
token_url: str,
client_secret: str = "",
code_verifier_length: int = 128,
) -> None:
"""Initialize local auth implementation."""
super().__init__(
hass,
domain,
client_id,
client_secret,
authorize_url,
token_url,
)
# Generate code verifier
self.code_verifier = LocalOAuth2ImplementationWithPkce.generate_code_verifier(
code_verifier_length
)
@property
def extra_authorize_data(self) -> dict:
"""Extra data that needs to be appended to the authorize url.
If you want to override this method,
calling super is mandatory (for adding scopes):
```
@def extra_authorize_data(self) -> dict:
data: dict = {
"scope": "openid profile email",
}
data.update(super().extra_authorize_data)
return data
```
"""
return {
"code_challenge": LocalOAuth2ImplementationWithPkce.compute_code_challenge(
self.code_verifier
),
"code_challenge_method": "S256",
}
@property
def extra_token_resolve_data(self) -> dict:
"""Extra data that needs to be included in the token resolve request.
If you want to override this method,
calling super is mandatory (for adding `someKey`):
```
@def extra_token_resolve_data(self) -> dict:
data: dict = {
"someKey": "someValue",
}
data.update(super().extra_token_resolve_data)
return data
```
"""
return {"code_verifier": self.code_verifier}
@staticmethod
def generate_code_verifier(code_verifier_length: int = 128) -> str:
"""Generate a code verifier."""
if not 43 <= code_verifier_length <= 128:
msg = (
"Parameter `code_verifier_length` must validate"
"`43 <= code_verifier_length <= 128`."
)
raise ValueError(msg)
return secrets.token_urlsafe(96)[:code_verifier_length]
@staticmethod
def compute_code_challenge(code_verifier: str) -> str:
"""Compute the code challenge."""
if not 43 <= len(code_verifier) <= 128:
msg = (
"Parameter `code_verifier` must validate "
"`43 <= len(code_verifier) <= 128`."
)
raise ValueError(msg)
hashed = hashlib.sha256(code_verifier.encode("ascii")).digest()
encoded = base64.urlsafe_b64encode(hashed)
return encoded.decode("ascii").replace("=", "")
class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
"""Handle a config flow."""

View File

@ -1,11 +1,11 @@
"""Tests for the Somfy config flow."""
from collections.abc import Generator
from collections.abc import AsyncGenerator, Generator
from http import HTTPStatus
import logging
import time
from typing import Any
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
import aiohttp
import pytest
@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_entry_oauth2_flow
from homeassistant.helpers.network import NoURLAvailableError
from tests.common import MockConfigEntry, mock_platform
from tests.common import MockConfigEntry, MockModule, mock_integration, mock_platform
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
@ -27,6 +27,11 @@ ACCESS_TOKEN_1 = "mock-access-token-1"
ACCESS_TOKEN_2 = "mock-access-token-2"
AUTHORIZE_URL = "https://example.como/auth/authorize"
TOKEN_URL = "https://example.como/auth/token"
MOCK_SECRET_TOKEN_URLSAFE = (
"token-"
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
)
@pytest.fixture
@ -40,6 +45,22 @@ async def local_impl(
)
@pytest.fixture
async def local_impl_pkce(
hass: HomeAssistant,
) -> AsyncGenerator[config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce]:
"""Local implementation."""
assert await setup.async_setup_component(hass, "auth", {})
with patch(
"homeassistant.helpers.config_entry_oauth2_flow.secrets.token_urlsafe",
return_value=MOCK_SECRET_TOKEN_URLSAFE
+ "bbbbbb", # Add some characters that should be removed by the logic.
):
yield config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce(
hass, TEST_DOMAIN, CLIENT_ID, AUTHORIZE_URL, TOKEN_URL
)
@pytest.fixture
def flow_handler(
hass: HomeAssistant,
@ -963,3 +984,143 @@ async def test_oauth2_without_secret_init(
client = await hass_client_no_auth()
resp = await client.get("/auth/external/callback?code=abcd&state=qwer")
assert resp.status == 400
@pytest.mark.usefixtures("current_request_with_host")
async def test_abort_oauth_with_pkce_rejected(
hass: HomeAssistant,
flow_handler: type[config_entry_oauth2_flow.AbstractOAuth2FlowHandler],
local_impl_pkce: config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce,
hass_client_no_auth: ClientSessionGenerator,
) -> None:
"""Check bad oauth token."""
flow_handler.async_register_implementation(hass, local_impl_pkce)
result = await hass.config_entries.flow.async_init(
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
code_challenge = local_impl_pkce.compute_code_challenge(MOCK_SECRET_TOKEN_URLSAFE)
assert result["type"] == data_entry_flow.FlowResultType.EXTERNAL_STEP
assert result["url"].startswith(f"{AUTHORIZE_URL}?")
assert f"client_id={CLIENT_ID}" in result["url"]
assert "redirect_uri=https://example.com/auth/external/callback" in result["url"]
assert f"state={state}" in result["url"]
assert "scope=read+write" in result["url"]
assert "response_type=code" in result["url"]
assert f"code_challenge={code_challenge}" in result["url"]
assert "code_challenge_method=S256" in result["url"]
client = await hass_client_no_auth()
resp = await client.get(
f"/auth/external/callback?error=access_denied&state={state}"
)
assert resp.status == 200
assert resp.headers["content-type"] == "text/html; charset=utf-8"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == data_entry_flow.FlowResultType.ABORT
assert result["reason"] == "user_rejected_authorize"
assert result["description_placeholders"] == {"error": "access_denied"}
@pytest.mark.usefixtures("current_request_with_host")
async def test_oauth_with_pkce_adds_code_verifier_to_token_resolve(
hass: HomeAssistant,
flow_handler: type[config_entry_oauth2_flow.AbstractOAuth2FlowHandler],
local_impl_pkce: config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce,
hass_client_no_auth: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Check pkce flow."""
mock_integration(
hass,
MockModule(
domain=TEST_DOMAIN,
async_setup_entry=AsyncMock(return_value=True),
),
)
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", None)
flow_handler.async_register_implementation(hass, local_impl_pkce)
result = await hass.config_entries.flow.async_init(
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
code_challenge = local_impl_pkce.compute_code_challenge(MOCK_SECRET_TOKEN_URLSAFE)
assert result["type"] == data_entry_flow.FlowResultType.EXTERNAL_STEP
assert result["url"].startswith(f"{AUTHORIZE_URL}?")
assert f"client_id={CLIENT_ID}" in result["url"]
assert "redirect_uri=https://example.com/auth/external/callback" in result["url"]
assert f"state={state}" in result["url"]
assert "scope=read+write" in result["url"]
assert "response_type=code" in result["url"]
assert f"code_challenge={code_challenge}" in result["url"]
assert "code_challenge_method=S256" in result["url"]
# Setup the response when HA tries to fetch a token with the code
aioclient_mock.post(
TOKEN_URL,
json={
"refresh_token": REFRESH_TOKEN,
"access_token": ACCESS_TOKEN_1,
"type": "bearer",
"expires_in": 60,
},
)
client = await hass_client_no_auth()
# trigger the callback
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
assert resp.status == 200
assert resp.headers["content-type"] == "text/html; charset=utf-8"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
# Verify the token resolve request occurred
assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == {
"client_id": CLIENT_ID,
"grant_type": "authorization_code",
"code": "abcd",
"redirect_uri": "https://example.com/auth/external/callback",
"code_verifier": MOCK_SECRET_TOKEN_URLSAFE,
}
@pytest.mark.parametrize("code_verifier_length", [40, 129])
def test_generate_code_verifier_invalid_length(code_verifier_length: int) -> None:
"""Test generate_code_verifier with an invalid length."""
with pytest.raises(ValueError):
config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce.generate_code_verifier(
code_verifier_length
)
@pytest.mark.parametrize("code_verifier", ["", "yyy", "a" * 129])
def test_compute_code_challenge_invalid_code_verifier(code_verifier: str) -> None:
"""Test compute_code_challenge with an invalid code_verifier."""
with pytest.raises(ValueError):
config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce.compute_code_challenge(
code_verifier
)