Storage auth (#15192)

* Support parallel loading

* Add storage mock

* Store auth

* Fix tests
This commit is contained in:
Paulus Schoutsen 2018-06-28 22:14:26 -04:00 committed by GitHub
parent a277470363
commit 2205090795
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 324 additions and 92 deletions

View File

@ -21,6 +21,8 @@ from homeassistant.util import dt as dt_util
_LOGGER = logging.getLogger(__name__)
STORAGE_VERSION = 1
STORAGE_KEY = 'auth'
AUTH_PROVIDERS = Registry()
@ -121,23 +123,12 @@ class User:
is_owner = attr.ib(type=bool, default=False)
is_active = attr.ib(type=bool, default=False)
name = attr.ib(type=str, default=None)
# For persisting and see if saved?
# store = attr.ib(type=AuthStore, default=None)
# List of credentials of a user.
credentials = attr.ib(type=list, default=attr.Factory(list))
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
# Tokens associated with a user.
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict))
def as_dict(self):
"""Convert user object to a dictionary."""
return {
'id': self.id,
'is_owner': self.is_owner,
'is_active': self.is_active,
'name': self.name,
}
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
@attr.s(slots=True)
@ -152,7 +143,7 @@ class RefreshToken:
default=ACCESS_TOKEN_EXPIRATION)
token = attr.ib(type=str,
default=attr.Factory(lambda: generate_secret(64)))
access_tokens = attr.ib(type=list, default=attr.Factory(list))
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
@attr.s(slots=True)
@ -376,7 +367,7 @@ class AuthStore:
self.hass = hass
self.users = None
self.clients = None
self._load_lock = asyncio.Lock(loop=hass.loop)
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
async def credentials_for_provider(self, provider_type, provider_id):
"""Return credentials for specific auth provider type and id."""
@ -494,10 +485,128 @@ class AuthStore:
async def async_load(self):
"""Load the users."""
async with self._load_lock:
data = await self._store.async_load()
# Make sure that we're not overriding data if 2 loads happened at the
# same time
if self.users is not None:
return
if data is None:
self.users = {}
self.clients = {}
return
users = {
user_dict['id']: User(**user_dict) for user_dict in data['users']
}
for cred_dict in data['credentials']:
users[cred_dict['user_id']].credentials.append(Credentials(
id=cred_dict['id'],
is_new=False,
auth_provider_type=cred_dict['auth_provider_type'],
auth_provider_id=cred_dict['auth_provider_id'],
data=cred_dict['data'],
))
refresh_tokens = {}
for rt_dict in data['refresh_tokens']:
token = RefreshToken(
id=rt_dict['id'],
user=users[rt_dict['user_id']],
client_id=rt_dict['client_id'],
created_at=dt_util.parse_datetime(rt_dict['created_at']),
access_token_expiration=timedelta(
rt_dict['access_token_expiration']),
token=rt_dict['token'],
)
refresh_tokens[token.id] = token
users[rt_dict['user_id']].refresh_tokens[token.token] = token
for ac_dict in data['access_tokens']:
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
token = AccessToken(
refresh_token=refresh_token,
created_at=dt_util.parse_datetime(ac_dict['created_at']),
token=ac_dict['token'],
)
refresh_token.access_tokens.append(token)
clients = {
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
}
self.users = users
self.clients = clients
async def async_save(self):
"""Save users."""
pass
users = [
{
'id': user.id,
'is_owner': user.is_owner,
'is_active': user.is_active,
'name': user.name,
}
for user in self.users.values()
]
credentials = [
{
'id': credential.id,
'user_id': user.id,
'auth_provider_type': credential.auth_provider_type,
'auth_provider_id': credential.auth_provider_id,
'data': credential.data,
}
for user in self.users.values()
for credential in user.credentials
]
refresh_tokens = [
{
'id': refresh_token.id,
'user_id': user.id,
'client_id': refresh_token.client_id,
'created_at': refresh_token.created_at.isoformat(),
'access_token_expiration':
refresh_token.access_token_expiration.total_seconds(),
'token': refresh_token.token,
}
for user in self.users.values()
for refresh_token in user.refresh_tokens.values()
]
access_tokens = [
{
'id': user.id,
'refresh_token_id': refresh_token.id,
'created_at': access_token.created_at.isoformat(),
'token': access_token.token,
}
for user in self.users.values()
for refresh_token in user.refresh_tokens.values()
for access_token in refresh_token.access_tokens
]
clients = [
{
'id': client.id,
'name': client.name,
'secret': client.secret,
'redirect_uris': client.redirect_uris,
}
for client in self.clients.values()
]
data = {
'users': users,
'clients': clients,
'credentials': credentials,
'access_tokens': access_tokens,
'refresh_tokens': refresh_tokens,
}
await self._store.async_save(data, delay=1)

View File

@ -53,6 +53,7 @@ class Store:
self._unsub_delay_listener = None
self._unsub_stop_listener = None
self._write_lock = asyncio.Lock()
self._load_task = None
@property
def path(self):
@ -64,7 +65,17 @@ class Store:
If the expected version does not match the given version, the migrate
function will be invoked with await migrate_func(version, config).
Will ensure that when a call comes in while another one is in progress,
the second call will wait and return the result of the first call.
"""
if self._load_task is None:
self._load_task = self.hass.async_add_job(self._async_load())
return await self._load_task
async def _async_load(self):
"""Helper to load the data."""
if self._data is not None:
data = self._data
else:
@ -75,9 +86,15 @@ class Store:
return None
if data['version'] == self.version:
return data['data']
stored = data['data']
else:
_LOGGER.info('Migrating %s storage from %s to %s',
self.key, data['version'], self.version)
stored = await self._async_migrate_func(
data['version'], data['data'])
return await self._async_migrate_func(data['version'], data['data'])
self._load_task = None
return stored
async def async_save(self, data: Dict, *, delay: Optional[int] = None):
"""Save data with an optional delay."""

View File

@ -11,15 +11,15 @@ from tests.common import mock_coro
@pytest.fixture
def store():
def store(hass):
"""Mock store."""
return auth.AuthStore(Mock())
return auth.AuthStore(hass)
@pytest.fixture
def provider(store):
def provider(hass, store):
"""Mock provider."""
return insecure_example.ExampleAuthProvider(None, store, {
return insecure_example.ExampleAuthProvider(hass, store, {
'type': 'insecure_example',
'users': [
{

View File

@ -2,6 +2,7 @@
import asyncio
from datetime import timedelta
import functools as ft
import json
import os
import sys
from unittest.mock import patch, MagicMock, Mock
@ -15,7 +16,7 @@ from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config
from homeassistant.helpers import (
intent, entity, restore_state, entity_registry,
entity_platform)
entity_platform, storage)
from homeassistant.util.unit_system import METRIC_SYSTEM
import homeassistant.util.dt as date_util
import homeassistant.util.yaml as yaml
@ -705,3 +706,51 @@ class MockEntity(entity.Entity):
if attr in self._values:
return self._values[attr]
return getattr(super(), attr)
@contextmanager
def mock_storage(data=None):
"""Mock storage.
Data is a dict {'key': {'version': version, 'data': data}}
Written data will be converted to JSON to ensure JSON parsing works.
"""
if data is None:
data = {}
orig_load = storage.Store._async_load
async def mock_async_load(store):
"""Mock version of load."""
if store._data is None:
# No data to load
if store.key not in data:
return None
store._data = data.get(store.key)
# Route through original load so that we trigger migration
loaded = await orig_load(store)
_LOGGER.info('Loading data for %s: %s', store.key, loaded)
return loaded
def mock_write_data(store, path, data_to_write):
"""Mock version of write data."""
# To ensure that the data can be serialized
_LOGGER.info('Writing data to %s: %s', store.key, data_to_write)
data[store.key] = json.loads(json.dumps(data_to_write))
with patch('homeassistant.helpers.storage.Store._async_load',
side_effect=mock_async_load, autospec=True), \
patch('homeassistant.helpers.storage.Store._write_data',
side_effect=mock_write_data, autospec=True):
yield data
async def flush_store(store):
"""Make sure all delayed writes of a store are written."""
if store._data is None:
return
await store._async_handle_write_data()

View File

@ -12,7 +12,8 @@ from homeassistant import util
from homeassistant.util import location
from tests.common import (
async_test_home_assistant, INSTANCES, async_mock_mqtt_component, mock_coro)
async_test_home_assistant, INSTANCES, async_mock_mqtt_component, mock_coro,
mock_storage as mock_storage)
from tests.test_util.aiohttp import mock_aiohttp_client
from tests.mock.zwave import MockNetwork, MockOption
@ -59,7 +60,14 @@ def verify_cleanup():
@pytest.fixture
def hass(loop):
def hass_storage():
"""Fixture to mock storage."""
with mock_storage() as stored_data:
yield stored_data
@pytest.fixture
def hass(loop, hass_storage):
"""Fixture to provide a test instance of HASS."""
hass = loop.run_until_complete(async_test_home_assistant(loop))

View File

@ -1,4 +1,5 @@
"""Tests for the storage helper."""
import asyncio
from datetime import timedelta
from unittest.mock import patch
@ -16,32 +17,13 @@ MOCK_KEY = 'storage-test'
MOCK_DATA = {'hello': 'world'}
@pytest.fixture
def mock_save():
"""Fixture to mock JSON save."""
written = []
with patch('homeassistant.util.json.save_json',
side_effect=lambda *args: written.append(args)):
yield written
@pytest.fixture
def mock_load(mock_save):
"""Fixture to mock JSON read."""
with patch('homeassistant.util.json.load_json',
side_effect=lambda *args: mock_save[-1][1]):
yield
@pytest.fixture
def store(hass):
"""Fixture of a store that prevents writing on HASS stop."""
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
store._async_ensure_stop_listener = lambda: None
yield store
yield storage.Store(hass, MOCK_VERSION, MOCK_KEY)
async def test_loading(hass, store, mock_save, mock_load):
async def test_loading(hass, store):
"""Test we can save and load data."""
await store.async_save(MOCK_DATA)
data = await store.async_load()
@ -55,55 +37,96 @@ async def test_loading_non_existing(hass, store):
assert data is None
async def test_saving_with_delay(hass, store, mock_save):
async def test_loading_parallel(hass, store, hass_storage, caplog):
"""Test we can save and load data."""
hass_storage[store.key] = {
'version': MOCK_VERSION,
'data': MOCK_DATA,
}
results = await asyncio.gather(
store.async_load(),
store.async_load()
)
assert results[0] is MOCK_DATA
assert results[1] is MOCK_DATA
assert caplog.text.count('Loading data for {}'.format(store.key))
async def test_saving_with_delay(hass, store, hass_storage):
"""Test saving data after a delay."""
await store.async_save(MOCK_DATA, delay=1)
assert len(mock_save) == 0
assert store.key not in hass_storage
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
await hass.async_block_till_done()
assert len(mock_save) == 1
assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': MOCK_DATA,
}
async def test_saving_on_stop(hass, mock_save):
async def test_saving_on_stop(hass, hass_storage):
"""Test delayed saves trigger when we quit Home Assistant."""
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
await store.async_save(MOCK_DATA, delay=1)
assert len(mock_save) == 0
assert store.key not in hass_storage
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
assert len(mock_save) == 1
assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': MOCK_DATA,
}
async def test_loading_while_delay(hass, store, mock_save, mock_load):
async def test_loading_while_delay(hass, store, hass_storage):
"""Test we load new data even if not written yet."""
await store.async_save({'delay': 'no'})
assert len(mock_save) == 1
assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': {'delay': 'no'},
}
await store.async_save({'delay': 'yes'}, delay=1)
assert len(mock_save) == 1
assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': {'delay': 'no'},
}
data = await store.async_load()
assert data == {'delay': 'yes'}
async def test_writing_while_writing_delay(hass, store, mock_save, mock_load):
async def test_writing_while_writing_delay(hass, store, hass_storage):
"""Test a write while a write with delay is active."""
await store.async_save({'delay': 'yes'}, delay=1)
assert len(mock_save) == 0
assert store.key not in hass_storage
await store.async_save({'delay': 'no'})
assert len(mock_save) == 1
assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': {'delay': 'no'},
}
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
await hass.async_block_till_done()
assert len(mock_save) == 1
assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': {'delay': 'no'},
}
data = await store.async_load()
assert data == {'delay': 'no'}
async def test_migrator_no_existing_config(hass, store, mock_save):
async def test_migrator_no_existing_config(hass, store, hass_storage):
"""Test migrator with no existing config."""
with patch('os.path.isfile', return_value=False), \
patch.object(store, 'async_load',
@ -112,10 +135,10 @@ async def test_migrator_no_existing_config(hass, store, mock_save):
hass, 'old-path', store)
assert data == {'cur': 'config'}
assert len(mock_save) == 0
assert store.key not in hass_storage
async def test_migrator_existing_config(hass, store, mock_save):
async def test_migrator_existing_config(hass, store, hass_storage):
"""Test migrating existing config."""
with patch('os.path.isfile', return_value=True), \
patch('os.remove') as mock_remove, \
@ -126,15 +149,14 @@ async def test_migrator_existing_config(hass, store, mock_save):
assert len(mock_remove.mock_calls) == 1
assert data == {'old': 'config'}
assert len(mock_save) == 1
assert mock_save[0][1] == {
assert hass_storage[store.key] == {
'key': MOCK_KEY,
'version': MOCK_VERSION,
'data': data,
}
async def test_migrator_transforming_config(hass, store, mock_save):
async def test_migrator_transforming_config(hass, store, hass_storage):
"""Test migrating config to new format."""
async def old_conf_migrate_func(old_config):
"""Migrate old config to new format."""
@ -150,8 +172,7 @@ async def test_migrator_transforming_config(hass, store, mock_save):
assert len(mock_remove.mock_calls) == 1
assert data == {'new': 'config'}
assert len(mock_save) == 1
assert mock_save[0][1] == {
assert hass_storage[store.key] == {
'key': MOCK_KEY,
'version': MOCK_VERSION,
'data': data,

View File

@ -4,7 +4,7 @@ from unittest.mock import Mock
import pytest
from homeassistant import auth, data_entry_flow
from tests.common import MockUser, ensure_auth_manager_loaded
from tests.common import MockUser, ensure_auth_manager_loaded, flush_store
@pytest.fixture
@ -53,9 +53,9 @@ async def test_auth_manager_from_config_validates_config_and_id(mock_hass):
}]
async def test_create_new_user(mock_hass):
async def test_create_new_user(hass, hass_storage):
"""Test creating new user."""
manager = await auth.auth_manager_from_config(mock_hass, [{
manager = await auth.auth_manager_from_config(hass, [{
'type': 'insecure_example',
'users': [{
'username': 'test-user',
@ -124,9 +124,9 @@ async def test_login_as_existing_user(mock_hass):
assert user.name == 'Paulus'
async def test_linking_user_to_two_auth_providers(mock_hass):
async def test_linking_user_to_two_auth_providers(hass, hass_storage):
"""Test linking user to two auth providers."""
manager = await auth.auth_manager_from_config(mock_hass, [{
manager = await auth.auth_manager_from_config(hass, [{
'type': 'insecure_example',
'users': [{
'username': 'test-user',
@ -157,3 +157,41 @@ async def test_linking_user_to_two_auth_providers(mock_hass):
})
await manager.async_link_user(user, step['result'])
assert len(user.credentials) == 2
async def test_saving_loading(hass, hass_storage):
"""Test storing and saving data.
Creates one of each type that we store to test we restore correctly.
"""
manager = await auth.auth_manager_from_config(hass, [{
'type': 'insecure_example',
'users': [{
'username': 'test-user',
'password': 'test-pass',
}]
}])
step = await manager.login_flow.async_init(('insecure_example', None))
step = await manager.login_flow.async_configure(step['flow_id'], {
'username': 'test-user',
'password': 'test-pass',
})
user = await manager.async_get_or_create_user(step['result'])
client = await manager.async_create_client(
'test', redirect_uris=['https://example.com'])
refresh_token = await manager.async_create_refresh_token(user, client.id)
manager.async_create_access_token(refresh_token)
await flush_store(manager._store._store)
store2 = auth.AuthStore(hass)
await store2.async_load()
assert len(store2.users) == 1
assert store2.users[user.id] == user
assert len(store2.clients) == 1
assert store2.clients[client.id] == client

View File

@ -1,7 +1,7 @@
"""Test the config manager."""
import asyncio
from datetime import timedelta
from unittest.mock import MagicMock, patch, mock_open
from unittest.mock import MagicMock, patch
import pytest
@ -152,8 +152,7 @@ def test_domains_gets_uniques(manager):
assert manager.async_domains() == ['test', 'test2', 'test3']
@asyncio.coroutine
def test_saving_and_loading(hass):
async def test_saving_and_loading(hass):
"""Test that we're saving and loading correctly."""
loader.set_component(
hass, 'test',
@ -172,7 +171,7 @@ def test_saving_and_loading(hass):
)
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
yield from hass.config_entries.flow.async_init('test')
await hass.config_entries.flow.async_init('test')
class Test2Flow(data_entry_flow.FlowHandler):
VERSION = 3
@ -186,27 +185,18 @@ def test_saving_and_loading(hass):
}
)
json_path = 'homeassistant.util.json.open'
with patch('homeassistant.config_entries.HANDLERS.get',
return_value=Test2Flow):
yield from hass.config_entries.flow.async_init('test')
await hass.config_entries.flow.async_init('test')
with patch(json_path, mock_open(), create=True) as mock_write:
# To trigger the call_later
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
# To execute the save
yield from hass.async_block_till_done()
# Mock open calls are: open file, context enter, write, context leave
written = mock_write.mock_calls[2][1][0]
# To trigger the call_later
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
# To execute the save
await hass.async_block_till_done()
# Now load written data in new config manager
manager = config_entries.ConfigEntries(hass, {})
with patch('os.path.isfile', return_value=False), \
patch(json_path, mock_open(read_data=written), create=True):
yield from manager.async_load()
await manager.async_load()
# Ensure same order
for orig, loaded in zip(hass.config_entries.async_entries(),