mirror of
https://github.com/home-assistant/core.git
synced 2025-07-17 10:17:09 +00:00
Add PKCE implementation in oauth2 helper (#139509)
* Update config_entry_oauth2_flow.py * Specify type on request_data * Added LocalOAuth2ImplementationWithPkce * LocalOAuth2ImplementationWithPkce works more like specs * fix: Adding tests for pkce flow and feedback applied * fix last test for pkce * Clean test_abort_if_oauth_with_pkce_rejected * Improve assertion of code verifier and code challenge * Break long docstrings * Shorten docstring --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
fb2b3ce7d2
commit
76aef5be9f
@ -11,7 +11,9 @@ from __future__ import annotations
|
|||||||
from abc import ABC, ABCMeta, abstractmethod
|
from abc import ABC, ABCMeta, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import Lock
|
from asyncio import Lock
|
||||||
|
import base64
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
import hashlib
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
import logging
|
import logging
|
||||||
@ -166,6 +168,11 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
|||||||
"""Extra data that needs to be appended to the authorize url."""
|
"""Extra data that needs to be appended to the authorize url."""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_token_resolve_data(self) -> dict:
|
||||||
|
"""Extra data for the token resolve request."""
|
||||||
|
return {}
|
||||||
|
|
||||||
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
||||||
"""Generate a url for the user to authorize."""
|
"""Generate a url for the user to authorize."""
|
||||||
redirect_uri = self.redirect_uri
|
redirect_uri = self.redirect_uri
|
||||||
@ -186,13 +193,13 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
|||||||
|
|
||||||
async def async_resolve_external_data(self, external_data: Any) -> dict:
|
async def async_resolve_external_data(self, external_data: Any) -> dict:
|
||||||
"""Resolve the authorization code to tokens."""
|
"""Resolve the authorization code to tokens."""
|
||||||
return await self._token_request(
|
request_data: dict = {
|
||||||
{
|
"grant_type": "authorization_code",
|
||||||
"grant_type": "authorization_code",
|
"code": external_data["code"],
|
||||||
"code": external_data["code"],
|
"redirect_uri": external_data["state"]["redirect_uri"],
|
||||||
"redirect_uri": external_data["state"]["redirect_uri"],
|
}
|
||||||
}
|
request_data.update(self.extra_token_resolve_data)
|
||||||
)
|
return await self._token_request(request_data)
|
||||||
|
|
||||||
async def _async_refresh_token(self, token: dict) -> dict:
|
async def _async_refresh_token(self, token: dict) -> dict:
|
||||||
"""Refresh tokens."""
|
"""Refresh tokens."""
|
||||||
@ -211,7 +218,7 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
|||||||
|
|
||||||
data["client_id"] = self.client_id
|
data["client_id"] = self.client_id
|
||||||
|
|
||||||
if self.client_secret is not None:
|
if self.client_secret:
|
||||||
data["client_secret"] = self.client_secret
|
data["client_secret"] = self.client_secret
|
||||||
|
|
||||||
_LOGGER.debug("Sending token request to %s", self.token_url)
|
_LOGGER.debug("Sending token request to %s", self.token_url)
|
||||||
@ -233,6 +240,100 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
|||||||
return cast(dict, await resp.json())
|
return cast(dict, await resp.json())
|
||||||
|
|
||||||
|
|
||||||
|
class LocalOAuth2ImplementationWithPkce(LocalOAuth2Implementation):
|
||||||
|
"""Local OAuth2 implementation with PKCE."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
domain: str,
|
||||||
|
client_id: str,
|
||||||
|
authorize_url: str,
|
||||||
|
token_url: str,
|
||||||
|
client_secret: str = "",
|
||||||
|
code_verifier_length: int = 128,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize local auth implementation."""
|
||||||
|
super().__init__(
|
||||||
|
hass,
|
||||||
|
domain,
|
||||||
|
client_id,
|
||||||
|
client_secret,
|
||||||
|
authorize_url,
|
||||||
|
token_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate code verifier
|
||||||
|
self.code_verifier = LocalOAuth2ImplementationWithPkce.generate_code_verifier(
|
||||||
|
code_verifier_length
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_authorize_data(self) -> dict:
|
||||||
|
"""Extra data that needs to be appended to the authorize url.
|
||||||
|
|
||||||
|
If you want to override this method,
|
||||||
|
calling super is mandatory (for adding scopes):
|
||||||
|
```
|
||||||
|
@def extra_authorize_data(self) -> dict:
|
||||||
|
data: dict = {
|
||||||
|
"scope": "openid profile email",
|
||||||
|
}
|
||||||
|
data.update(super().extra_authorize_data)
|
||||||
|
return data
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"code_challenge": LocalOAuth2ImplementationWithPkce.compute_code_challenge(
|
||||||
|
self.code_verifier
|
||||||
|
),
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_token_resolve_data(self) -> dict:
|
||||||
|
"""Extra data that needs to be included in the token resolve request.
|
||||||
|
|
||||||
|
If you want to override this method,
|
||||||
|
calling super is mandatory (for adding `someKey`):
|
||||||
|
```
|
||||||
|
@def extra_token_resolve_data(self) -> dict:
|
||||||
|
data: dict = {
|
||||||
|
"someKey": "someValue",
|
||||||
|
}
|
||||||
|
data.update(super().extra_token_resolve_data)
|
||||||
|
return data
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
return {"code_verifier": self.code_verifier}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_code_verifier(code_verifier_length: int = 128) -> str:
|
||||||
|
"""Generate a code verifier."""
|
||||||
|
if not 43 <= code_verifier_length <= 128:
|
||||||
|
msg = (
|
||||||
|
"Parameter `code_verifier_length` must validate"
|
||||||
|
"`43 <= code_verifier_length <= 128`."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
return secrets.token_urlsafe(96)[:code_verifier_length]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_code_challenge(code_verifier: str) -> str:
|
||||||
|
"""Compute the code challenge."""
|
||||||
|
if not 43 <= len(code_verifier) <= 128:
|
||||||
|
msg = (
|
||||||
|
"Parameter `code_verifier` must validate "
|
||||||
|
"`43 <= len(code_verifier) <= 128`."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
hashed = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||||
|
encoded = base64.urlsafe_b64encode(hashed)
|
||||||
|
return encoded.decode("ascii").replace("=", "")
|
||||||
|
|
||||||
|
|
||||||
class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
|
class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
|
||||||
"""Handle a config flow."""
|
"""Handle a config flow."""
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
"""Tests for the Somfy config flow."""
|
"""Tests for the Somfy config flow."""
|
||||||
|
|
||||||
from collections.abc import Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest
|
import pytest
|
||||||
@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.helpers import config_entry_oauth2_flow
|
from homeassistant.helpers import config_entry_oauth2_flow
|
||||||
from homeassistant.helpers.network import NoURLAvailableError
|
from homeassistant.helpers.network import NoURLAvailableError
|
||||||
|
|
||||||
from tests.common import MockConfigEntry, mock_platform
|
from tests.common import MockConfigEntry, MockModule, mock_integration, mock_platform
|
||||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
@ -27,6 +27,11 @@ ACCESS_TOKEN_1 = "mock-access-token-1"
|
|||||||
ACCESS_TOKEN_2 = "mock-access-token-2"
|
ACCESS_TOKEN_2 = "mock-access-token-2"
|
||||||
AUTHORIZE_URL = "https://example.como/auth/authorize"
|
AUTHORIZE_URL = "https://example.como/auth/authorize"
|
||||||
TOKEN_URL = "https://example.como/auth/token"
|
TOKEN_URL = "https://example.como/auth/token"
|
||||||
|
MOCK_SECRET_TOKEN_URLSAFE = (
|
||||||
|
"token-"
|
||||||
|
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||||
|
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -40,6 +45,22 @@ async def local_impl(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def local_impl_pkce(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> AsyncGenerator[config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce]:
|
||||||
|
"""Local implementation."""
|
||||||
|
assert await setup.async_setup_component(hass, "auth", {})
|
||||||
|
with patch(
|
||||||
|
"homeassistant.helpers.config_entry_oauth2_flow.secrets.token_urlsafe",
|
||||||
|
return_value=MOCK_SECRET_TOKEN_URLSAFE
|
||||||
|
+ "bbbbbb", # Add some characters that should be removed by the logic.
|
||||||
|
):
|
||||||
|
yield config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce(
|
||||||
|
hass, TEST_DOMAIN, CLIENT_ID, AUTHORIZE_URL, TOKEN_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def flow_handler(
|
def flow_handler(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -963,3 +984,143 @@ async def test_oauth2_without_secret_init(
|
|||||||
client = await hass_client_no_auth()
|
client = await hass_client_no_auth()
|
||||||
resp = await client.get("/auth/external/callback?code=abcd&state=qwer")
|
resp = await client.get("/auth/external/callback?code=abcd&state=qwer")
|
||||||
assert resp.status == 400
|
assert resp.status == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("current_request_with_host")
|
||||||
|
async def test_abort_oauth_with_pkce_rejected(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
flow_handler: type[config_entry_oauth2_flow.AbstractOAuth2FlowHandler],
|
||||||
|
local_impl_pkce: config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce,
|
||||||
|
hass_client_no_auth: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Check bad oauth token."""
|
||||||
|
flow_handler.async_register_implementation(hass, local_impl_pkce)
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
|
||||||
|
state = config_entry_oauth2_flow._encode_jwt(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
"flow_id": result["flow_id"],
|
||||||
|
"redirect_uri": "https://example.com/auth/external/callback",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
code_challenge = local_impl_pkce.compute_code_challenge(MOCK_SECRET_TOKEN_URLSAFE)
|
||||||
|
assert result["type"] == data_entry_flow.FlowResultType.EXTERNAL_STEP
|
||||||
|
|
||||||
|
assert result["url"].startswith(f"{AUTHORIZE_URL}?")
|
||||||
|
assert f"client_id={CLIENT_ID}" in result["url"]
|
||||||
|
assert "redirect_uri=https://example.com/auth/external/callback" in result["url"]
|
||||||
|
assert f"state={state}" in result["url"]
|
||||||
|
assert "scope=read+write" in result["url"]
|
||||||
|
assert "response_type=code" in result["url"]
|
||||||
|
assert f"code_challenge={code_challenge}" in result["url"]
|
||||||
|
assert "code_challenge_method=S256" in result["url"]
|
||||||
|
|
||||||
|
client = await hass_client_no_auth()
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/external/callback?error=access_denied&state={state}"
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||||
|
|
||||||
|
assert result["type"] == data_entry_flow.FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "user_rejected_authorize"
|
||||||
|
assert result["description_placeholders"] == {"error": "access_denied"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("current_request_with_host")
|
||||||
|
async def test_oauth_with_pkce_adds_code_verifier_to_token_resolve(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
flow_handler: type[config_entry_oauth2_flow.AbstractOAuth2FlowHandler],
|
||||||
|
local_impl_pkce: config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce,
|
||||||
|
hass_client_no_auth: ClientSessionGenerator,
|
||||||
|
aioclient_mock: AiohttpClientMocker,
|
||||||
|
) -> None:
|
||||||
|
"""Check pkce flow."""
|
||||||
|
|
||||||
|
mock_integration(
|
||||||
|
hass,
|
||||||
|
MockModule(
|
||||||
|
domain=TEST_DOMAIN,
|
||||||
|
async_setup_entry=AsyncMock(return_value=True),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", None)
|
||||||
|
flow_handler.async_register_implementation(hass, local_impl_pkce)
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
|
||||||
|
state = config_entry_oauth2_flow._encode_jwt(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
"flow_id": result["flow_id"],
|
||||||
|
"redirect_uri": "https://example.com/auth/external/callback",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
code_challenge = local_impl_pkce.compute_code_challenge(MOCK_SECRET_TOKEN_URLSAFE)
|
||||||
|
assert result["type"] == data_entry_flow.FlowResultType.EXTERNAL_STEP
|
||||||
|
|
||||||
|
assert result["url"].startswith(f"{AUTHORIZE_URL}?")
|
||||||
|
assert f"client_id={CLIENT_ID}" in result["url"]
|
||||||
|
assert "redirect_uri=https://example.com/auth/external/callback" in result["url"]
|
||||||
|
assert f"state={state}" in result["url"]
|
||||||
|
assert "scope=read+write" in result["url"]
|
||||||
|
assert "response_type=code" in result["url"]
|
||||||
|
assert f"code_challenge={code_challenge}" in result["url"]
|
||||||
|
assert "code_challenge_method=S256" in result["url"]
|
||||||
|
|
||||||
|
# Setup the response when HA tries to fetch a token with the code
|
||||||
|
aioclient_mock.post(
|
||||||
|
TOKEN_URL,
|
||||||
|
json={
|
||||||
|
"refresh_token": REFRESH_TOKEN,
|
||||||
|
"access_token": ACCESS_TOKEN_1,
|
||||||
|
"type": "bearer",
|
||||||
|
"expires_in": 60,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
client = await hass_client_no_auth()
|
||||||
|
# trigger the callback
|
||||||
|
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||||
|
|
||||||
|
# Verify the token resolve request occurred
|
||||||
|
assert len(aioclient_mock.mock_calls) == 1
|
||||||
|
assert aioclient_mock.mock_calls[0][2] == {
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": "abcd",
|
||||||
|
"redirect_uri": "https://example.com/auth/external/callback",
|
||||||
|
"code_verifier": MOCK_SECRET_TOKEN_URLSAFE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("code_verifier_length", [40, 129])
|
||||||
|
def test_generate_code_verifier_invalid_length(code_verifier_length: int) -> None:
|
||||||
|
"""Test generate_code_verifier with an invalid length."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce.generate_code_verifier(
|
||||||
|
code_verifier_length
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("code_verifier", ["", "yyy", "a" * 129])
|
||||||
|
def test_compute_code_challenge_invalid_code_verifier(code_verifier: str) -> None:
|
||||||
|
"""Test compute_code_challenge with an invalid code_verifier."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
config_entry_oauth2_flow.LocalOAuth2ImplementationWithPkce.compute_code_challenge(
|
||||||
|
code_verifier
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user