diff --git a/homeassistant/auth.py b/homeassistant/auth.py index 5e434b74ca8..0c8346607ca 100644 --- a/homeassistant/auth.py +++ b/homeassistant/auth.py @@ -21,6 +21,8 @@ from homeassistant.util import dt as dt_util _LOGGER = logging.getLogger(__name__) +STORAGE_VERSION = 1 +STORAGE_KEY = 'auth' AUTH_PROVIDERS = Registry() @@ -121,23 +123,12 @@ class User: is_owner = attr.ib(type=bool, default=False) is_active = attr.ib(type=bool, default=False) name = attr.ib(type=str, default=None) - # For persisting and see if saved? - # store = attr.ib(type=AuthStore, default=None) # List of credentials of a user. - credentials = attr.ib(type=list, default=attr.Factory(list)) + credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False) # Tokens associated with a user. - refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict)) - - def as_dict(self): - """Convert user object to a dictionary.""" - return { - 'id': self.id, - 'is_owner': self.is_owner, - 'is_active': self.is_active, - 'name': self.name, - } + refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False) @attr.s(slots=True) @@ -152,7 +143,7 @@ class RefreshToken: default=ACCESS_TOKEN_EXPIRATION) token = attr.ib(type=str, default=attr.Factory(lambda: generate_secret(64))) - access_tokens = attr.ib(type=list, default=attr.Factory(list)) + access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False) @attr.s(slots=True) @@ -376,7 +367,7 @@ class AuthStore: self.hass = hass self.users = None self.clients = None - self._load_lock = asyncio.Lock(loop=hass.loop) + self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) async def credentials_for_provider(self, provider_type, provider_id): """Return credentials for specific auth provider type and id.""" @@ -494,10 +485,128 @@ class AuthStore: async def async_load(self): """Load the users.""" - async with self._load_lock: + data = await self._store.async_load() + + # Make sure that we're not overriding data if 2 loads happened at the + # same time + if self.users is not None: + return + + if data is None: self.users = {} self.clients = {} + return + + users = { + user_dict['id']: User(**user_dict) for user_dict in data['users'] + } + + for cred_dict in data['credentials']: + users[cred_dict['user_id']].credentials.append(Credentials( + id=cred_dict['id'], + is_new=False, + auth_provider_type=cred_dict['auth_provider_type'], + auth_provider_id=cred_dict['auth_provider_id'], + data=cred_dict['data'], + )) + + refresh_tokens = {} + + for rt_dict in data['refresh_tokens']: + token = RefreshToken( + id=rt_dict['id'], + user=users[rt_dict['user_id']], + client_id=rt_dict['client_id'], + created_at=dt_util.parse_datetime(rt_dict['created_at']), + access_token_expiration=timedelta( + rt_dict['access_token_expiration']), + token=rt_dict['token'], + ) + refresh_tokens[token.id] = token + users[rt_dict['user_id']].refresh_tokens[token.token] = token + + for ac_dict in data['access_tokens']: + refresh_token = refresh_tokens[ac_dict['refresh_token_id']] + token = AccessToken( + refresh_token=refresh_token, + created_at=dt_util.parse_datetime(ac_dict['created_at']), + token=ac_dict['token'], + ) + refresh_token.access_tokens.append(token) + + clients = { + cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients'] + } + + self.users = users + self.clients = clients async def async_save(self): """Save users.""" - pass + users = [ + { + 'id': user.id, + 'is_owner': user.is_owner, + 'is_active': user.is_active, + 'name': user.name, + } + for user in self.users.values() + ] + + credentials = [ + { + 'id': credential.id, + 'user_id': user.id, + 'auth_provider_type': credential.auth_provider_type, + 'auth_provider_id': credential.auth_provider_id, + 'data': credential.data, + } + for user in self.users.values() + for credential in user.credentials + ] + + refresh_tokens = [ + { + 'id': refresh_token.id, + 'user_id': user.id, + 'client_id': refresh_token.client_id, + 'created_at': refresh_token.created_at.isoformat(), + 'access_token_expiration': + refresh_token.access_token_expiration.total_seconds(), + 'token': refresh_token.token, + } + for user in self.users.values() + for refresh_token in user.refresh_tokens.values() + ] + + access_tokens = [ + { + 'id': user.id, + 'refresh_token_id': refresh_token.id, + 'created_at': access_token.created_at.isoformat(), + 'token': access_token.token, + } + for user in self.users.values() + for refresh_token in user.refresh_tokens.values() + for access_token in refresh_token.access_tokens + ] + + clients = [ + { + 'id': client.id, + 'name': client.name, + 'secret': client.secret, + 'redirect_uris': client.redirect_uris, + } + for client in self.clients.values() + ] + + data = { + 'users': users, + 'clients': clients, + 'credentials': credentials, + 'access_tokens': access_tokens, + 'refresh_tokens': refresh_tokens, + } + + await self._store.async_save(data, delay=1) diff --git a/homeassistant/helpers/storage.py b/homeassistant/helpers/storage.py index 18c3ddf7fcd..962074ec3af 100644 --- a/homeassistant/helpers/storage.py +++ b/homeassistant/helpers/storage.py @@ -53,6 +53,7 @@ class Store: self._unsub_delay_listener = None self._unsub_stop_listener = None self._write_lock = asyncio.Lock() + self._load_task = None @property def path(self): @@ -64,7 +65,17 @@ class Store: If the expected version does not match the given version, the migrate function will be invoked with await migrate_func(version, config). + + Will ensure that when a call comes in while another one is in progress, + the second call will wait and return the result of the first call. """ + if self._load_task is None: + self._load_task = self.hass.async_add_job(self._async_load()) + + return await self._load_task + + async def _async_load(self): + """Helper to load the data.""" if self._data is not None: data = self._data else: @@ -75,9 +86,15 @@ class Store: return None if data['version'] == self.version: - return data['data'] + stored = data['data'] + else: + _LOGGER.info('Migrating %s storage from %s to %s', + self.key, data['version'], self.version) + stored = await self._async_migrate_func( + data['version'], data['data']) - return await self._async_migrate_func(data['version'], data['data']) + self._load_task = None + return stored async def async_save(self, data: Dict, *, delay: Optional[int] = None): """Save data with an optional delay.""" diff --git a/tests/auth_providers/test_insecure_example.py b/tests/auth_providers/test_insecure_example.py index 0b481f93099..3377a60c45b 100644 --- a/tests/auth_providers/test_insecure_example.py +++ b/tests/auth_providers/test_insecure_example.py @@ -11,15 +11,15 @@ from tests.common import mock_coro @pytest.fixture -def store(): +def store(hass): """Mock store.""" - return auth.AuthStore(Mock()) + return auth.AuthStore(hass) @pytest.fixture -def provider(store): +def provider(hass, store): """Mock provider.""" - return insecure_example.ExampleAuthProvider(None, store, { + return insecure_example.ExampleAuthProvider(hass, store, { 'type': 'insecure_example', 'users': [ { diff --git a/tests/common.py b/tests/common.py index 56575bdb1e9..8eaee686b22 100644 --- a/tests/common.py +++ b/tests/common.py @@ -2,6 +2,7 @@ import asyncio from datetime import timedelta import functools as ft +import json import os import sys from unittest.mock import patch, MagicMock, Mock @@ -15,7 +16,7 @@ from homeassistant.setup import setup_component, async_setup_component from homeassistant.config import async_process_component_config from homeassistant.helpers import ( intent, entity, restore_state, entity_registry, - entity_platform) + entity_platform, storage) from homeassistant.util.unit_system import METRIC_SYSTEM import homeassistant.util.dt as date_util import homeassistant.util.yaml as yaml @@ -705,3 +706,51 @@ class MockEntity(entity.Entity): if attr in self._values: return self._values[attr] return getattr(super(), attr) + + +@contextmanager +def mock_storage(data=None): + """Mock storage. + + Data is a dict {'key': {'version': version, 'data': data}} + + Written data will be converted to JSON to ensure JSON parsing works. + """ + if data is None: + data = {} + + orig_load = storage.Store._async_load + + async def mock_async_load(store): + """Mock version of load.""" + if store._data is None: + # No data to load + if store.key not in data: + return None + + store._data = data.get(store.key) + + # Route through original load so that we trigger migration + loaded = await orig_load(store) + _LOGGER.info('Loading data for %s: %s', store.key, loaded) + return loaded + + def mock_write_data(store, path, data_to_write): + """Mock version of write data.""" + # To ensure that the data can be serialized + _LOGGER.info('Writing data to %s: %s', store.key, data_to_write) + data[store.key] = json.loads(json.dumps(data_to_write)) + + with patch('homeassistant.helpers.storage.Store._async_load', + side_effect=mock_async_load, autospec=True), \ + patch('homeassistant.helpers.storage.Store._write_data', + side_effect=mock_write_data, autospec=True): + yield data + + +async def flush_store(store): + """Make sure all delayed writes of a store are written.""" + if store._data is None: + return + + await store._async_handle_write_data() diff --git a/tests/conftest.py b/tests/conftest.py index 4d619c5ef61..0a350b62fc1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,8 @@ from homeassistant import util from homeassistant.util import location from tests.common import ( - async_test_home_assistant, INSTANCES, async_mock_mqtt_component, mock_coro) + async_test_home_assistant, INSTANCES, async_mock_mqtt_component, mock_coro, + mock_storage as mock_storage) from tests.test_util.aiohttp import mock_aiohttp_client from tests.mock.zwave import MockNetwork, MockOption @@ -59,7 +60,14 @@ def verify_cleanup(): @pytest.fixture -def hass(loop): +def hass_storage(): + """Fixture to mock storage.""" + with mock_storage() as stored_data: + yield stored_data + + +@pytest.fixture +def hass(loop, hass_storage): """Fixture to provide a test instance of HASS.""" hass = loop.run_until_complete(async_test_home_assistant(loop)) diff --git a/tests/helpers/test_storage.py b/tests/helpers/test_storage.py index 04de920b036..f414eaec97c 100644 --- a/tests/helpers/test_storage.py +++ b/tests/helpers/test_storage.py @@ -1,4 +1,5 @@ """Tests for the storage helper.""" +import asyncio from datetime import timedelta from unittest.mock import patch @@ -16,32 +17,13 @@ MOCK_KEY = 'storage-test' MOCK_DATA = {'hello': 'world'} -@pytest.fixture -def mock_save(): - """Fixture to mock JSON save.""" - written = [] - with patch('homeassistant.util.json.save_json', - side_effect=lambda *args: written.append(args)): - yield written - - -@pytest.fixture -def mock_load(mock_save): - """Fixture to mock JSON read.""" - with patch('homeassistant.util.json.load_json', - side_effect=lambda *args: mock_save[-1][1]): - yield - - @pytest.fixture def store(hass): """Fixture of a store that prevents writing on HASS stop.""" - store = storage.Store(hass, MOCK_VERSION, MOCK_KEY) - store._async_ensure_stop_listener = lambda: None - yield store + yield storage.Store(hass, MOCK_VERSION, MOCK_KEY) -async def test_loading(hass, store, mock_save, mock_load): +async def test_loading(hass, store): """Test we can save and load data.""" await store.async_save(MOCK_DATA) data = await store.async_load() @@ -55,55 +37,96 @@ async def test_loading_non_existing(hass, store): assert data is None -async def test_saving_with_delay(hass, store, mock_save): +async def test_loading_parallel(hass, store, hass_storage, caplog): + """Test we can save and load data.""" + hass_storage[store.key] = { + 'version': MOCK_VERSION, + 'data': MOCK_DATA, + } + + results = await asyncio.gather( + store.async_load(), + store.async_load() + ) + + assert results[0] is MOCK_DATA + assert results[1] is MOCK_DATA + assert caplog.text.count('Loading data for {}'.format(store.key)) + + +async def test_saving_with_delay(hass, store, hass_storage): """Test saving data after a delay.""" await store.async_save(MOCK_DATA, delay=1) - assert len(mock_save) == 0 + assert store.key not in hass_storage async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) await hass.async_block_till_done() - assert len(mock_save) == 1 + assert hass_storage[store.key] == { + 'version': MOCK_VERSION, + 'key': MOCK_KEY, + 'data': MOCK_DATA, + } -async def test_saving_on_stop(hass, mock_save): +async def test_saving_on_stop(hass, hass_storage): """Test delayed saves trigger when we quit Home Assistant.""" store = storage.Store(hass, MOCK_VERSION, MOCK_KEY) await store.async_save(MOCK_DATA, delay=1) - assert len(mock_save) == 0 + assert store.key not in hass_storage hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) await hass.async_block_till_done() - assert len(mock_save) == 1 + assert hass_storage[store.key] == { + 'version': MOCK_VERSION, + 'key': MOCK_KEY, + 'data': MOCK_DATA, + } -async def test_loading_while_delay(hass, store, mock_save, mock_load): +async def test_loading_while_delay(hass, store, hass_storage): """Test we load new data even if not written yet.""" await store.async_save({'delay': 'no'}) - assert len(mock_save) == 1 + assert hass_storage[store.key] == { + 'version': MOCK_VERSION, + 'key': MOCK_KEY, + 'data': {'delay': 'no'}, + } await store.async_save({'delay': 'yes'}, delay=1) - assert len(mock_save) == 1 + assert hass_storage[store.key] == { + 'version': MOCK_VERSION, + 'key': MOCK_KEY, + 'data': {'delay': 'no'}, + } data = await store.async_load() assert data == {'delay': 'yes'} -async def test_writing_while_writing_delay(hass, store, mock_save, mock_load): +async def test_writing_while_writing_delay(hass, store, hass_storage): """Test a write while a write with delay is active.""" await store.async_save({'delay': 'yes'}, delay=1) - assert len(mock_save) == 0 + assert store.key not in hass_storage await store.async_save({'delay': 'no'}) - assert len(mock_save) == 1 + assert hass_storage[store.key] == { + 'version': MOCK_VERSION, + 'key': MOCK_KEY, + 'data': {'delay': 'no'}, + } async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) await hass.async_block_till_done() - assert len(mock_save) == 1 + assert hass_storage[store.key] == { + 'version': MOCK_VERSION, + 'key': MOCK_KEY, + 'data': {'delay': 'no'}, + } data = await store.async_load() assert data == {'delay': 'no'} -async def test_migrator_no_existing_config(hass, store, mock_save): +async def test_migrator_no_existing_config(hass, store, hass_storage): """Test migrator with no existing config.""" with patch('os.path.isfile', return_value=False), \ patch.object(store, 'async_load', @@ -112,10 +135,10 @@ async def test_migrator_no_existing_config(hass, store, mock_save): hass, 'old-path', store) assert data == {'cur': 'config'} - assert len(mock_save) == 0 + assert store.key not in hass_storage -async def test_migrator_existing_config(hass, store, mock_save): +async def test_migrator_existing_config(hass, store, hass_storage): """Test migrating existing config.""" with patch('os.path.isfile', return_value=True), \ patch('os.remove') as mock_remove, \ @@ -126,15 +149,14 @@ async def test_migrator_existing_config(hass, store, mock_save): assert len(mock_remove.mock_calls) == 1 assert data == {'old': 'config'} - assert len(mock_save) == 1 - assert mock_save[0][1] == { + assert hass_storage[store.key] == { 'key': MOCK_KEY, 'version': MOCK_VERSION, 'data': data, } -async def test_migrator_transforming_config(hass, store, mock_save): +async def test_migrator_transforming_config(hass, store, hass_storage): """Test migrating config to new format.""" async def old_conf_migrate_func(old_config): """Migrate old config to new format.""" @@ -150,8 +172,7 @@ async def test_migrator_transforming_config(hass, store, mock_save): assert len(mock_remove.mock_calls) == 1 assert data == {'new': 'config'} - assert len(mock_save) == 1 - assert mock_save[0][1] == { + assert hass_storage[store.key] == { 'key': MOCK_KEY, 'version': MOCK_VERSION, 'data': data, diff --git a/tests/test_auth.py b/tests/test_auth.py index 4bbf218fd23..116f92ca817 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -4,7 +4,7 @@ from unittest.mock import Mock import pytest from homeassistant import auth, data_entry_flow -from tests.common import MockUser, ensure_auth_manager_loaded +from tests.common import MockUser, ensure_auth_manager_loaded, flush_store @pytest.fixture @@ -53,9 +53,9 @@ async def test_auth_manager_from_config_validates_config_and_id(mock_hass): }] -async def test_create_new_user(mock_hass): +async def test_create_new_user(hass, hass_storage): """Test creating new user.""" - manager = await auth.auth_manager_from_config(mock_hass, [{ + manager = await auth.auth_manager_from_config(hass, [{ 'type': 'insecure_example', 'users': [{ 'username': 'test-user', @@ -124,9 +124,9 @@ async def test_login_as_existing_user(mock_hass): assert user.name == 'Paulus' -async def test_linking_user_to_two_auth_providers(mock_hass): +async def test_linking_user_to_two_auth_providers(hass, hass_storage): """Test linking user to two auth providers.""" - manager = await auth.auth_manager_from_config(mock_hass, [{ + manager = await auth.auth_manager_from_config(hass, [{ 'type': 'insecure_example', 'users': [{ 'username': 'test-user', @@ -157,3 +157,41 @@ async def test_linking_user_to_two_auth_providers(mock_hass): }) await manager.async_link_user(user, step['result']) assert len(user.credentials) == 2 + + +async def test_saving_loading(hass, hass_storage): + """Test storing and saving data. + + Creates one of each type that we store to test we restore correctly. + """ + manager = await auth.auth_manager_from_config(hass, [{ + 'type': 'insecure_example', + 'users': [{ + 'username': 'test-user', + 'password': 'test-pass', + }] + }]) + + step = await manager.login_flow.async_init(('insecure_example', None)) + step = await manager.login_flow.async_configure(step['flow_id'], { + 'username': 'test-user', + 'password': 'test-pass', + }) + user = await manager.async_get_or_create_user(step['result']) + + client = await manager.async_create_client( + 'test', redirect_uris=['https://example.com']) + + refresh_token = await manager.async_create_refresh_token(user, client.id) + + manager.async_create_access_token(refresh_token) + + await flush_store(manager._store._store) + + store2 = auth.AuthStore(hass) + await store2.async_load() + assert len(store2.users) == 1 + assert store2.users[user.id] == user + + assert len(store2.clients) == 1 + assert store2.clients[client.id] == client diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index b65e0dd62e7..d7a7ec4b82b 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -1,7 +1,7 @@ """Test the config manager.""" import asyncio from datetime import timedelta -from unittest.mock import MagicMock, patch, mock_open +from unittest.mock import MagicMock, patch import pytest @@ -152,8 +152,7 @@ def test_domains_gets_uniques(manager): assert manager.async_domains() == ['test', 'test2', 'test3'] -@asyncio.coroutine -def test_saving_and_loading(hass): +async def test_saving_and_loading(hass): """Test that we're saving and loading correctly.""" loader.set_component( hass, 'test', @@ -172,7 +171,7 @@ def test_saving_and_loading(hass): ) with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): - yield from hass.config_entries.flow.async_init('test') + await hass.config_entries.flow.async_init('test') class Test2Flow(data_entry_flow.FlowHandler): VERSION = 3 @@ -186,27 +185,18 @@ def test_saving_and_loading(hass): } ) - json_path = 'homeassistant.util.json.open' - with patch('homeassistant.config_entries.HANDLERS.get', return_value=Test2Flow): - yield from hass.config_entries.flow.async_init('test') + await hass.config_entries.flow.async_init('test') - with patch(json_path, mock_open(), create=True) as mock_write: - # To trigger the call_later - async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) - # To execute the save - yield from hass.async_block_till_done() - - # Mock open calls are: open file, context enter, write, context leave - written = mock_write.mock_calls[2][1][0] + # To trigger the call_later + async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) + # To execute the save + await hass.async_block_till_done() # Now load written data in new config manager manager = config_entries.ConfigEntries(hass, {}) - - with patch('os.path.isfile', return_value=False), \ - patch(json_path, mock_open(read_data=written), create=True): - yield from manager.async_load() + await manager.async_load() # Ensure same order for orig, loaded in zip(hass.config_entries.async_entries(),