From 2ad0bd40362841646e097885098f0864c91220b4 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 17 Aug 2018 20:18:21 +0200 Subject: [PATCH] Split out storage delay save (#16017) * Split out storage delayed write * Update code using delayed save * Fix tests * Fix typing test * Add callback decorator --- homeassistant/auth/auth_store.py | 49 ++++++++++++++++++-------------- homeassistant/config_entries.py | 14 +++++---- homeassistant/helpers/storage.py | 32 ++++++++++++++++----- tests/helpers/test_storage.py | 8 +++--- 4 files changed, 65 insertions(+), 38 deletions(-) diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index 07ab40ceaea..5b26cf2f5f8 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -5,7 +5,7 @@ from logging import getLogger from typing import Any, Dict, List, Optional # noqa: F401 import hmac -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.util import dt as dt_util from . import models @@ -32,7 +32,7 @@ class AuthStore: async def async_get_users(self) -> List[models.User]: """Retrieve all users.""" if self._users is None: - await self.async_load() + await self._async_load() assert self._users is not None return list(self._users.values()) @@ -40,7 +40,7 @@ class AuthStore: async def async_get_user(self, user_id: str) -> Optional[models.User]: """Retrieve a user by id.""" if self._users is None: - await self.async_load() + await self._async_load() assert self._users is not None return self._users.get(user_id) @@ -52,7 +52,7 @@ class AuthStore: credentials: Optional[models.Credentials] = None) -> models.User: """Create a new user.""" if self._users is None: - await self.async_load() + await self._async_load() assert self._users is not None kwargs = { @@ -73,7 +73,7 @@ class AuthStore: self._users[new_user.id] = new_user if credentials is None: - await self.async_save() + self._async_schedule_save() return new_user # Saving is done inside the link. @@ -84,33 +84,33 @@ class AuthStore: credentials: models.Credentials) -> None: """Add credentials to an existing user.""" user.credentials.append(credentials) - await self.async_save() + self._async_schedule_save() credentials.is_new = False async def async_remove_user(self, user: models.User) -> None: """Remove a user.""" if self._users is None: - await self.async_load() + await self._async_load() assert self._users is not None self._users.pop(user.id) - await self.async_save() + self._async_schedule_save() async def async_activate_user(self, user: models.User) -> None: """Activate a user.""" user.is_active = True - await self.async_save() + self._async_schedule_save() async def async_deactivate_user(self, user: models.User) -> None: """Activate a user.""" user.is_active = False - await self.async_save() + self._async_schedule_save() async def async_remove_credentials( self, credentials: models.Credentials) -> None: """Remove credentials.""" if self._users is None: - await self.async_load() + await self._async_load() assert self._users is not None for user in self._users.values(): @@ -125,7 +125,7 @@ class AuthStore: user.credentials.pop(found) break - await self.async_save() + self._async_schedule_save() async def async_create_refresh_token( self, user: models.User, client_id: Optional[str] = None) \ @@ -133,14 +133,14 @@ class AuthStore: """Create a new token for a user.""" refresh_token = models.RefreshToken(user=user, client_id=client_id) user.refresh_tokens[refresh_token.id] = refresh_token - await self.async_save() + self._async_schedule_save() return refresh_token async def async_get_refresh_token( self, token_id: str) -> Optional[models.RefreshToken]: """Get refresh token by id.""" if self._users is None: - await self.async_load() + await self._async_load() assert self._users is not None for user in self._users.values(): @@ -154,7 +154,7 @@ class AuthStore: self, token: str) -> Optional[models.RefreshToken]: """Get refresh token by token.""" if self._users is None: - await self.async_load() + await self._async_load() assert self._users is not None found = None @@ -166,7 +166,7 @@ class AuthStore: return found - async def async_load(self) -> None: + async def _async_load(self) -> None: """Load the users.""" data = await self._store.async_load() @@ -218,11 +218,18 @@ class AuthStore: self._users = users - async def async_save(self) -> None: + @callback + def _async_schedule_save(self) -> None: """Save users.""" if self._users is None: - await self.async_load() - assert self._users is not None + return + + self._store.async_delay_save(self._data_to_save, 1) + + @callback + def _data_to_save(self) -> Dict: + """Return the data to store.""" + assert self._users is not None users = [ { @@ -262,10 +269,8 @@ class AuthStore: for refresh_token in user.refresh_tokens.values() ] - data = { + return { 'users': users, 'credentials': credentials, 'refresh_tokens': refresh_tokens, } - - await self._store.async_save(data, delay=1) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index b2e8389e449..e9c5bc07e57 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -320,7 +320,7 @@ class ConfigEntries: raise UnknownEntry entry = self._entries.pop(found) - await self._async_schedule_save() + self._async_schedule_save() unloaded = await entry.async_unload(self.hass) @@ -391,7 +391,7 @@ class ConfigEntries: source=context['source'], ) self._entries.append(entry) - await self._async_schedule_save() + self._async_schedule_save() # Setup entry if entry.domain in self.hass.config.components: @@ -439,12 +439,16 @@ class ConfigEntries: flow.init_step = source return flow - async def _async_schedule_save(self): + def _async_schedule_save(self): """Save the entity registry to a file.""" - data = { + self._store.async_delay_save(self._data_to_save, SAVE_DELAY) + + @callback + def _data_to_save(self): + """Return data to save.""" + return { 'entries': [entry.as_dict() for entry in self._entries] } - await self._store.async_save(data, delay=SAVE_DELAY) async def _old_conf_migrator(old_config): diff --git a/homeassistant/helpers/storage.py b/homeassistant/helpers/storage.py index a68b489868d..47d182d9a7c 100644 --- a/homeassistant/helpers/storage.py +++ b/homeassistant/helpers/storage.py @@ -2,7 +2,7 @@ import asyncio import logging import os -from typing import Dict, Optional +from typing import Dict, Optional, Callable from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import callback @@ -76,8 +76,13 @@ class Store: async def _async_load(self): """Helper to load the data.""" + # Check if we have a pending write if self._data is not None: data = self._data + + # If we didn't generate data yet, do it now. + if 'data_func' in data: + data['data'] = data.pop('data_func')() else: data = await self.hass.async_add_executor_job( json.load_json, self.path) @@ -95,8 +100,8 @@ class Store: self._load_task = None return stored - async def async_save(self, data: Dict, *, delay: Optional[int] = None): - """Save data with an optional delay.""" + async def async_save(self, data): + """Save data.""" self._data = { 'version': self.version, 'key': self.key, @@ -104,11 +109,20 @@ class Store: } self._async_cleanup_delay_listener() + self._async_cleanup_stop_listener() + await self._async_handle_write_data() - if delay is None: - self._async_cleanup_stop_listener() - await self._async_handle_write_data() - return + @callback + def async_delay_save(self, data_func: Callable[[], Dict], + delay: Optional[int] = None): + """Save data with an optional delay.""" + self._data = { + 'version': self.version, + 'key': self.key, + 'data_func': data_func, + } + + self._async_cleanup_delay_listener() self._unsub_delay_listener = async_call_later( self.hass, delay, self._async_callback_delayed_write) @@ -151,6 +165,10 @@ class Store: async def _async_handle_write_data(self, *_args): """Handler to handle writing the config.""" data = self._data + + if 'data_func' in data: + data['data'] = data.pop('data_func')() + self._data = None async with self._write_lock: diff --git a/tests/helpers/test_storage.py b/tests/helpers/test_storage.py index f414eaec97c..b35b2596802 100644 --- a/tests/helpers/test_storage.py +++ b/tests/helpers/test_storage.py @@ -56,7 +56,7 @@ async def test_loading_parallel(hass, store, hass_storage, caplog): async def test_saving_with_delay(hass, store, hass_storage): """Test saving data after a delay.""" - await store.async_save(MOCK_DATA, delay=1) + store.async_delay_save(lambda: MOCK_DATA, 1) assert store.key not in hass_storage async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) @@ -71,7 +71,7 @@ async def test_saving_with_delay(hass, store, hass_storage): async def test_saving_on_stop(hass, hass_storage): """Test delayed saves trigger when we quit Home Assistant.""" store = storage.Store(hass, MOCK_VERSION, MOCK_KEY) - await store.async_save(MOCK_DATA, delay=1) + store.async_delay_save(lambda: MOCK_DATA, 1) assert store.key not in hass_storage hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) @@ -92,7 +92,7 @@ async def test_loading_while_delay(hass, store, hass_storage): 'data': {'delay': 'no'}, } - await store.async_save({'delay': 'yes'}, delay=1) + store.async_delay_save(lambda: {'delay': 'yes'}, 1) assert hass_storage[store.key] == { 'version': MOCK_VERSION, 'key': MOCK_KEY, @@ -105,7 +105,7 @@ async def test_loading_while_delay(hass, store, hass_storage): async def test_writing_while_writing_delay(hass, store, hass_storage): """Test a write while a write with delay is active.""" - await store.async_save({'delay': 'yes'}, delay=1) + store.async_delay_save(lambda: {'delay': 'yes'}, 1) assert store.key not in hass_storage await store.async_save({'delay': 'no'}) assert hass_storage[store.key] == {