mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Cache decode of JWT tokens (#90013)
This commit is contained in:
parent
8a591fa16e
commit
ca576d45ac
@ -14,7 +14,7 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
|||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from . import auth_store, models
|
from . import auth_store, jwt_wrapper, models
|
||||||
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN
|
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN
|
||||||
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
||||||
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
||||||
@ -555,9 +555,7 @@ class AuthManager:
|
|||||||
) -> models.RefreshToken | None:
|
) -> models.RefreshToken | None:
|
||||||
"""Return refresh token if an access token is valid."""
|
"""Return refresh token if an access token is valid."""
|
||||||
try:
|
try:
|
||||||
unverif_claims = jwt.decode(
|
unverif_claims = jwt_wrapper.unverified_hs256_token_decode(token)
|
||||||
token, algorithms=["HS256"], options={"verify_signature": False}
|
|
||||||
)
|
|
||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -573,7 +571,9 @@ class AuthManager:
|
|||||||
issuer = refresh_token.id
|
issuer = refresh_token.id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
jwt.decode(token, jwt_key, leeway=10, issuer=issuer, algorithms=["HS256"])
|
jwt_wrapper.verify_and_decode(
|
||||||
|
token, jwt_key, leeway=10, issuer=issuer, algorithms=["HS256"]
|
||||||
|
)
|
||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
116
homeassistant/auth/jwt_wrapper.py
Normal file
116
homeassistant/auth/jwt_wrapper.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
"""Provide a wrapper around JWT that caches decoding tokens.
|
||||||
|
|
||||||
|
Since we decode the same tokens over and over again
|
||||||
|
we can cache the result of the decode of valid tokens
|
||||||
|
to speed up the process.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from functools import lru_cache, partial
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from jwt import DecodeError, PyJWS, PyJWT
|
||||||
|
|
||||||
|
from homeassistant.util.json import json_loads
|
||||||
|
|
||||||
|
JWT_TOKEN_CACHE_SIZE = 16
|
||||||
|
MAX_TOKEN_SIZE = 8192
|
||||||
|
|
||||||
|
_VERIFY_KEYS = ("signature", "exp", "nbf", "iat", "aud", "iss")
|
||||||
|
|
||||||
|
_VERIFY_OPTIONS: dict[str, Any] = {f"verify_{key}": True for key in _VERIFY_KEYS} | {
|
||||||
|
"require": []
|
||||||
|
}
|
||||||
|
_NO_VERIFY_OPTIONS = {f"verify_{key}": False for key in _VERIFY_KEYS}
|
||||||
|
|
||||||
|
|
||||||
|
class _PyJWSWithLoadCache(PyJWS):
|
||||||
|
"""PyJWS with a dedicated load implementation."""
|
||||||
|
|
||||||
|
@lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)
|
||||||
|
# We only ever have a global instance of this class
|
||||||
|
# so we do not have to worry about the LRU growing
|
||||||
|
# each time we create a new instance.
|
||||||
|
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
|
||||||
|
"""Load a JWS."""
|
||||||
|
return super()._load(jwt)
|
||||||
|
|
||||||
|
|
||||||
|
_jws = _PyJWSWithLoadCache()
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)
|
||||||
|
def _decode_payload(json_payload: str) -> dict[str, Any]:
|
||||||
|
"""Decode the payload from a JWS dictionary."""
|
||||||
|
try:
|
||||||
|
payload = json_loads(json_payload)
|
||||||
|
except ValueError as err:
|
||||||
|
raise DecodeError(f"Invalid payload string: {err}") from err
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
raise DecodeError("Invalid payload string: must be a json object")
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
class _PyJWTWithVerify(PyJWT):
|
||||||
|
"""PyJWT with a fast decode implementation."""
|
||||||
|
|
||||||
|
def decode_payload(
|
||||||
|
self, jwt: str, key: str, options: dict[str, Any], algorithms: list[str]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Decode a JWT's payload."""
|
||||||
|
if len(jwt) > MAX_TOKEN_SIZE:
|
||||||
|
# Avoid caching impossible tokens
|
||||||
|
raise DecodeError("Token too large")
|
||||||
|
return _decode_payload(
|
||||||
|
_jws.decode_complete(
|
||||||
|
jwt=jwt,
|
||||||
|
key=key,
|
||||||
|
algorithms=algorithms,
|
||||||
|
options=options,
|
||||||
|
)["payload"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def verify_and_decode(
|
||||||
|
self,
|
||||||
|
jwt: str,
|
||||||
|
key: str,
|
||||||
|
algorithms: list[str],
|
||||||
|
issuer: str | None = None,
|
||||||
|
leeway: int | float | timedelta = 0,
|
||||||
|
options: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Verify a JWT's signature and claims."""
|
||||||
|
merged_options = {**_VERIFY_OPTIONS, **(options or {})}
|
||||||
|
payload = self.decode_payload(
|
||||||
|
jwt=jwt,
|
||||||
|
key=key,
|
||||||
|
options=merged_options,
|
||||||
|
algorithms=algorithms,
|
||||||
|
)
|
||||||
|
# These should never be missing since we verify them
|
||||||
|
# but this is an additional safeguard to make sure
|
||||||
|
# nothing slips through.
|
||||||
|
assert "exp" in payload, "exp claim is required"
|
||||||
|
assert "iat" in payload, "iat claim is required"
|
||||||
|
self._validate_claims( # type: ignore[no-untyped-call]
|
||||||
|
payload=payload,
|
||||||
|
options=merged_options,
|
||||||
|
issuer=issuer,
|
||||||
|
leeway=leeway,
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
_jwt = _PyJWTWithVerify() # type: ignore[no-untyped-call]
|
||||||
|
verify_and_decode = _jwt.verify_and_decode
|
||||||
|
unverified_hs256_token_decode = lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)(
|
||||||
|
partial(
|
||||||
|
_jwt.decode_payload, key="", algorithms=["HS256"], options=_NO_VERIFY_OPTIONS
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"unverified_hs256_token_decode",
|
||||||
|
"verify_and_decode",
|
||||||
|
]
|
@ -13,6 +13,7 @@ from aiohttp.web import Application, Request, StreamResponse, middleware
|
|||||||
import jwt
|
import jwt
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
|
from homeassistant.auth import jwt_wrapper
|
||||||
from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
||||||
from homeassistant.auth.models import User
|
from homeassistant.auth.models import User
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
@ -175,7 +176,7 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
claims = jwt.decode(
|
claims = jwt_wrapper.verify_and_decode(
|
||||||
signature, secret, algorithms=["HS256"], options={"verify_iss": False}
|
signature, secret, algorithms=["HS256"], options={"verify_iss": False}
|
||||||
)
|
)
|
||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
|
@ -3,6 +3,7 @@ from datetime import timedelta
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from freezegun import freeze_time
|
||||||
import jwt
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -1127,3 +1128,175 @@ async def test_event_user_updated_fires(hass: HomeAssistant) -> None:
|
|||||||
|
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_access_token_with_invalid_signature(mock_hass) -> None:
|
||||||
|
"""Test rejecting access tokens with an invalid signature."""
|
||||||
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
|
refresh_token = await manager.async_create_refresh_token(
|
||||||
|
user,
|
||||||
|
client_name="Good Client",
|
||||||
|
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||||
|
access_token_expiration=timedelta(days=3000),
|
||||||
|
)
|
||||||
|
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
|
rt = await manager.async_validate_access_token(access_token)
|
||||||
|
assert rt.id == refresh_token.id
|
||||||
|
|
||||||
|
# Now we corrupt the signature
|
||||||
|
header, payload, signature = access_token.split(".")
|
||||||
|
invalid_signature = "a" * len(signature)
|
||||||
|
invalid_token = f"{header}.{payload}.{invalid_signature}"
|
||||||
|
|
||||||
|
assert access_token != invalid_token
|
||||||
|
|
||||||
|
result = await manager.async_validate_access_token(invalid_token)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_access_token_with_null_signature(mock_hass) -> None:
|
||||||
|
"""Test rejecting access tokens with a null signature."""
|
||||||
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
|
refresh_token = await manager.async_create_refresh_token(
|
||||||
|
user,
|
||||||
|
client_name="Good Client",
|
||||||
|
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||||
|
access_token_expiration=timedelta(days=3000),
|
||||||
|
)
|
||||||
|
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
|
rt = await manager.async_validate_access_token(access_token)
|
||||||
|
assert rt.id == refresh_token.id
|
||||||
|
|
||||||
|
# Now we make the signature all nulls
|
||||||
|
header, payload, signature = access_token.split(".")
|
||||||
|
invalid_signature = "\0" * len(signature)
|
||||||
|
invalid_token = f"{header}.{payload}.{invalid_signature}"
|
||||||
|
|
||||||
|
assert access_token != invalid_token
|
||||||
|
|
||||||
|
result = await manager.async_validate_access_token(invalid_token)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_access_token_with_empty_signature(mock_hass) -> None:
|
||||||
|
"""Test rejecting access tokens with an empty signature."""
|
||||||
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
|
refresh_token = await manager.async_create_refresh_token(
|
||||||
|
user,
|
||||||
|
client_name="Good Client",
|
||||||
|
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||||
|
access_token_expiration=timedelta(days=3000),
|
||||||
|
)
|
||||||
|
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
|
rt = await manager.async_validate_access_token(access_token)
|
||||||
|
assert rt.id == refresh_token.id
|
||||||
|
|
||||||
|
# Now we make the signature all nulls
|
||||||
|
header, payload, _ = access_token.split(".")
|
||||||
|
invalid_token = f"{header}.{payload}."
|
||||||
|
|
||||||
|
assert access_token != invalid_token
|
||||||
|
|
||||||
|
result = await manager.async_validate_access_token(invalid_token)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_access_token_with_empty_key(mock_hass) -> None:
|
||||||
|
"""Test rejecting access tokens with an empty key."""
|
||||||
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
|
refresh_token = await manager.async_create_refresh_token(
|
||||||
|
user,
|
||||||
|
client_name="Good Client",
|
||||||
|
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||||
|
access_token_expiration=timedelta(days=3000),
|
||||||
|
)
|
||||||
|
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
|
|
||||||
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
|
await manager.async_remove_refresh_token(refresh_token)
|
||||||
|
# Now remove the token from the keyring
|
||||||
|
# so we will get an empty key
|
||||||
|
|
||||||
|
assert await manager.async_validate_access_token(access_token) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reject_access_token_with_impossible_large_size(mock_hass) -> None:
|
||||||
|
"""Test rejecting access tokens with impossible sizes."""
|
||||||
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
|
assert await manager.async_validate_access_token("a" * 10000) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reject_token_with_invalid_json_payload(mock_hass) -> None:
|
||||||
|
"""Test rejecting access tokens with invalid json payload."""
|
||||||
|
jws = jwt.PyJWS()
|
||||||
|
token_with_invalid_json = jws.encode(
|
||||||
|
b"invalid", b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
|
||||||
|
)
|
||||||
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
|
assert await manager.async_validate_access_token(token_with_invalid_json) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None:
|
||||||
|
"""Test rejecting access tokens with not a dict json payload."""
|
||||||
|
jws = jwt.PyJWS()
|
||||||
|
token_not_a_dict_json = jws.encode(
|
||||||
|
b'["invalid"]', b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
|
||||||
|
)
|
||||||
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
|
assert await manager.async_validate_access_token(token_not_a_dict_json) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_access_token_that_expires_soon(mock_hass) -> None:
|
||||||
|
"""Test access token from refresh token that expires very soon."""
|
||||||
|
now = dt_util.utcnow()
|
||||||
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
|
refresh_token = await manager.async_create_refresh_token(
|
||||||
|
user,
|
||||||
|
client_name="Token that expires very soon",
|
||||||
|
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||||
|
access_token_expiration=timedelta(seconds=1),
|
||||||
|
)
|
||||||
|
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
|
rt = await manager.async_validate_access_token(access_token)
|
||||||
|
assert rt.id == refresh_token.id
|
||||||
|
|
||||||
|
with freeze_time(now + timedelta(minutes=1)):
|
||||||
|
assert await manager.async_validate_access_token(access_token) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_access_token_from_the_future(mock_hass) -> None:
|
||||||
|
"""Test we reject an access token from the future."""
|
||||||
|
now = dt_util.utcnow()
|
||||||
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
|
with freeze_time(now + timedelta(days=365)):
|
||||||
|
refresh_token = await manager.async_create_refresh_token(
|
||||||
|
user,
|
||||||
|
client_name="Token that expires very soon",
|
||||||
|
token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||||
|
access_token_expiration=timedelta(days=10),
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
|
)
|
||||||
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
|
assert await manager.async_validate_access_token(access_token) is None
|
||||||
|
|
||||||
|
with freeze_time(now + timedelta(days=365)):
|
||||||
|
rt = await manager.async_validate_access_token(access_token)
|
||||||
|
assert rt.id == refresh_token.id
|
||||||
|
12
tests/auth/test_jwt_wrapper.py
Normal file
12
tests/auth/test_jwt_wrapper.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
"""Tests for the Home Assistant auth jwt_wrapper module."""
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.auth import jwt_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reject_access_token_with_impossible_large_size() -> None:
|
||||||
|
"""Test rejecting access tokens with impossible sizes."""
|
||||||
|
with pytest.raises(jwt.DecodeError):
|
||||||
|
jwt_wrapper.unverified_hs256_token_decode("a" * 10000)
|
Loading…
x
Reference in New Issue
Block a user