mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 09:17:53 +00:00
Storage auth (#15192)
* Support parallel loading * Add storage mock * Store auth * Fix tests
This commit is contained in:
parent
a277470363
commit
2205090795
@ -21,6 +21,8 @@ from homeassistant.util import dt as dt_util
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = 'auth'
|
||||
|
||||
AUTH_PROVIDERS = Registry()
|
||||
|
||||
@ -121,23 +123,12 @@ class User:
|
||||
is_owner = attr.ib(type=bool, default=False)
|
||||
is_active = attr.ib(type=bool, default=False)
|
||||
name = attr.ib(type=str, default=None)
|
||||
# For persisting and see if saved?
|
||||
# store = attr.ib(type=AuthStore, default=None)
|
||||
|
||||
# List of credentials of a user.
|
||||
credentials = attr.ib(type=list, default=attr.Factory(list))
|
||||
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
||||
|
||||
# Tokens associated with a user.
|
||||
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict))
|
||||
|
||||
def as_dict(self):
|
||||
"""Convert user object to a dictionary."""
|
||||
return {
|
||||
'id': self.id,
|
||||
'is_owner': self.is_owner,
|
||||
'is_active': self.is_active,
|
||||
'name': self.name,
|
||||
}
|
||||
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
@ -152,7 +143,7 @@ class RefreshToken:
|
||||
default=ACCESS_TOKEN_EXPIRATION)
|
||||
token = attr.ib(type=str,
|
||||
default=attr.Factory(lambda: generate_secret(64)))
|
||||
access_tokens = attr.ib(type=list, default=attr.Factory(list))
|
||||
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
@ -376,7 +367,7 @@ class AuthStore:
|
||||
self.hass = hass
|
||||
self.users = None
|
||||
self.clients = None
|
||||
self._load_lock = asyncio.Lock(loop=hass.loop)
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
async def credentials_for_provider(self, provider_type, provider_id):
|
||||
"""Return credentials for specific auth provider type and id."""
|
||||
@ -494,10 +485,128 @@ class AuthStore:
|
||||
|
||||
async def async_load(self):
|
||||
"""Load the users."""
|
||||
async with self._load_lock:
|
||||
data = await self._store.async_load()
|
||||
|
||||
# Make sure that we're not overriding data if 2 loads happened at the
|
||||
# same time
|
||||
if self.users is not None:
|
||||
return
|
||||
|
||||
if data is None:
|
||||
self.users = {}
|
||||
self.clients = {}
|
||||
return
|
||||
|
||||
users = {
|
||||
user_dict['id']: User(**user_dict) for user_dict in data['users']
|
||||
}
|
||||
|
||||
for cred_dict in data['credentials']:
|
||||
users[cred_dict['user_id']].credentials.append(Credentials(
|
||||
id=cred_dict['id'],
|
||||
is_new=False,
|
||||
auth_provider_type=cred_dict['auth_provider_type'],
|
||||
auth_provider_id=cred_dict['auth_provider_id'],
|
||||
data=cred_dict['data'],
|
||||
))
|
||||
|
||||
refresh_tokens = {}
|
||||
|
||||
for rt_dict in data['refresh_tokens']:
|
||||
token = RefreshToken(
|
||||
id=rt_dict['id'],
|
||||
user=users[rt_dict['user_id']],
|
||||
client_id=rt_dict['client_id'],
|
||||
created_at=dt_util.parse_datetime(rt_dict['created_at']),
|
||||
access_token_expiration=timedelta(
|
||||
rt_dict['access_token_expiration']),
|
||||
token=rt_dict['token'],
|
||||
)
|
||||
refresh_tokens[token.id] = token
|
||||
users[rt_dict['user_id']].refresh_tokens[token.token] = token
|
||||
|
||||
for ac_dict in data['access_tokens']:
|
||||
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
|
||||
token = AccessToken(
|
||||
refresh_token=refresh_token,
|
||||
created_at=dt_util.parse_datetime(ac_dict['created_at']),
|
||||
token=ac_dict['token'],
|
||||
)
|
||||
refresh_token.access_tokens.append(token)
|
||||
|
||||
clients = {
|
||||
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
|
||||
}
|
||||
|
||||
self.users = users
|
||||
self.clients = clients
|
||||
|
||||
async def async_save(self):
|
||||
"""Save users."""
|
||||
pass
|
||||
users = [
|
||||
{
|
||||
'id': user.id,
|
||||
'is_owner': user.is_owner,
|
||||
'is_active': user.is_active,
|
||||
'name': user.name,
|
||||
}
|
||||
for user in self.users.values()
|
||||
]
|
||||
|
||||
credentials = [
|
||||
{
|
||||
'id': credential.id,
|
||||
'user_id': user.id,
|
||||
'auth_provider_type': credential.auth_provider_type,
|
||||
'auth_provider_id': credential.auth_provider_id,
|
||||
'data': credential.data,
|
||||
}
|
||||
for user in self.users.values()
|
||||
for credential in user.credentials
|
||||
]
|
||||
|
||||
refresh_tokens = [
|
||||
{
|
||||
'id': refresh_token.id,
|
||||
'user_id': user.id,
|
||||
'client_id': refresh_token.client_id,
|
||||
'created_at': refresh_token.created_at.isoformat(),
|
||||
'access_token_expiration':
|
||||
refresh_token.access_token_expiration.total_seconds(),
|
||||
'token': refresh_token.token,
|
||||
}
|
||||
for user in self.users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
]
|
||||
|
||||
access_tokens = [
|
||||
{
|
||||
'id': user.id,
|
||||
'refresh_token_id': refresh_token.id,
|
||||
'created_at': access_token.created_at.isoformat(),
|
||||
'token': access_token.token,
|
||||
}
|
||||
for user in self.users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
for access_token in refresh_token.access_tokens
|
||||
]
|
||||
|
||||
clients = [
|
||||
{
|
||||
'id': client.id,
|
||||
'name': client.name,
|
||||
'secret': client.secret,
|
||||
'redirect_uris': client.redirect_uris,
|
||||
}
|
||||
for client in self.clients.values()
|
||||
]
|
||||
|
||||
data = {
|
||||
'users': users,
|
||||
'clients': clients,
|
||||
'credentials': credentials,
|
||||
'access_tokens': access_tokens,
|
||||
'refresh_tokens': refresh_tokens,
|
||||
}
|
||||
|
||||
await self._store.async_save(data, delay=1)
|
||||
|
@ -53,6 +53,7 @@ class Store:
|
||||
self._unsub_delay_listener = None
|
||||
self._unsub_stop_listener = None
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._load_task = None
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
@ -64,7 +65,17 @@ class Store:
|
||||
|
||||
If the expected version does not match the given version, the migrate
|
||||
function will be invoked with await migrate_func(version, config).
|
||||
|
||||
Will ensure that when a call comes in while another one is in progress,
|
||||
the second call will wait and return the result of the first call.
|
||||
"""
|
||||
if self._load_task is None:
|
||||
self._load_task = self.hass.async_add_job(self._async_load())
|
||||
|
||||
return await self._load_task
|
||||
|
||||
async def _async_load(self):
|
||||
"""Helper to load the data."""
|
||||
if self._data is not None:
|
||||
data = self._data
|
||||
else:
|
||||
@ -75,9 +86,15 @@ class Store:
|
||||
return None
|
||||
|
||||
if data['version'] == self.version:
|
||||
return data['data']
|
||||
stored = data['data']
|
||||
else:
|
||||
_LOGGER.info('Migrating %s storage from %s to %s',
|
||||
self.key, data['version'], self.version)
|
||||
stored = await self._async_migrate_func(
|
||||
data['version'], data['data'])
|
||||
|
||||
return await self._async_migrate_func(data['version'], data['data'])
|
||||
self._load_task = None
|
||||
return stored
|
||||
|
||||
async def async_save(self, data: Dict, *, delay: Optional[int] = None):
|
||||
"""Save data with an optional delay."""
|
||||
|
@ -11,15 +11,15 @@ from tests.common import mock_coro
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
def store(hass):
|
||||
"""Mock store."""
|
||||
return auth.AuthStore(Mock())
|
||||
return auth.AuthStore(hass)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(store):
|
||||
def provider(hass, store):
|
||||
"""Mock provider."""
|
||||
return insecure_example.ExampleAuthProvider(None, store, {
|
||||
return insecure_example.ExampleAuthProvider(hass, store, {
|
||||
'type': 'insecure_example',
|
||||
'users': [
|
||||
{
|
||||
|
@ -2,6 +2,7 @@
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
import functools as ft
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
@ -15,7 +16,7 @@ from homeassistant.setup import setup_component, async_setup_component
|
||||
from homeassistant.config import async_process_component_config
|
||||
from homeassistant.helpers import (
|
||||
intent, entity, restore_state, entity_registry,
|
||||
entity_platform)
|
||||
entity_platform, storage)
|
||||
from homeassistant.util.unit_system import METRIC_SYSTEM
|
||||
import homeassistant.util.dt as date_util
|
||||
import homeassistant.util.yaml as yaml
|
||||
@ -705,3 +706,51 @@ class MockEntity(entity.Entity):
|
||||
if attr in self._values:
|
||||
return self._values[attr]
|
||||
return getattr(super(), attr)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mock_storage(data=None):
|
||||
"""Mock storage.
|
||||
|
||||
Data is a dict {'key': {'version': version, 'data': data}}
|
||||
|
||||
Written data will be converted to JSON to ensure JSON parsing works.
|
||||
"""
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
orig_load = storage.Store._async_load
|
||||
|
||||
async def mock_async_load(store):
|
||||
"""Mock version of load."""
|
||||
if store._data is None:
|
||||
# No data to load
|
||||
if store.key not in data:
|
||||
return None
|
||||
|
||||
store._data = data.get(store.key)
|
||||
|
||||
# Route through original load so that we trigger migration
|
||||
loaded = await orig_load(store)
|
||||
_LOGGER.info('Loading data for %s: %s', store.key, loaded)
|
||||
return loaded
|
||||
|
||||
def mock_write_data(store, path, data_to_write):
|
||||
"""Mock version of write data."""
|
||||
# To ensure that the data can be serialized
|
||||
_LOGGER.info('Writing data to %s: %s', store.key, data_to_write)
|
||||
data[store.key] = json.loads(json.dumps(data_to_write))
|
||||
|
||||
with patch('homeassistant.helpers.storage.Store._async_load',
|
||||
side_effect=mock_async_load, autospec=True), \
|
||||
patch('homeassistant.helpers.storage.Store._write_data',
|
||||
side_effect=mock_write_data, autospec=True):
|
||||
yield data
|
||||
|
||||
|
||||
async def flush_store(store):
|
||||
"""Make sure all delayed writes of a store are written."""
|
||||
if store._data is None:
|
||||
return
|
||||
|
||||
await store._async_handle_write_data()
|
||||
|
@ -12,7 +12,8 @@ from homeassistant import util
|
||||
from homeassistant.util import location
|
||||
|
||||
from tests.common import (
|
||||
async_test_home_assistant, INSTANCES, async_mock_mqtt_component, mock_coro)
|
||||
async_test_home_assistant, INSTANCES, async_mock_mqtt_component, mock_coro,
|
||||
mock_storage as mock_storage)
|
||||
from tests.test_util.aiohttp import mock_aiohttp_client
|
||||
from tests.mock.zwave import MockNetwork, MockOption
|
||||
|
||||
@ -59,7 +60,14 @@ def verify_cleanup():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hass(loop):
|
||||
def hass_storage():
|
||||
"""Fixture to mock storage."""
|
||||
with mock_storage() as stored_data:
|
||||
yield stored_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hass(loop, hass_storage):
|
||||
"""Fixture to provide a test instance of HASS."""
|
||||
hass = loop.run_until_complete(async_test_home_assistant(loop))
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Tests for the storage helper."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -16,32 +17,13 @@ MOCK_KEY = 'storage-test'
|
||||
MOCK_DATA = {'hello': 'world'}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_save():
|
||||
"""Fixture to mock JSON save."""
|
||||
written = []
|
||||
with patch('homeassistant.util.json.save_json',
|
||||
side_effect=lambda *args: written.append(args)):
|
||||
yield written
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_load(mock_save):
|
||||
"""Fixture to mock JSON read."""
|
||||
with patch('homeassistant.util.json.load_json',
|
||||
side_effect=lambda *args: mock_save[-1][1]):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(hass):
|
||||
"""Fixture of a store that prevents writing on HASS stop."""
|
||||
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
|
||||
store._async_ensure_stop_listener = lambda: None
|
||||
yield store
|
||||
yield storage.Store(hass, MOCK_VERSION, MOCK_KEY)
|
||||
|
||||
|
||||
async def test_loading(hass, store, mock_save, mock_load):
|
||||
async def test_loading(hass, store):
|
||||
"""Test we can save and load data."""
|
||||
await store.async_save(MOCK_DATA)
|
||||
data = await store.async_load()
|
||||
@ -55,55 +37,96 @@ async def test_loading_non_existing(hass, store):
|
||||
assert data is None
|
||||
|
||||
|
||||
async def test_saving_with_delay(hass, store, mock_save):
|
||||
async def test_loading_parallel(hass, store, hass_storage, caplog):
|
||||
"""Test we can save and load data."""
|
||||
hass_storage[store.key] = {
|
||||
'version': MOCK_VERSION,
|
||||
'data': MOCK_DATA,
|
||||
}
|
||||
|
||||
results = await asyncio.gather(
|
||||
store.async_load(),
|
||||
store.async_load()
|
||||
)
|
||||
|
||||
assert results[0] is MOCK_DATA
|
||||
assert results[1] is MOCK_DATA
|
||||
assert caplog.text.count('Loading data for {}'.format(store.key))
|
||||
|
||||
|
||||
async def test_saving_with_delay(hass, store, hass_storage):
|
||||
"""Test saving data after a delay."""
|
||||
await store.async_save(MOCK_DATA, delay=1)
|
||||
assert len(mock_save) == 0
|
||||
assert store.key not in hass_storage
|
||||
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
|
||||
await hass.async_block_till_done()
|
||||
assert len(mock_save) == 1
|
||||
assert hass_storage[store.key] == {
|
||||
'version': MOCK_VERSION,
|
||||
'key': MOCK_KEY,
|
||||
'data': MOCK_DATA,
|
||||
}
|
||||
|
||||
|
||||
async def test_saving_on_stop(hass, mock_save):
|
||||
async def test_saving_on_stop(hass, hass_storage):
|
||||
"""Test delayed saves trigger when we quit Home Assistant."""
|
||||
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
|
||||
await store.async_save(MOCK_DATA, delay=1)
|
||||
assert len(mock_save) == 0
|
||||
assert store.key not in hass_storage
|
||||
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
||||
await hass.async_block_till_done()
|
||||
assert len(mock_save) == 1
|
||||
assert hass_storage[store.key] == {
|
||||
'version': MOCK_VERSION,
|
||||
'key': MOCK_KEY,
|
||||
'data': MOCK_DATA,
|
||||
}
|
||||
|
||||
|
||||
async def test_loading_while_delay(hass, store, mock_save, mock_load):
|
||||
async def test_loading_while_delay(hass, store, hass_storage):
|
||||
"""Test we load new data even if not written yet."""
|
||||
await store.async_save({'delay': 'no'})
|
||||
assert len(mock_save) == 1
|
||||
assert hass_storage[store.key] == {
|
||||
'version': MOCK_VERSION,
|
||||
'key': MOCK_KEY,
|
||||
'data': {'delay': 'no'},
|
||||
}
|
||||
|
||||
await store.async_save({'delay': 'yes'}, delay=1)
|
||||
assert len(mock_save) == 1
|
||||
assert hass_storage[store.key] == {
|
||||
'version': MOCK_VERSION,
|
||||
'key': MOCK_KEY,
|
||||
'data': {'delay': 'no'},
|
||||
}
|
||||
|
||||
data = await store.async_load()
|
||||
assert data == {'delay': 'yes'}
|
||||
|
||||
|
||||
async def test_writing_while_writing_delay(hass, store, mock_save, mock_load):
|
||||
async def test_writing_while_writing_delay(hass, store, hass_storage):
|
||||
"""Test a write while a write with delay is active."""
|
||||
await store.async_save({'delay': 'yes'}, delay=1)
|
||||
assert len(mock_save) == 0
|
||||
assert store.key not in hass_storage
|
||||
await store.async_save({'delay': 'no'})
|
||||
assert len(mock_save) == 1
|
||||
assert hass_storage[store.key] == {
|
||||
'version': MOCK_VERSION,
|
||||
'key': MOCK_KEY,
|
||||
'data': {'delay': 'no'},
|
||||
}
|
||||
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
|
||||
await hass.async_block_till_done()
|
||||
assert len(mock_save) == 1
|
||||
assert hass_storage[store.key] == {
|
||||
'version': MOCK_VERSION,
|
||||
'key': MOCK_KEY,
|
||||
'data': {'delay': 'no'},
|
||||
}
|
||||
|
||||
data = await store.async_load()
|
||||
assert data == {'delay': 'no'}
|
||||
|
||||
|
||||
async def test_migrator_no_existing_config(hass, store, mock_save):
|
||||
async def test_migrator_no_existing_config(hass, store, hass_storage):
|
||||
"""Test migrator with no existing config."""
|
||||
with patch('os.path.isfile', return_value=False), \
|
||||
patch.object(store, 'async_load',
|
||||
@ -112,10 +135,10 @@ async def test_migrator_no_existing_config(hass, store, mock_save):
|
||||
hass, 'old-path', store)
|
||||
|
||||
assert data == {'cur': 'config'}
|
||||
assert len(mock_save) == 0
|
||||
assert store.key not in hass_storage
|
||||
|
||||
|
||||
async def test_migrator_existing_config(hass, store, mock_save):
|
||||
async def test_migrator_existing_config(hass, store, hass_storage):
|
||||
"""Test migrating existing config."""
|
||||
with patch('os.path.isfile', return_value=True), \
|
||||
patch('os.remove') as mock_remove, \
|
||||
@ -126,15 +149,14 @@ async def test_migrator_existing_config(hass, store, mock_save):
|
||||
|
||||
assert len(mock_remove.mock_calls) == 1
|
||||
assert data == {'old': 'config'}
|
||||
assert len(mock_save) == 1
|
||||
assert mock_save[0][1] == {
|
||||
assert hass_storage[store.key] == {
|
||||
'key': MOCK_KEY,
|
||||
'version': MOCK_VERSION,
|
||||
'data': data,
|
||||
}
|
||||
|
||||
|
||||
async def test_migrator_transforming_config(hass, store, mock_save):
|
||||
async def test_migrator_transforming_config(hass, store, hass_storage):
|
||||
"""Test migrating config to new format."""
|
||||
async def old_conf_migrate_func(old_config):
|
||||
"""Migrate old config to new format."""
|
||||
@ -150,8 +172,7 @@ async def test_migrator_transforming_config(hass, store, mock_save):
|
||||
|
||||
assert len(mock_remove.mock_calls) == 1
|
||||
assert data == {'new': 'config'}
|
||||
assert len(mock_save) == 1
|
||||
assert mock_save[0][1] == {
|
||||
assert hass_storage[store.key] == {
|
||||
'key': MOCK_KEY,
|
||||
'version': MOCK_VERSION,
|
||||
'data': data,
|
||||
|
@ -4,7 +4,7 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
from homeassistant import auth, data_entry_flow
|
||||
from tests.common import MockUser, ensure_auth_manager_loaded
|
||||
from tests.common import MockUser, ensure_auth_manager_loaded, flush_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -53,9 +53,9 @@ async def test_auth_manager_from_config_validates_config_and_id(mock_hass):
|
||||
}]
|
||||
|
||||
|
||||
async def test_create_new_user(mock_hass):
|
||||
async def test_create_new_user(hass, hass_storage):
|
||||
"""Test creating new user."""
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [{
|
||||
manager = await auth.auth_manager_from_config(hass, [{
|
||||
'type': 'insecure_example',
|
||||
'users': [{
|
||||
'username': 'test-user',
|
||||
@ -124,9 +124,9 @@ async def test_login_as_existing_user(mock_hass):
|
||||
assert user.name == 'Paulus'
|
||||
|
||||
|
||||
async def test_linking_user_to_two_auth_providers(mock_hass):
|
||||
async def test_linking_user_to_two_auth_providers(hass, hass_storage):
|
||||
"""Test linking user to two auth providers."""
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [{
|
||||
manager = await auth.auth_manager_from_config(hass, [{
|
||||
'type': 'insecure_example',
|
||||
'users': [{
|
||||
'username': 'test-user',
|
||||
@ -157,3 +157,41 @@ async def test_linking_user_to_two_auth_providers(mock_hass):
|
||||
})
|
||||
await manager.async_link_user(user, step['result'])
|
||||
assert len(user.credentials) == 2
|
||||
|
||||
|
||||
async def test_saving_loading(hass, hass_storage):
|
||||
"""Test storing and saving data.
|
||||
|
||||
Creates one of each type that we store to test we restore correctly.
|
||||
"""
|
||||
manager = await auth.auth_manager_from_config(hass, [{
|
||||
'type': 'insecure_example',
|
||||
'users': [{
|
||||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
}]
|
||||
}])
|
||||
|
||||
step = await manager.login_flow.async_init(('insecure_example', None))
|
||||
step = await manager.login_flow.async_configure(step['flow_id'], {
|
||||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
user = await manager.async_get_or_create_user(step['result'])
|
||||
|
||||
client = await manager.async_create_client(
|
||||
'test', redirect_uris=['https://example.com'])
|
||||
|
||||
refresh_token = await manager.async_create_refresh_token(user, client.id)
|
||||
|
||||
manager.async_create_access_token(refresh_token)
|
||||
|
||||
await flush_store(manager._store._store)
|
||||
|
||||
store2 = auth.AuthStore(hass)
|
||||
await store2.async_load()
|
||||
assert len(store2.users) == 1
|
||||
assert store2.users[user.id] == user
|
||||
|
||||
assert len(store2.clients) == 1
|
||||
assert store2.clients[client.id] == client
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test the config manager."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock, patch, mock_open
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -152,8 +152,7 @@ def test_domains_gets_uniques(manager):
|
||||
assert manager.async_domains() == ['test', 'test2', 'test3']
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_saving_and_loading(hass):
|
||||
async def test_saving_and_loading(hass):
|
||||
"""Test that we're saving and loading correctly."""
|
||||
loader.set_component(
|
||||
hass, 'test',
|
||||
@ -172,7 +171,7 @@ def test_saving_and_loading(hass):
|
||||
)
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
|
||||
yield from hass.config_entries.flow.async_init('test')
|
||||
await hass.config_entries.flow.async_init('test')
|
||||
|
||||
class Test2Flow(data_entry_flow.FlowHandler):
|
||||
VERSION = 3
|
||||
@ -186,27 +185,18 @@ def test_saving_and_loading(hass):
|
||||
}
|
||||
)
|
||||
|
||||
json_path = 'homeassistant.util.json.open'
|
||||
|
||||
with patch('homeassistant.config_entries.HANDLERS.get',
|
||||
return_value=Test2Flow):
|
||||
yield from hass.config_entries.flow.async_init('test')
|
||||
await hass.config_entries.flow.async_init('test')
|
||||
|
||||
with patch(json_path, mock_open(), create=True) as mock_write:
|
||||
# To trigger the call_later
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
|
||||
# To execute the save
|
||||
yield from hass.async_block_till_done()
|
||||
|
||||
# Mock open calls are: open file, context enter, write, context leave
|
||||
written = mock_write.mock_calls[2][1][0]
|
||||
# To trigger the call_later
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
|
||||
# To execute the save
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Now load written data in new config manager
|
||||
manager = config_entries.ConfigEntries(hass, {})
|
||||
|
||||
with patch('os.path.isfile', return_value=False), \
|
||||
patch(json_path, mock_open(read_data=written), create=True):
|
||||
yield from manager.async_load()
|
||||
await manager.async_load()
|
||||
|
||||
# Ensure same order
|
||||
for orig, loaded in zip(hass.config_entries.async_entries(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user