Migrate home assistant auth provider to use storage helper (#15200)

This commit is contained in:
Paulus Schoutsen 2018-06-29 00:02:45 -04:00 committed by GitHub
parent 39971ee919
commit 26590e244c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 113 additions and 116 deletions

View File

@ -8,10 +8,10 @@ import voluptuous as vol
from homeassistant import auth, data_entry_flow
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import json
PATH_DATA = '.users.json'
STORAGE_VERSION = 1
STORAGE_KEY = 'auth_provider.homeassistant'
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
}, extra=vol.PREVENT_EXTRA)
@ -31,14 +31,22 @@ class InvalidUser(HomeAssistantError):
class Data:
"""Hold the user data."""
def __init__(self, path, data):
def __init__(self, hass):
"""Initialize the user data store."""
self.path = path
self.hass = hass
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self._data = None
async def async_load(self):
"""Load stored data."""
data = await self._store.async_load()
if data is None:
data = {
'salt': auth.generate_secret(),
'users': []
}
self._data = data
@property
@ -99,14 +107,9 @@ class Data:
else:
raise InvalidUser
def save(self):
async def async_save(self):
"""Save data."""
json.save_json(self.path, self._data)
def load_data(path):
"""Load auth data."""
return Data(path, json.load_json(path, None))
await self._store.async_save(self._data)
@auth.AUTH_PROVIDERS.register('homeassistant')
@ -121,12 +124,10 @@ class HassAuthProvider(auth.AuthProvider):
async def async_validate_login(self, username, password):
"""Helper to validate a username and password."""
def validate():
"""Validate creds."""
data = self._auth_data()
data.validate_login(username, password)
await self.hass.async_add_job(validate)
data = Data(self.hass)
await data.async_load()
await self.hass.async_add_executor_job(
data.validate_login, username, password)
async def async_get_or_create_credentials(self, flow_result):
"""Get credentials based on the flow result."""
@ -141,10 +142,6 @@ class HassAuthProvider(auth.AuthProvider):
'username': username
})
def _auth_data(self):
"""Return the auth provider data."""
return load_data(self.hass.config.path(PATH_DATA))
class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow."""

View File

@ -1,7 +1,9 @@
"""Script to manage users for the Home Assistant auth provider."""
import argparse
import asyncio
import os
from homeassistant.core import HomeAssistant
from homeassistant.config import get_default_config_dir
from homeassistant.auth_providers import homeassistant as hass_auth
@ -17,7 +19,8 @@ def run(args):
default=get_default_config_dir(),
help="Directory that contains the Home Assistant configuration")
subparsers = parser.add_subparsers()
subparsers = parser.add_subparsers(dest='func')
subparsers.required = True
parser_list = subparsers.add_parser('list')
parser_list.set_defaults(func=list_users)
@ -37,11 +40,15 @@ def run(args):
parser_change_pw.set_defaults(func=change_password)
args = parser.parse_args(args)
path = os.path.join(os.getcwd(), args.config, hass_auth.PATH_DATA)
args.func(hass_auth.load_data(path), args)
loop = asyncio.get_event_loop()
hass = HomeAssistant(loop=loop)
hass.config.config_dir = os.path.join(os.getcwd(), args.config)
data = hass_auth.Data(hass)
loop.run_until_complete(data.async_load())
loop.run_until_complete(args.func(data, args))
def list_users(data, args):
async def list_users(data, args):
"""List the users."""
count = 0
for user in data.users:
@ -52,14 +59,14 @@ def list_users(data, args):
print("Total users:", count)
def add_user(data, args):
async def add_user(data, args):
"""Create a user."""
data.add_user(args.username, args.password)
data.save()
await data.async_save()
print("User created")
def validate_login(data, args):
async def validate_login(data, args):
"""Validate a login."""
try:
data.validate_login(args.username, args.password)
@ -68,11 +75,11 @@ def validate_login(data, args):
print("Auth invalid")
def change_password(data, args):
async def change_password(data, args):
"""Change password."""
try:
data.change_password(args.username, args.new_password)
data.save()
await data.async_save()
print("Password changed")
except hass_auth.InvalidUser:
print("User not found")

View File

@ -1,60 +1,48 @@
"""Test the Home Assistant local auth provider."""
from unittest.mock import patch, mock_open
import pytest
from homeassistant import data_entry_flow
from homeassistant.auth_providers import homeassistant as hass_auth
MOCK_PATH = '/bla/users.json'
JSON__OPEN_PATH = 'homeassistant.util.json.open'
@pytest.fixture
def data(hass):
"""Create a loaded data class."""
data = hass_auth.Data(hass)
hass.loop.run_until_complete(data.async_load())
return data
def test_initialize_empty_config_file_not_found():
"""Test that we initialize an empty config."""
with patch('homeassistant.util.json.open', side_effect=FileNotFoundError):
data = hass_auth.load_data(MOCK_PATH)
assert data is not None
def test_adding_user():
async def test_adding_user(data, hass):
"""Test adding a user."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass')
data.validate_login('test-user', 'test-pass')
def test_adding_user_duplicate_username():
async def test_adding_user_duplicate_username(data, hass):
"""Test adding a user."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidUser):
data.add_user('test-user', 'other-pass')
def test_validating_password_invalid_user():
async def test_validating_password_invalid_user(data, hass):
"""Test validating an invalid user."""
data = hass_auth.Data(MOCK_PATH, None)
with pytest.raises(hass_auth.InvalidAuth):
data.validate_login('non-existing', 'pw')
def test_validating_password_invalid_password():
async def test_validating_password_invalid_password(data, hass):
"""Test validating an invalid user."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidAuth):
data.validate_login('test-user', 'invalid-pass')
def test_changing_password():
async def test_changing_password(data, hass):
"""Test adding a user."""
user = 'test-user'
data = hass_auth.Data(MOCK_PATH, None)
data.add_user(user, 'test-pass')
data.change_password(user, 'new-pass')
@ -64,61 +52,50 @@ def test_changing_password():
data.validate_login(user, 'new-pass')
def test_changing_password_raises_invalid_user():
async def test_changing_password_raises_invalid_user(data, hass):
"""Test that we initialize an empty config."""
data = hass_auth.Data(MOCK_PATH, None)
with pytest.raises(hass_auth.InvalidUser):
data.change_password('non-existing', 'pw')
async def test_login_flow_validates(hass):
async def test_login_flow_validates(data, hass):
"""Test login flow."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass')
await data.async_save()
provider = hass_auth.HassAuthProvider(hass, None, {})
flow = hass_auth.LoginFlow(provider)
result = await flow.async_step_init()
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
with patch.object(provider, '_auth_data', return_value=data):
result = await flow.async_step_init({
'username': 'incorrect-user',
'password': 'test-pass',
})
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['errors']['base'] == 'invalid_auth'
result = await flow.async_step_init({
'username': 'incorrect-user',
'password': 'test-pass',
})
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['errors']['base'] == 'invalid_auth'
result = await flow.async_step_init({
'username': 'test-user',
'password': 'incorrect-pass',
})
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['errors']['base'] == 'invalid_auth'
result = await flow.async_step_init({
'username': 'test-user',
'password': 'incorrect-pass',
})
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['errors']['base'] == 'invalid_auth'
result = await flow.async_step_init({
'username': 'test-user',
'password': 'test-pass',
})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
result = await flow.async_step_init({
'username': 'test-user',
'password': 'test-pass',
})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
async def test_saving_loading(hass):
async def test_saving_loading(data, hass):
"""Test saving and loading JSON."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass')
data.add_user('second-user', 'second-pass')
await data.async_save()
with patch(JSON__OPEN_PATH, mock_open(), create=True) as mock_write:
await hass.async_add_job(data.save)
# Mock open calls are: open file, context enter, write, context leave
written = mock_write.mock_calls[2][1][0]
with patch('os.path.isfile', return_value=True), \
patch(JSON__OPEN_PATH, mock_open(read_data=written), create=True):
await hass.async_add_job(hass_auth.load_data, MOCK_PATH)
data = hass_auth.Data(hass)
await data.async_load()
data.validate_login('test-user', 'test-pass')
data.validate_login('second-user', 'second-pass')

View File

@ -6,16 +6,21 @@ import pytest
from homeassistant.scripts import auth as script_auth
from homeassistant.auth_providers import homeassistant as hass_auth
MOCK_PATH = '/bla/users.json'
@pytest.fixture
def data(hass):
"""Create a loaded data class."""
data = hass_auth.Data(hass)
hass.loop.run_until_complete(data.async_load())
return data
def test_list_user(capsys):
async def test_list_user(data, capsys):
"""Test we can list users."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass')
data.add_user('second-user', 'second-pass')
script_auth.list_users(data, None)
await script_auth.list_users(data, None)
captured = capsys.readouterr()
@ -28,15 +33,12 @@ def test_list_user(capsys):
])
def test_add_user(capsys):
async def test_add_user(data, capsys, hass_storage):
"""Test we can add a user."""
data = hass_auth.Data(MOCK_PATH, None)
await script_auth.add_user(
data, Mock(username='paulus', password='test-pass'))
with patch.object(data, 'save') as mock_save:
script_auth.add_user(
data, Mock(username='paulus', password='test-pass'))
assert len(mock_save.mock_calls) == 1
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
captured = capsys.readouterr()
assert captured.out == 'User created\n'
@ -45,37 +47,34 @@ def test_add_user(capsys):
data.validate_login('paulus', 'test-pass')
def test_validate_login(capsys):
async def test_validate_login(data, capsys):
"""Test we can validate a user login."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass')
script_auth.validate_login(
await script_auth.validate_login(
data, Mock(username='test-user', password='test-pass'))
captured = capsys.readouterr()
assert captured.out == 'Auth valid\n'
script_auth.validate_login(
await script_auth.validate_login(
data, Mock(username='test-user', password='invalid-pass'))
captured = capsys.readouterr()
assert captured.out == 'Auth invalid\n'
script_auth.validate_login(
await script_auth.validate_login(
data, Mock(username='invalid-user', password='test-pass'))
captured = capsys.readouterr()
assert captured.out == 'Auth invalid\n'
def test_change_password(capsys):
async def test_change_password(data, capsys, hass_storage):
"""Test we can change a password."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass')
with patch.object(data, 'save') as mock_save:
script_auth.change_password(
data, Mock(username='test-user', new_password='new-pass'))
await script_auth.change_password(
data, Mock(username='test-user', new_password='new-pass'))
assert len(mock_save.mock_calls) == 1
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
captured = capsys.readouterr()
assert captured.out == 'Password changed\n'
data.validate_login('test-user', 'new-pass')
@ -83,18 +82,35 @@ def test_change_password(capsys):
data.validate_login('test-user', 'test-pass')
def test_change_password_invalid_user(capsys):
async def test_change_password_invalid_user(data, capsys, hass_storage):
"""Test changing password of non-existing user."""
data = hass_auth.Data(MOCK_PATH, None)
data.add_user('test-user', 'test-pass')
with patch.object(data, 'save') as mock_save:
script_auth.change_password(
data, Mock(username='invalid-user', new_password='new-pass'))
await script_auth.change_password(
data, Mock(username='invalid-user', new_password='new-pass'))
assert len(mock_save.mock_calls) == 0
assert hass_auth.STORAGE_KEY not in hass_storage
captured = capsys.readouterr()
assert captured.out == 'User not found\n'
data.validate_login('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidAuth):
data.validate_login('invalid-user', 'new-pass')
def test_parsing_args(loop):
"""Test we parse args correctly."""
called = False
async def mock_func(data, args2):
"""Mock function to be called."""
nonlocal called
called = True
assert data.hass.config.config_dir == '/somewhere/config'
assert args2 is args
args = Mock(config='/somewhere/config', func=mock_func)
with patch('argparse.ArgumentParser.parse_args', return_value=args):
script_auth.run(None)
assert called, 'Mock function did not get called'