From ec15b0def2ee1d8d3715c172b7b15180b8599b19 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 20 Jan 2024 16:16:43 -1000 Subject: [PATCH] Always load auth storage at startup (#108543) --- homeassistant/auth/__init__.py | 4 +- homeassistant/auth/auth_store.py | 78 +++---------------- homeassistant/scripts/auth.py | 2 + tests/auth/providers/test_command_line.py | 6 +- tests/auth/providers/test_insecure_example.py | 6 +- .../providers/test_legacy_api_password.py | 6 +- tests/auth/providers/test_trusted_networks.py | 6 +- tests/auth/test_auth_store.py | 15 +++- tests/auth/test_init.py | 1 + 9 files changed, 43 insertions(+), 81 deletions(-) diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 000dde90faa..ac9bbaaf593 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -47,6 +47,7 @@ async def auth_manager_from_config( mfa modules exist in configs. """ store = auth_store.AuthStore(hass) + await store.async_load() if provider_configs: providers = await asyncio.gather( *( @@ -73,8 +74,7 @@ async def auth_manager_from_config( for module in modules: module_hash[module.id] = module - manager = AuthManager(hass, store, provider_hash, module_hash) - return manager + return AuthManager(hass, store, provider_hash, module_hash) class AuthManagerFlowManager(data_entry_flow.FlowManager): diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index 534016a54e0..5de5d087a65 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -1,7 +1,6 @@ """Storage for auth models.""" from __future__ import annotations -import asyncio from datetime import timedelta import hmac from logging import getLogger @@ -42,44 +41,28 @@ class AuthStore: def __init__(self, hass: HomeAssistant) -> None: """Initialize the auth store.""" self.hass = hass - self._users: dict[str, models.User] | None = None - self._groups: dict[str, models.Group] | None = None - self._perm_lookup: PermissionLookup | None = None + self._loaded = False + self._users: dict[str, models.User] = None # type: ignore[assignment] + self._groups: dict[str, models.Group] = None # type: ignore[assignment] + self._perm_lookup: PermissionLookup = None # type: ignore[assignment] self._store = Store[dict[str, list[dict[str, Any]]]]( hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True ) - self._lock = asyncio.Lock() async def async_get_groups(self) -> list[models.Group]: """Retrieve all users.""" - if self._groups is None: - await self._async_load() - assert self._groups is not None - return list(self._groups.values()) async def async_get_group(self, group_id: str) -> models.Group | None: """Retrieve all users.""" - if self._groups is None: - await self._async_load() - assert self._groups is not None - return self._groups.get(group_id) async def async_get_users(self) -> list[models.User]: """Retrieve all users.""" - if self._users is None: - await self._async_load() - assert self._users is not None - return list(self._users.values()) async def async_get_user(self, user_id: str) -> models.User | None: """Retrieve a user by id.""" - if self._users is None: - await self._async_load() - assert self._users is not None - return self._users.get(user_id) async def async_create_user( @@ -93,12 +76,6 @@ class AuthStore: local_only: bool | None = None, ) -> models.User: """Create a new user.""" - if self._users is None: - await self._async_load() - - assert self._users is not None - assert self._groups is not None - groups = [] for group_id in group_ids or []: if (group := self._groups.get(group_id)) is None: @@ -144,10 +121,6 @@ class AuthStore: async def async_remove_user(self, user: models.User) -> None: """Remove a user.""" - if self._users is None: - await self._async_load() - assert self._users is not None - self._users.pop(user.id) self._async_schedule_save() @@ -160,8 +133,6 @@ class AuthStore: local_only: bool | None = None, ) -> None: """Update a user.""" - assert self._groups is not None - if group_ids is not None: groups = [] for grid in group_ids: @@ -193,10 +164,6 @@ class AuthStore: async def async_remove_credentials(self, credentials: models.Credentials) -> None: """Remove credentials.""" - if self._users is None: - await self._async_load() - assert self._users is not None - for user in self._users.values(): found = None @@ -244,10 +211,6 @@ class AuthStore: self, refresh_token: models.RefreshToken ) -> None: """Remove a refresh token.""" - if self._users is None: - await self._async_load() - assert self._users is not None - for user in self._users.values(): if user.refresh_tokens.pop(refresh_token.id, None): self._async_schedule_save() @@ -257,10 +220,6 @@ class AuthStore: self, token_id: str ) -> models.RefreshToken | None: """Get refresh token by id.""" - if self._users is None: - await self._async_load() - assert self._users is not None - for user in self._users.values(): refresh_token = user.refresh_tokens.get(token_id) if refresh_token is not None: @@ -272,10 +231,6 @@ class AuthStore: self, token: str ) -> models.RefreshToken | None: """Get refresh token by token.""" - if self._users is None: - await self._async_load() - assert self._users is not None - found = None for user in self._users.values(): @@ -294,25 +249,18 @@ class AuthStore: refresh_token.last_used_ip = remote_ip self._async_schedule_save() - async def _async_load(self) -> None: + async def async_load(self) -> None: """Load the users.""" - async with self._lock: - if self._users is not None: - return - await self._async_load_task() + if self._loaded: + raise RuntimeError("Auth storage is already loaded") + self._loaded = True - async def _async_load_task(self) -> None: - """Load the users.""" dev_reg = dr.async_get(self.hass) ent_reg = er.async_get(self.hass) data = await self._store.async_load() - # Make sure that we're not overriding data if 2 loads happened at the - # same time - if self._users is not None: - return - - self._perm_lookup = perm_lookup = PermissionLookup(ent_reg, dev_reg) + perm_lookup = PermissionLookup(ent_reg, dev_reg) + self._perm_lookup = perm_lookup if data is None or not isinstance(data, dict): self._set_defaults() @@ -495,17 +443,11 @@ class AuthStore: @callback def _async_schedule_save(self) -> None: """Save users.""" - if self._users is None: - return - self._store.async_delay_save(self._data_to_save, 1) @callback def _data_to_save(self) -> dict[str, list[dict[str, Any]]]: """Return the data to store.""" - assert self._users is not None - assert self._groups is not None - users = [ { "id": user.id, diff --git a/homeassistant/scripts/auth.py b/homeassistant/scripts/auth.py index 5714e5814a4..dd3b9b7ba48 100644 --- a/homeassistant/scripts/auth.py +++ b/homeassistant/scripts/auth.py @@ -9,6 +9,7 @@ from homeassistant.auth import auth_manager_from_config from homeassistant.auth.providers import homeassistant as hass_auth from homeassistant.config import get_default_config_dir from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr, entity_registry as er # mypy: allow-untyped-calls, allow-untyped-defs @@ -51,6 +52,7 @@ def run(args): async def run_command(args): """Run the command.""" hass = HomeAssistant(os.path.join(os.getcwd(), args.config)) + await asyncio.gather(dr.async_load(hass), er.async_load(hass)) hass.auth = await auth_manager_from_config(hass, [{"type": "homeassistant"}], []) provider = hass.auth.auth_providers[0] await provider.async_initialize() diff --git a/tests/auth/providers/test_command_line.py b/tests/auth/providers/test_command_line.py index a92d41a8c5f..016ce767bad 100644 --- a/tests/auth/providers/test_command_line.py +++ b/tests/auth/providers/test_command_line.py @@ -13,9 +13,11 @@ from homeassistant.const import CONF_TYPE @pytest.fixture -def store(hass): +async def store(hass): """Mock store.""" - return auth_store.AuthStore(hass) + store = auth_store.AuthStore(hass) + await store.async_load() + return store @pytest.fixture diff --git a/tests/auth/providers/test_insecure_example.py b/tests/auth/providers/test_insecure_example.py index 6054b7937c6..ceb8b02ae65 100644 --- a/tests/auth/providers/test_insecure_example.py +++ b/tests/auth/providers/test_insecure_example.py @@ -9,9 +9,11 @@ from homeassistant.auth.providers import insecure_example @pytest.fixture -def store(hass): +async def store(hass): """Mock store.""" - return auth_store.AuthStore(hass) + store = auth_store.AuthStore(hass) + await store.async_load() + return store @pytest.fixture diff --git a/tests/auth/providers/test_legacy_api_password.py b/tests/auth/providers/test_legacy_api_password.py index 3d89c577ebf..75c4f733285 100644 --- a/tests/auth/providers/test_legacy_api_password.py +++ b/tests/auth/providers/test_legacy_api_password.py @@ -14,9 +14,11 @@ CONFIG = {"type": "legacy_api_password", "api_password": "test-password"} @pytest.fixture -def store(hass): +async def store(hass): """Mock store.""" - return auth_store.AuthStore(hass) + store = auth_store.AuthStore(hass) + await store.async_load() + return store @pytest.fixture diff --git a/tests/auth/providers/test_trusted_networks.py b/tests/auth/providers/test_trusted_networks.py index a098eea28e0..3ccff990b9c 100644 --- a/tests/auth/providers/test_trusted_networks.py +++ b/tests/auth/providers/test_trusted_networks.py @@ -16,9 +16,11 @@ from homeassistant.setup import async_setup_component @pytest.fixture -def store(hass): +async def store(hass): """Mock store.""" - return auth_store.AuthStore(hass) + store = auth_store.AuthStore(hass) + await store.async_load() + return store @pytest.fixture diff --git a/tests/auth/test_auth_store.py b/tests/auth/test_auth_store.py index 860abe76577..778095388a8 100644 --- a/tests/auth/test_auth_store.py +++ b/tests/auth/test_auth_store.py @@ -3,6 +3,8 @@ import asyncio from typing import Any from unittest.mock import patch +import pytest + from homeassistant.auth import auth_store from homeassistant.core import HomeAssistant @@ -67,6 +69,7 @@ async def test_loading_no_group_data_format( } store = auth_store.AuthStore(hass) + await store.async_load() groups = await store.async_get_groups() assert len(groups) == 3 admin_group = groups[0] @@ -165,6 +168,7 @@ async def test_loading_all_access_group_data_format( } store = auth_store.AuthStore(hass) + await store.async_load() groups = await store.async_get_groups() assert len(groups) == 3 admin_group = groups[0] @@ -205,6 +209,7 @@ async def test_loading_empty_data( ) -> None: """Test we correctly load with no existing data.""" store = auth_store.AuthStore(hass) + await store.async_load() groups = await store.async_get_groups() assert len(groups) == 3 admin_group = groups[0] @@ -232,7 +237,7 @@ async def test_system_groups_store_id_and_name( Name is stored so that we remain backwards compat with < 0.82. """ store = auth_store.AuthStore(hass) - await store._async_load() + await store.async_load() data = store._data_to_save() assert len(data["users"]) == 0 assert data["groups"] == [ @@ -242,8 +247,8 @@ async def test_system_groups_store_id_and_name( ] -async def test_loading_race_condition(hass: HomeAssistant) -> None: - """Test only one storage load called when concurrent loading occurred .""" +async def test_loading_only_once(hass: HomeAssistant) -> None: + """Test only one storage load is allowed.""" store = auth_store.AuthStore(hass) with patch( "homeassistant.helpers.entity_registry.async_get" @@ -252,6 +257,10 @@ async def test_loading_race_condition(hass: HomeAssistant) -> None: ) as mock_dev_registry, patch( "homeassistant.helpers.storage.Store.async_load", return_value=None ) as mock_load: + await store.async_load() + with pytest.raises(RuntimeError, match="Auth storage is already loaded"): + await store.async_load() + results = await asyncio.gather(store.async_get_users(), store.async_get_users()) mock_ent_registry.assert_called_once_with(hass) diff --git a/tests/auth/test_init.py b/tests/auth/test_init.py index 9e9b48a07f6..53c4c680700 100644 --- a/tests/auth/test_init.py +++ b/tests/auth/test_init.py @@ -343,6 +343,7 @@ async def test_saving_loading( await flush_store(manager._store._store) store2 = auth_store.AuthStore(hass) + await store2.async_load() users = await store2.async_get_users() assert len(users) == 1 assert users[0].permissions == user.permissions