From ca576d45acf44530c1fe932518132f9650ad12ab Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 Mar 2023 16:03:41 -1000 Subject: [PATCH] Cache decode of JWT tokens (#90013) --- homeassistant/auth/__init__.py | 10 +- homeassistant/auth/jwt_wrapper.py | 116 +++++++++++++++++ homeassistant/components/http/auth.py | 3 +- tests/auth/test_init.py | 173 ++++++++++++++++++++++++++ tests/auth/test_jwt_wrapper.py | 12 ++ 5 files changed, 308 insertions(+), 6 deletions(-) create mode 100644 homeassistant/auth/jwt_wrapper.py create mode 100644 tests/auth/test_jwt_wrapper.py diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 5c401570dee..9a537174270 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -14,7 +14,7 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.data_entry_flow import FlowResult from homeassistant.util import dt as dt_util -from . import auth_store, models +from . import auth_store, jwt_wrapper, models from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config from .providers import AuthProvider, LoginFlow, auth_provider_from_config @@ -555,9 +555,7 @@ class AuthManager: ) -> models.RefreshToken | None: """Return refresh token if an access token is valid.""" try: - unverif_claims = jwt.decode( - token, algorithms=["HS256"], options={"verify_signature": False} - ) + unverif_claims = jwt_wrapper.unverified_hs256_token_decode(token) except jwt.InvalidTokenError: return None @@ -573,7 +571,9 @@ class AuthManager: issuer = refresh_token.id try: - jwt.decode(token, jwt_key, leeway=10, issuer=issuer, algorithms=["HS256"]) + jwt_wrapper.verify_and_decode( + token, jwt_key, leeway=10, issuer=issuer, algorithms=["HS256"] + ) except jwt.InvalidTokenError: return None diff --git a/homeassistant/auth/jwt_wrapper.py b/homeassistant/auth/jwt_wrapper.py new file mode 100644 index 00000000000..546e4afdcfa --- /dev/null +++ b/homeassistant/auth/jwt_wrapper.py @@ -0,0 +1,116 @@ +"""Provide a wrapper around JWT that caches decoding tokens. + +Since we decode the same tokens over and over again +we can cache the result of the decode of valid tokens +to speed up the process. +""" +from __future__ import annotations + +from datetime import timedelta +from functools import lru_cache, partial +from typing import Any + +from jwt import DecodeError, PyJWS, PyJWT + +from homeassistant.util.json import json_loads + +JWT_TOKEN_CACHE_SIZE = 16 +MAX_TOKEN_SIZE = 8192 + +_VERIFY_KEYS = ("signature", "exp", "nbf", "iat", "aud", "iss") + +_VERIFY_OPTIONS: dict[str, Any] = {f"verify_{key}": True for key in _VERIFY_KEYS} | { + "require": [] +} +_NO_VERIFY_OPTIONS = {f"verify_{key}": False for key in _VERIFY_KEYS} + + +class _PyJWSWithLoadCache(PyJWS): + """PyJWS with a dedicated load implementation.""" + + @lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE) + # We only ever have a global instance of this class + # so we do not have to worry about the LRU growing + # each time we create a new instance. + def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]: + """Load a JWS.""" + return super()._load(jwt) + + +_jws = _PyJWSWithLoadCache() + + +@lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE) +def _decode_payload(json_payload: str) -> dict[str, Any]: + """Decode the payload from a JWS dictionary.""" + try: + payload = json_loads(json_payload) + except ValueError as err: + raise DecodeError(f"Invalid payload string: {err}") from err + if not isinstance(payload, dict): + raise DecodeError("Invalid payload string: must be a json object") + return payload + + +class _PyJWTWithVerify(PyJWT): + """PyJWT with a fast decode implementation.""" + + def decode_payload( + self, jwt: str, key: str, options: dict[str, Any], algorithms: list[str] + ) -> dict[str, Any]: + """Decode a JWT's payload.""" + if len(jwt) > MAX_TOKEN_SIZE: + # Avoid caching impossible tokens + raise DecodeError("Token too large") + return _decode_payload( + _jws.decode_complete( + jwt=jwt, + key=key, + algorithms=algorithms, + options=options, + )["payload"] + ) + + def verify_and_decode( + self, + jwt: str, + key: str, + algorithms: list[str], + issuer: str | None = None, + leeway: int | float | timedelta = 0, + options: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Verify a JWT's signature and claims.""" + merged_options = {**_VERIFY_OPTIONS, **(options or {})} + payload = self.decode_payload( + jwt=jwt, + key=key, + options=merged_options, + algorithms=algorithms, + ) + # These should never be missing since we verify them + # but this is an additional safeguard to make sure + # nothing slips through. + assert "exp" in payload, "exp claim is required" + assert "iat" in payload, "iat claim is required" + self._validate_claims( # type: ignore[no-untyped-call] + payload=payload, + options=merged_options, + issuer=issuer, + leeway=leeway, + ) + return payload + + +_jwt = _PyJWTWithVerify() # type: ignore[no-untyped-call] +verify_and_decode = _jwt.verify_and_decode +unverified_hs256_token_decode = lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)( + partial( + _jwt.decode_payload, key="", algorithms=["HS256"], options=_NO_VERIFY_OPTIONS + ) +) + +__all__ = [ + "unverified_hs256_token_decode", + "verify_and_decode", +] diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py index 5213cd1b072..ec8c7de9899 100644 --- a/homeassistant/components/http/auth.py +++ b/homeassistant/components/http/auth.py @@ -13,6 +13,7 @@ from aiohttp.web import Application, Request, StreamResponse, middleware import jwt from yarl import URL +from homeassistant.auth import jwt_wrapper from homeassistant.auth.const import GROUP_ID_READ_ONLY from homeassistant.auth.models import User from homeassistant.components import websocket_api @@ -175,7 +176,7 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None: return False try: - claims = jwt.decode( + claims = jwt_wrapper.verify_and_decode( signature, secret, algorithms=["HS256"], options={"verify_iss": False} ) except jwt.InvalidTokenError: diff --git a/tests/auth/test_init.py b/tests/auth/test_init.py index 3d3674e1373..83c08dd73ee 100644 --- a/tests/auth/test_init.py +++ b/tests/auth/test_init.py @@ -3,6 +3,7 @@ from datetime import timedelta from typing import Any from unittest.mock import Mock, patch +from freezegun import freeze_time import jwt import pytest import voluptuous as vol @@ -1127,3 +1128,175 @@ async def test_event_user_updated_fires(hass: HomeAssistant) -> None: await hass.async_block_till_done() assert len(events) == 1 + + +async def test_access_token_with_invalid_signature(mock_hass) -> None: + """Test rejecting access tokens with an invalid signature.""" + manager = await auth.auth_manager_from_config(mock_hass, [], []) + user = MockUser().add_to_auth_manager(manager) + refresh_token = await manager.async_create_refresh_token( + user, + client_name="Good Client", + token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, + access_token_expiration=timedelta(days=3000), + ) + assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN + access_token = manager.async_create_access_token(refresh_token) + + rt = await manager.async_validate_access_token(access_token) + assert rt.id == refresh_token.id + + # Now we corrupt the signature + header, payload, signature = access_token.split(".") + invalid_signature = "a" * len(signature) + invalid_token = f"{header}.{payload}.{invalid_signature}" + + assert access_token != invalid_token + + result = await manager.async_validate_access_token(invalid_token) + assert result is None + + +async def test_access_token_with_null_signature(mock_hass) -> None: + """Test rejecting access tokens with a null signature.""" + manager = await auth.auth_manager_from_config(mock_hass, [], []) + user = MockUser().add_to_auth_manager(manager) + refresh_token = await manager.async_create_refresh_token( + user, + client_name="Good Client", + token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, + access_token_expiration=timedelta(days=3000), + ) + assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN + access_token = manager.async_create_access_token(refresh_token) + + rt = await manager.async_validate_access_token(access_token) + assert rt.id == refresh_token.id + + # Now we make the signature all nulls + header, payload, signature = access_token.split(".") + invalid_signature = "\0" * len(signature) + invalid_token = f"{header}.{payload}.{invalid_signature}" + + assert access_token != invalid_token + + result = await manager.async_validate_access_token(invalid_token) + assert result is None + + +async def test_access_token_with_empty_signature(mock_hass) -> None: + """Test rejecting access tokens with an empty signature.""" + manager = await auth.auth_manager_from_config(mock_hass, [], []) + user = MockUser().add_to_auth_manager(manager) + refresh_token = await manager.async_create_refresh_token( + user, + client_name="Good Client", + token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, + access_token_expiration=timedelta(days=3000), + ) + assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN + access_token = manager.async_create_access_token(refresh_token) + + rt = await manager.async_validate_access_token(access_token) + assert rt.id == refresh_token.id + + # Now we make the signature all nulls + header, payload, _ = access_token.split(".") + invalid_token = f"{header}.{payload}." + + assert access_token != invalid_token + + result = await manager.async_validate_access_token(invalid_token) + assert result is None + + +async def test_access_token_with_empty_key(mock_hass) -> None: + """Test rejecting access tokens with an empty key.""" + manager = await auth.auth_manager_from_config(mock_hass, [], []) + user = MockUser().add_to_auth_manager(manager) + refresh_token = await manager.async_create_refresh_token( + user, + client_name="Good Client", + token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, + access_token_expiration=timedelta(days=3000), + ) + assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN + + access_token = manager.async_create_access_token(refresh_token) + + await manager.async_remove_refresh_token(refresh_token) + # Now remove the token from the keyring + # so we will get an empty key + + assert await manager.async_validate_access_token(access_token) is None + + +async def test_reject_access_token_with_impossible_large_size(mock_hass) -> None: + """Test rejecting access tokens with impossible sizes.""" + manager = await auth.auth_manager_from_config(mock_hass, [], []) + assert await manager.async_validate_access_token("a" * 10000) is None + + +async def test_reject_token_with_invalid_json_payload(mock_hass) -> None: + """Test rejecting access tokens with invalid json payload.""" + jws = jwt.PyJWS() + token_with_invalid_json = jws.encode( + b"invalid", b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"} + ) + manager = await auth.auth_manager_from_config(mock_hass, [], []) + assert await manager.async_validate_access_token(token_with_invalid_json) is None + + +async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None: + """Test rejecting access tokens with not a dict json payload.""" + jws = jwt.PyJWS() + token_not_a_dict_json = jws.encode( + b'["invalid"]', b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"} + ) + manager = await auth.auth_manager_from_config(mock_hass, [], []) + assert await manager.async_validate_access_token(token_not_a_dict_json) is None + + +async def test_access_token_that_expires_soon(mock_hass) -> None: + """Test access token from refresh token that expires very soon.""" + now = dt_util.utcnow() + manager = await auth.auth_manager_from_config(mock_hass, [], []) + user = MockUser().add_to_auth_manager(manager) + refresh_token = await manager.async_create_refresh_token( + user, + client_name="Token that expires very soon", + token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, + access_token_expiration=timedelta(seconds=1), + ) + assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN + access_token = manager.async_create_access_token(refresh_token) + + rt = await manager.async_validate_access_token(access_token) + assert rt.id == refresh_token.id + + with freeze_time(now + timedelta(minutes=1)): + assert await manager.async_validate_access_token(access_token) is None + + +async def test_access_token_from_the_future(mock_hass) -> None: + """Test we reject an access token from the future.""" + now = dt_util.utcnow() + manager = await auth.auth_manager_from_config(mock_hass, [], []) + user = MockUser().add_to_auth_manager(manager) + with freeze_time(now + timedelta(days=365)): + refresh_token = await manager.async_create_refresh_token( + user, + client_name="Token that expires very soon", + token_type=auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, + access_token_expiration=timedelta(days=10), + ) + assert ( + refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN + ) + access_token = manager.async_create_access_token(refresh_token) + + assert await manager.async_validate_access_token(access_token) is None + + with freeze_time(now + timedelta(days=365)): + rt = await manager.async_validate_access_token(access_token) + assert rt.id == refresh_token.id diff --git a/tests/auth/test_jwt_wrapper.py b/tests/auth/test_jwt_wrapper.py new file mode 100644 index 00000000000..297d4dd5d7f --- /dev/null +++ b/tests/auth/test_jwt_wrapper.py @@ -0,0 +1,12 @@ +"""Tests for the Home Assistant auth jwt_wrapper module.""" + +import jwt +import pytest + +from homeassistant.auth import jwt_wrapper + + +async def test_reject_access_token_with_impossible_large_size() -> None: + """Test rejecting access tokens with impossible sizes.""" + with pytest.raises(jwt.DecodeError): + jwt_wrapper.unverified_hs256_token_decode("a" * 10000)