Storage auth (#15192)

* Support parallel loading

* Add storage mock

* Store auth

* Fix tests
This commit is contained in:
Paulus Schoutsen 2018-06-28 22:14:26 -04:00 committed by GitHub
parent a277470363
commit 2205090795
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 324 additions and 92 deletions

View File

@ -21,6 +21,8 @@ from homeassistant.util import dt as dt_util
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
STORAGE_VERSION = 1
STORAGE_KEY = 'auth'
AUTH_PROVIDERS = Registry() AUTH_PROVIDERS = Registry()
@ -121,23 +123,12 @@ class User:
is_owner = attr.ib(type=bool, default=False) is_owner = attr.ib(type=bool, default=False)
is_active = attr.ib(type=bool, default=False) is_active = attr.ib(type=bool, default=False)
name = attr.ib(type=str, default=None) name = attr.ib(type=str, default=None)
# For persisting and see if saved?
# store = attr.ib(type=AuthStore, default=None)
# List of credentials of a user. # List of credentials of a user.
credentials = attr.ib(type=list, default=attr.Factory(list)) credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
# Tokens associated with a user. # Tokens associated with a user.
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict)) refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
def as_dict(self):
"""Convert user object to a dictionary."""
return {
'id': self.id,
'is_owner': self.is_owner,
'is_active': self.is_active,
'name': self.name,
}
@attr.s(slots=True) @attr.s(slots=True)
@ -152,7 +143,7 @@ class RefreshToken:
default=ACCESS_TOKEN_EXPIRATION) default=ACCESS_TOKEN_EXPIRATION)
token = attr.ib(type=str, token = attr.ib(type=str,
default=attr.Factory(lambda: generate_secret(64))) default=attr.Factory(lambda: generate_secret(64)))
access_tokens = attr.ib(type=list, default=attr.Factory(list)) access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
@attr.s(slots=True) @attr.s(slots=True)
@ -376,7 +367,7 @@ class AuthStore:
self.hass = hass self.hass = hass
self.users = None self.users = None
self.clients = None self.clients = None
self._load_lock = asyncio.Lock(loop=hass.loop) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
async def credentials_for_provider(self, provider_type, provider_id): async def credentials_for_provider(self, provider_type, provider_id):
"""Return credentials for specific auth provider type and id.""" """Return credentials for specific auth provider type and id."""
@ -494,10 +485,128 @@ class AuthStore:
async def async_load(self): async def async_load(self):
"""Load the users.""" """Load the users."""
async with self._load_lock: data = await self._store.async_load()
# Make sure that we're not overriding data if 2 loads happened at the
# same time
if self.users is not None:
return
if data is None:
self.users = {} self.users = {}
self.clients = {} self.clients = {}
return
users = {
user_dict['id']: User(**user_dict) for user_dict in data['users']
}
for cred_dict in data['credentials']:
users[cred_dict['user_id']].credentials.append(Credentials(
id=cred_dict['id'],
is_new=False,
auth_provider_type=cred_dict['auth_provider_type'],
auth_provider_id=cred_dict['auth_provider_id'],
data=cred_dict['data'],
))
refresh_tokens = {}
for rt_dict in data['refresh_tokens']:
token = RefreshToken(
id=rt_dict['id'],
user=users[rt_dict['user_id']],
client_id=rt_dict['client_id'],
created_at=dt_util.parse_datetime(rt_dict['created_at']),
access_token_expiration=timedelta(
rt_dict['access_token_expiration']),
token=rt_dict['token'],
)
refresh_tokens[token.id] = token
users[rt_dict['user_id']].refresh_tokens[token.token] = token
for ac_dict in data['access_tokens']:
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
token = AccessToken(
refresh_token=refresh_token,
created_at=dt_util.parse_datetime(ac_dict['created_at']),
token=ac_dict['token'],
)
refresh_token.access_tokens.append(token)
clients = {
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
}
self.users = users
self.clients = clients
async def async_save(self): async def async_save(self):
"""Save users.""" """Save users."""
pass users = [
{
'id': user.id,
'is_owner': user.is_owner,
'is_active': user.is_active,
'name': user.name,
}
for user in self.users.values()
]
credentials = [
{
'id': credential.id,
'user_id': user.id,
'auth_provider_type': credential.auth_provider_type,
'auth_provider_id': credential.auth_provider_id,
'data': credential.data,
}
for user in self.users.values()
for credential in user.credentials
]
refresh_tokens = [
{
'id': refresh_token.id,
'user_id': user.id,
'client_id': refresh_token.client_id,
'created_at': refresh_token.created_at.isoformat(),
'access_token_expiration':
refresh_token.access_token_expiration.total_seconds(),
'token': refresh_token.token,
}
for user in self.users.values()
for refresh_token in user.refresh_tokens.values()
]
access_tokens = [
{
'id': user.id,
'refresh_token_id': refresh_token.id,
'created_at': access_token.created_at.isoformat(),
'token': access_token.token,
}
for user in self.users.values()
for refresh_token in user.refresh_tokens.values()
for access_token in refresh_token.access_tokens
]
clients = [
{
'id': client.id,
'name': client.name,
'secret': client.secret,
'redirect_uris': client.redirect_uris,
}
for client in self.clients.values()
]
data = {
'users': users,
'clients': clients,
'credentials': credentials,
'access_tokens': access_tokens,
'refresh_tokens': refresh_tokens,
}
await self._store.async_save(data, delay=1)

View File

@ -53,6 +53,7 @@ class Store:
self._unsub_delay_listener = None self._unsub_delay_listener = None
self._unsub_stop_listener = None self._unsub_stop_listener = None
self._write_lock = asyncio.Lock() self._write_lock = asyncio.Lock()
self._load_task = None
@property @property
def path(self): def path(self):
@ -64,7 +65,17 @@ class Store:
If the expected version does not match the given version, the migrate If the expected version does not match the given version, the migrate
function will be invoked with await migrate_func(version, config). function will be invoked with await migrate_func(version, config).
Will ensure that when a call comes in while another one is in progress,
the second call will wait and return the result of the first call.
""" """
if self._load_task is None:
self._load_task = self.hass.async_add_job(self._async_load())
return await self._load_task
async def _async_load(self):
"""Helper to load the data."""
if self._data is not None: if self._data is not None:
data = self._data data = self._data
else: else:
@ -75,9 +86,15 @@ class Store:
return None return None
if data['version'] == self.version: if data['version'] == self.version:
return data['data'] stored = data['data']
else:
_LOGGER.info('Migrating %s storage from %s to %s',
self.key, data['version'], self.version)
stored = await self._async_migrate_func(
data['version'], data['data'])
return await self._async_migrate_func(data['version'], data['data']) self._load_task = None
return stored
async def async_save(self, data: Dict, *, delay: Optional[int] = None): async def async_save(self, data: Dict, *, delay: Optional[int] = None):
"""Save data with an optional delay.""" """Save data with an optional delay."""

View File

@ -11,15 +11,15 @@ from tests.common import mock_coro
@pytest.fixture @pytest.fixture
def store(): def store(hass):
"""Mock store.""" """Mock store."""
return auth.AuthStore(Mock()) return auth.AuthStore(hass)
@pytest.fixture @pytest.fixture
def provider(store): def provider(hass, store):
"""Mock provider.""" """Mock provider."""
return insecure_example.ExampleAuthProvider(None, store, { return insecure_example.ExampleAuthProvider(hass, store, {
'type': 'insecure_example', 'type': 'insecure_example',
'users': [ 'users': [
{ {

View File

@ -2,6 +2,7 @@
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
import functools as ft import functools as ft
import json
import os import os
import sys import sys
from unittest.mock import patch, MagicMock, Mock from unittest.mock import patch, MagicMock, Mock
@ -15,7 +16,7 @@ from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config from homeassistant.config import async_process_component_config
from homeassistant.helpers import ( from homeassistant.helpers import (
intent, entity, restore_state, entity_registry, intent, entity, restore_state, entity_registry,
entity_platform) entity_platform, storage)
from homeassistant.util.unit_system import METRIC_SYSTEM from homeassistant.util.unit_system import METRIC_SYSTEM
import homeassistant.util.dt as date_util import homeassistant.util.dt as date_util
import homeassistant.util.yaml as yaml import homeassistant.util.yaml as yaml
@ -705,3 +706,51 @@ class MockEntity(entity.Entity):
if attr in self._values: if attr in self._values:
return self._values[attr] return self._values[attr]
return getattr(super(), attr) return getattr(super(), attr)
@contextmanager
def mock_storage(data=None):
"""Mock storage.
Data is a dict {'key': {'version': version, 'data': data}}
Written data will be converted to JSON to ensure JSON parsing works.
"""
if data is None:
data = {}
orig_load = storage.Store._async_load
async def mock_async_load(store):
"""Mock version of load."""
if store._data is None:
# No data to load
if store.key not in data:
return None
store._data = data.get(store.key)
# Route through original load so that we trigger migration
loaded = await orig_load(store)
_LOGGER.info('Loading data for %s: %s', store.key, loaded)
return loaded
def mock_write_data(store, path, data_to_write):
"""Mock version of write data."""
# To ensure that the data can be serialized
_LOGGER.info('Writing data to %s: %s', store.key, data_to_write)
data[store.key] = json.loads(json.dumps(data_to_write))
with patch('homeassistant.helpers.storage.Store._async_load',
side_effect=mock_async_load, autospec=True), \
patch('homeassistant.helpers.storage.Store._write_data',
side_effect=mock_write_data, autospec=True):
yield data
async def flush_store(store):
"""Make sure all delayed writes of a store are written."""
if store._data is None:
return
await store._async_handle_write_data()

View File

@ -12,7 +12,8 @@ from homeassistant import util
from homeassistant.util import location from homeassistant.util import location
from tests.common import ( from tests.common import (
async_test_home_assistant, INSTANCES, async_mock_mqtt_component, mock_coro) async_test_home_assistant, INSTANCES, async_mock_mqtt_component, mock_coro,
mock_storage as mock_storage)
from tests.test_util.aiohttp import mock_aiohttp_client from tests.test_util.aiohttp import mock_aiohttp_client
from tests.mock.zwave import MockNetwork, MockOption from tests.mock.zwave import MockNetwork, MockOption
@ -59,7 +60,14 @@ def verify_cleanup():
@pytest.fixture @pytest.fixture
def hass(loop): def hass_storage():
"""Fixture to mock storage."""
with mock_storage() as stored_data:
yield stored_data
@pytest.fixture
def hass(loop, hass_storage):
"""Fixture to provide a test instance of HASS.""" """Fixture to provide a test instance of HASS."""
hass = loop.run_until_complete(async_test_home_assistant(loop)) hass = loop.run_until_complete(async_test_home_assistant(loop))

View File

@ -1,4 +1,5 @@
"""Tests for the storage helper.""" """Tests for the storage helper."""
import asyncio
from datetime import timedelta from datetime import timedelta
from unittest.mock import patch from unittest.mock import patch
@ -16,32 +17,13 @@ MOCK_KEY = 'storage-test'
MOCK_DATA = {'hello': 'world'} MOCK_DATA = {'hello': 'world'}
@pytest.fixture
def mock_save():
"""Fixture to mock JSON save."""
written = []
with patch('homeassistant.util.json.save_json',
side_effect=lambda *args: written.append(args)):
yield written
@pytest.fixture
def mock_load(mock_save):
"""Fixture to mock JSON read."""
with patch('homeassistant.util.json.load_json',
side_effect=lambda *args: mock_save[-1][1]):
yield
@pytest.fixture @pytest.fixture
def store(hass): def store(hass):
"""Fixture of a store that prevents writing on HASS stop.""" """Fixture of a store that prevents writing on HASS stop."""
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY) yield storage.Store(hass, MOCK_VERSION, MOCK_KEY)
store._async_ensure_stop_listener = lambda: None
yield store
async def test_loading(hass, store, mock_save, mock_load): async def test_loading(hass, store):
"""Test we can save and load data.""" """Test we can save and load data."""
await store.async_save(MOCK_DATA) await store.async_save(MOCK_DATA)
data = await store.async_load() data = await store.async_load()
@ -55,55 +37,96 @@ async def test_loading_non_existing(hass, store):
assert data is None assert data is None
async def test_saving_with_delay(hass, store, mock_save): async def test_loading_parallel(hass, store, hass_storage, caplog):
"""Test we can save and load data."""
hass_storage[store.key] = {
'version': MOCK_VERSION,
'data': MOCK_DATA,
}
results = await asyncio.gather(
store.async_load(),
store.async_load()
)
assert results[0] is MOCK_DATA
assert results[1] is MOCK_DATA
assert caplog.text.count('Loading data for {}'.format(store.key))
async def test_saving_with_delay(hass, store, hass_storage):
"""Test saving data after a delay.""" """Test saving data after a delay."""
await store.async_save(MOCK_DATA, delay=1) await store.async_save(MOCK_DATA, delay=1)
assert len(mock_save) == 0 assert store.key not in hass_storage
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_save) == 1 assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': MOCK_DATA,
}
async def test_saving_on_stop(hass, mock_save): async def test_saving_on_stop(hass, hass_storage):
"""Test delayed saves trigger when we quit Home Assistant.""" """Test delayed saves trigger when we quit Home Assistant."""
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY) store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
await store.async_save(MOCK_DATA, delay=1) await store.async_save(MOCK_DATA, delay=1)
assert len(mock_save) == 0 assert store.key not in hass_storage
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_save) == 1 assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': MOCK_DATA,
}
async def test_loading_while_delay(hass, store, mock_save, mock_load): async def test_loading_while_delay(hass, store, hass_storage):
"""Test we load new data even if not written yet.""" """Test we load new data even if not written yet."""
await store.async_save({'delay': 'no'}) await store.async_save({'delay': 'no'})
assert len(mock_save) == 1 assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': {'delay': 'no'},
}
await store.async_save({'delay': 'yes'}, delay=1) await store.async_save({'delay': 'yes'}, delay=1)
assert len(mock_save) == 1 assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': {'delay': 'no'},
}
data = await store.async_load() data = await store.async_load()
assert data == {'delay': 'yes'} assert data == {'delay': 'yes'}
async def test_writing_while_writing_delay(hass, store, mock_save, mock_load): async def test_writing_while_writing_delay(hass, store, hass_storage):
"""Test a write while a write with delay is active.""" """Test a write while a write with delay is active."""
await store.async_save({'delay': 'yes'}, delay=1) await store.async_save({'delay': 'yes'}, delay=1)
assert len(mock_save) == 0 assert store.key not in hass_storage
await store.async_save({'delay': 'no'}) await store.async_save({'delay': 'no'})
assert len(mock_save) == 1 assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': {'delay': 'no'},
}
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_save) == 1 assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
'data': {'delay': 'no'},
}
data = await store.async_load() data = await store.async_load()
assert data == {'delay': 'no'} assert data == {'delay': 'no'}
async def test_migrator_no_existing_config(hass, store, mock_save): async def test_migrator_no_existing_config(hass, store, hass_storage):
"""Test migrator with no existing config.""" """Test migrator with no existing config."""
with patch('os.path.isfile', return_value=False), \ with patch('os.path.isfile', return_value=False), \
patch.object(store, 'async_load', patch.object(store, 'async_load',
@ -112,10 +135,10 @@ async def test_migrator_no_existing_config(hass, store, mock_save):
hass, 'old-path', store) hass, 'old-path', store)
assert data == {'cur': 'config'} assert data == {'cur': 'config'}
assert len(mock_save) == 0 assert store.key not in hass_storage
async def test_migrator_existing_config(hass, store, mock_save): async def test_migrator_existing_config(hass, store, hass_storage):
"""Test migrating existing config.""" """Test migrating existing config."""
with patch('os.path.isfile', return_value=True), \ with patch('os.path.isfile', return_value=True), \
patch('os.remove') as mock_remove, \ patch('os.remove') as mock_remove, \
@ -126,15 +149,14 @@ async def test_migrator_existing_config(hass, store, mock_save):
assert len(mock_remove.mock_calls) == 1 assert len(mock_remove.mock_calls) == 1
assert data == {'old': 'config'} assert data == {'old': 'config'}
assert len(mock_save) == 1 assert hass_storage[store.key] == {
assert mock_save[0][1] == {
'key': MOCK_KEY, 'key': MOCK_KEY,
'version': MOCK_VERSION, 'version': MOCK_VERSION,
'data': data, 'data': data,
} }
async def test_migrator_transforming_config(hass, store, mock_save): async def test_migrator_transforming_config(hass, store, hass_storage):
"""Test migrating config to new format.""" """Test migrating config to new format."""
async def old_conf_migrate_func(old_config): async def old_conf_migrate_func(old_config):
"""Migrate old config to new format.""" """Migrate old config to new format."""
@ -150,8 +172,7 @@ async def test_migrator_transforming_config(hass, store, mock_save):
assert len(mock_remove.mock_calls) == 1 assert len(mock_remove.mock_calls) == 1
assert data == {'new': 'config'} assert data == {'new': 'config'}
assert len(mock_save) == 1 assert hass_storage[store.key] == {
assert mock_save[0][1] == {
'key': MOCK_KEY, 'key': MOCK_KEY,
'version': MOCK_VERSION, 'version': MOCK_VERSION,
'data': data, 'data': data,

View File

@ -4,7 +4,7 @@ from unittest.mock import Mock
import pytest import pytest
from homeassistant import auth, data_entry_flow from homeassistant import auth, data_entry_flow
from tests.common import MockUser, ensure_auth_manager_loaded from tests.common import MockUser, ensure_auth_manager_loaded, flush_store
@pytest.fixture @pytest.fixture
@ -53,9 +53,9 @@ async def test_auth_manager_from_config_validates_config_and_id(mock_hass):
}] }]
async def test_create_new_user(mock_hass): async def test_create_new_user(hass, hass_storage):
"""Test creating new user.""" """Test creating new user."""
manager = await auth.auth_manager_from_config(mock_hass, [{ manager = await auth.auth_manager_from_config(hass, [{
'type': 'insecure_example', 'type': 'insecure_example',
'users': [{ 'users': [{
'username': 'test-user', 'username': 'test-user',
@ -124,9 +124,9 @@ async def test_login_as_existing_user(mock_hass):
assert user.name == 'Paulus' assert user.name == 'Paulus'
async def test_linking_user_to_two_auth_providers(mock_hass): async def test_linking_user_to_two_auth_providers(hass, hass_storage):
"""Test linking user to two auth providers.""" """Test linking user to two auth providers."""
manager = await auth.auth_manager_from_config(mock_hass, [{ manager = await auth.auth_manager_from_config(hass, [{
'type': 'insecure_example', 'type': 'insecure_example',
'users': [{ 'users': [{
'username': 'test-user', 'username': 'test-user',
@ -157,3 +157,41 @@ async def test_linking_user_to_two_auth_providers(mock_hass):
}) })
await manager.async_link_user(user, step['result']) await manager.async_link_user(user, step['result'])
assert len(user.credentials) == 2 assert len(user.credentials) == 2
async def test_saving_loading(hass, hass_storage):
"""Test storing and saving data.
Creates one of each type that we store to test we restore correctly.
"""
manager = await auth.auth_manager_from_config(hass, [{
'type': 'insecure_example',
'users': [{
'username': 'test-user',
'password': 'test-pass',
}]
}])
step = await manager.login_flow.async_init(('insecure_example', None))
step = await manager.login_flow.async_configure(step['flow_id'], {
'username': 'test-user',
'password': 'test-pass',
})
user = await manager.async_get_or_create_user(step['result'])
client = await manager.async_create_client(
'test', redirect_uris=['https://example.com'])
refresh_token = await manager.async_create_refresh_token(user, client.id)
manager.async_create_access_token(refresh_token)
await flush_store(manager._store._store)
store2 = auth.AuthStore(hass)
await store2.async_load()
assert len(store2.users) == 1
assert store2.users[user.id] == user
assert len(store2.clients) == 1
assert store2.clients[client.id] == client

View File

@ -1,7 +1,7 @@
"""Test the config manager.""" """Test the config manager."""
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
from unittest.mock import MagicMock, patch, mock_open from unittest.mock import MagicMock, patch
import pytest import pytest
@ -152,8 +152,7 @@ def test_domains_gets_uniques(manager):
assert manager.async_domains() == ['test', 'test2', 'test3'] assert manager.async_domains() == ['test', 'test2', 'test3']
@asyncio.coroutine async def test_saving_and_loading(hass):
def test_saving_and_loading(hass):
"""Test that we're saving and loading correctly.""" """Test that we're saving and loading correctly."""
loader.set_component( loader.set_component(
hass, 'test', hass, 'test',
@ -172,7 +171,7 @@ def test_saving_and_loading(hass):
) )
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
yield from hass.config_entries.flow.async_init('test') await hass.config_entries.flow.async_init('test')
class Test2Flow(data_entry_flow.FlowHandler): class Test2Flow(data_entry_flow.FlowHandler):
VERSION = 3 VERSION = 3
@ -186,27 +185,18 @@ def test_saving_and_loading(hass):
} }
) )
json_path = 'homeassistant.util.json.open'
with patch('homeassistant.config_entries.HANDLERS.get', with patch('homeassistant.config_entries.HANDLERS.get',
return_value=Test2Flow): return_value=Test2Flow):
yield from hass.config_entries.flow.async_init('test') await hass.config_entries.flow.async_init('test')
with patch(json_path, mock_open(), create=True) as mock_write:
# To trigger the call_later # To trigger the call_later
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
# To execute the save # To execute the save
yield from hass.async_block_till_done() await hass.async_block_till_done()
# Mock open calls are: open file, context enter, write, context leave
written = mock_write.mock_calls[2][1][0]
# Now load written data in new config manager # Now load written data in new config manager
manager = config_entries.ConfigEntries(hass, {}) manager = config_entries.ConfigEntries(hass, {})
await manager.async_load()
with patch('os.path.isfile', return_value=False), \
patch(json_path, mock_open(read_data=written), create=True):
yield from manager.async_load()
# Ensure same order # Ensure same order
for orig, loaded in zip(hass.config_entries.async_entries(), for orig, loaded in zip(hass.config_entries.async_entries(),