From f456e3a0718abd6b79ab4b484e2ebcf139221f6d Mon Sep 17 00:00:00 2001 From: karwosts <32912880+karwosts@users.noreply.github.com> Date: Mon, 29 Jan 2024 11:09:23 -0500 Subject: [PATCH] Allow delete_all_refresh_tokens to delete a specific token_type (#106119) * Allow delete_all_refresh_tokens to delete a specific token_type * add a test * minor string change * test updates * more test updates * more test updates * fix tests * do not delete current token * Update tests/components/auth/test_init.py * Update tests/components/auth/test_init.py * Option to not delete the current token --------- Co-authored-by: J. Nick Koston --- homeassistant/components/auth/__init__.py | 15 +++++- tests/components/auth/test_init.py | 63 ++++++++++++++++++++--- 2 files changed, 68 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index f4a59f13486..f97647fff0e 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -604,6 +604,8 @@ async def websocket_delete_refresh_token( @websocket_api.websocket_command( { vol.Required("type"): "auth/delete_all_refresh_tokens", + vol.Optional("token_type"): cv.string, + vol.Optional("delete_current_token", default=True): bool, } ) @websocket_api.ws_require_user() @@ -614,6 +616,10 @@ async def websocket_delete_all_refresh_tokens( """Handle delete all refresh tokens request.""" current_refresh_token: RefreshToken remove_failed = False + token_type = msg.get("token_type") + delete_current_token = msg.get("delete_current_token") + limit_token_types = token_type is not None + for token in list(connection.user.refresh_tokens.values()): if token.id == connection.refresh_token_id: # Skip the current refresh token as it has revoke_callback, @@ -621,6 +627,8 @@ async def websocket_delete_all_refresh_tokens( # It will be removed after sending the result. current_refresh_token = token continue + if limit_token_types and token_type != token.token_type: + continue try: hass.auth.async_remove_refresh_token(token) except Exception as err: # pylint: disable=broad-except @@ -637,8 +645,11 @@ async def websocket_delete_all_refresh_tokens( else: connection.send_result(msg["id"], {}) - # This will close the connection so we need to send the result first. - hass.loop.call_soon(hass.auth.async_remove_refresh_token, current_refresh_token) + if delete_current_token and ( + not limit_token_types or current_refresh_token.token_type == token_type + ): + # This will close the connection so we need to send the result first. + hass.loop.call_soon(hass.auth.async_remove_refresh_token, current_refresh_token) @websocket_api.websocket_command( diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index 666ee4cac07..3926cd2f82b 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -8,7 +8,11 @@ from freezegun.api import FrozenDateTimeFactory import pytest from homeassistant.auth import InvalidAuthError -from homeassistant.auth.models import Credentials +from homeassistant.auth.models import ( + TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, + TOKEN_TYPE_NORMAL, + Credentials, +) from homeassistant.components import auth from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -567,22 +571,50 @@ async def test_ws_delete_all_refresh_tokens_error( assert refresh_token is None +@pytest.mark.parametrize( + ( + "delete_token_type", + "delete_current_token", + "expected_remaining_normal_tokens", + "expected_remaining_long_lived_tokens", + ), + [ + ({}, {}, 0, 0), + ({"token_type": TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN}, {}, 3, 0), + ({"token_type": TOKEN_TYPE_NORMAL}, {}, 0, 1), + ({"token_type": TOKEN_TYPE_NORMAL}, {"delete_current_token": False}, 1, 1), + ], +) async def test_ws_delete_all_refresh_tokens( hass: HomeAssistant, hass_admin_user: MockUser, hass_admin_credential: Credentials, hass_ws_client: WebSocketGenerator, hass_access_token: str, + delete_token_type: dict[str:str], + delete_current_token: dict[str:bool], + expected_remaining_normal_tokens: int, + expected_remaining_long_lived_tokens: int, ) -> None: - """Test deleting all refresh tokens.""" + """Test deleting all or some refresh tokens.""" assert await async_setup_component(hass, "auth", {"http": {}}) # one token already exists await hass.auth.async_create_refresh_token( hass_admin_user, CLIENT_ID, credential=hass_admin_credential ) + + # create a long lived token await hass.auth.async_create_refresh_token( - hass_admin_user, CLIENT_ID + "_1", credential=hass_admin_credential + hass_admin_user, + f"{CLIENT_ID}_LL", + client_name="client_ll", + credential=hass_admin_credential, + token_type=TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, + ) + + await hass.auth.async_create_refresh_token( + hass_admin_user, f"{CLIENT_ID}_1", credential=hass_admin_credential ) ws_client = await hass_ws_client(hass, hass_access_token) @@ -592,20 +624,35 @@ async def test_ws_delete_all_refresh_tokens( result = await ws_client.receive_json() assert result["success"], result - tokens = result["result"] - await ws_client.send_json( { "id": 6, "type": "auth/delete_all_refresh_tokens", + **delete_token_type, + **delete_current_token, } ) result = await ws_client.receive_json() assert result, result["success"] - for token in tokens: - refresh_token = hass.auth.async_get_refresh_token(token["id"]) - assert refresh_token is None + + # We need to enumerate the user since we may remove the token + # that is used to authenticate the user which will prevent the websocket + # connection from working + remaining_tokens_by_type: dict[str, int] = { + TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN: 0, + TOKEN_TYPE_NORMAL: 0, + } + for refresh_token in hass_admin_user.refresh_tokens.values(): + remaining_tokens_by_type[refresh_token.token_type] += 1 + + assert ( + remaining_tokens_by_type[TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN] + == expected_remaining_long_lived_tokens + ) + assert ( + remaining_tokens_by_type[TOKEN_TYPE_NORMAL] == expected_remaining_normal_tokens + ) async def test_ws_sign_path(