mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Migrate home assistant auth provider to use storage helper (#15200)
This commit is contained in:
parent
39971ee919
commit
26590e244c
@ -8,10 +8,10 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant import auth, data_entry_flow
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import json
|
||||
|
||||
|
||||
PATH_DATA = '.users.json'
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = 'auth_provider.homeassistant'
|
||||
|
||||
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
|
||||
}, extra=vol.PREVENT_EXTRA)
|
||||
@ -31,14 +31,22 @@ class InvalidUser(HomeAssistantError):
|
||||
class Data:
|
||||
"""Hold the user data."""
|
||||
|
||||
def __init__(self, path, data):
|
||||
def __init__(self, hass):
|
||||
"""Initialize the user data store."""
|
||||
self.path = path
|
||||
self.hass = hass
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
self._data = None
|
||||
|
||||
async def async_load(self):
|
||||
"""Load stored data."""
|
||||
data = await self._store.async_load()
|
||||
|
||||
if data is None:
|
||||
data = {
|
||||
'salt': auth.generate_secret(),
|
||||
'users': []
|
||||
}
|
||||
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
@ -99,14 +107,9 @@ class Data:
|
||||
else:
|
||||
raise InvalidUser
|
||||
|
||||
def save(self):
|
||||
async def async_save(self):
|
||||
"""Save data."""
|
||||
json.save_json(self.path, self._data)
|
||||
|
||||
|
||||
def load_data(path):
|
||||
"""Load auth data."""
|
||||
return Data(path, json.load_json(path, None))
|
||||
await self._store.async_save(self._data)
|
||||
|
||||
|
||||
@auth.AUTH_PROVIDERS.register('homeassistant')
|
||||
@ -121,12 +124,10 @@ class HassAuthProvider(auth.AuthProvider):
|
||||
|
||||
async def async_validate_login(self, username, password):
|
||||
"""Helper to validate a username and password."""
|
||||
def validate():
|
||||
"""Validate creds."""
|
||||
data = self._auth_data()
|
||||
data.validate_login(username, password)
|
||||
|
||||
await self.hass.async_add_job(validate)
|
||||
data = Data(self.hass)
|
||||
await data.async_load()
|
||||
await self.hass.async_add_executor_job(
|
||||
data.validate_login, username, password)
|
||||
|
||||
async def async_get_or_create_credentials(self, flow_result):
|
||||
"""Get credentials based on the flow result."""
|
||||
@ -141,10 +142,6 @@ class HassAuthProvider(auth.AuthProvider):
|
||||
'username': username
|
||||
})
|
||||
|
||||
def _auth_data(self):
|
||||
"""Return the auth provider data."""
|
||||
return load_data(self.hass.config.path(PATH_DATA))
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
"""Handler for the login flow."""
|
||||
|
@ -1,7 +1,9 @@
|
||||
"""Script to manage users for the Home Assistant auth provider."""
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.config import get_default_config_dir
|
||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||
|
||||
@ -17,7 +19,8 @@ def run(args):
|
||||
default=get_default_config_dir(),
|
||||
help="Directory that contains the Home Assistant configuration")
|
||||
|
||||
subparsers = parser.add_subparsers()
|
||||
subparsers = parser.add_subparsers(dest='func')
|
||||
subparsers.required = True
|
||||
parser_list = subparsers.add_parser('list')
|
||||
parser_list.set_defaults(func=list_users)
|
||||
|
||||
@ -37,11 +40,15 @@ def run(args):
|
||||
parser_change_pw.set_defaults(func=change_password)
|
||||
|
||||
args = parser.parse_args(args)
|
||||
path = os.path.join(os.getcwd(), args.config, hass_auth.PATH_DATA)
|
||||
args.func(hass_auth.load_data(path), args)
|
||||
loop = asyncio.get_event_loop()
|
||||
hass = HomeAssistant(loop=loop)
|
||||
hass.config.config_dir = os.path.join(os.getcwd(), args.config)
|
||||
data = hass_auth.Data(hass)
|
||||
loop.run_until_complete(data.async_load())
|
||||
loop.run_until_complete(args.func(data, args))
|
||||
|
||||
|
||||
def list_users(data, args):
|
||||
async def list_users(data, args):
|
||||
"""List the users."""
|
||||
count = 0
|
||||
for user in data.users:
|
||||
@ -52,14 +59,14 @@ def list_users(data, args):
|
||||
print("Total users:", count)
|
||||
|
||||
|
||||
def add_user(data, args):
|
||||
async def add_user(data, args):
|
||||
"""Create a user."""
|
||||
data.add_user(args.username, args.password)
|
||||
data.save()
|
||||
await data.async_save()
|
||||
print("User created")
|
||||
|
||||
|
||||
def validate_login(data, args):
|
||||
async def validate_login(data, args):
|
||||
"""Validate a login."""
|
||||
try:
|
||||
data.validate_login(args.username, args.password)
|
||||
@ -68,11 +75,11 @@ def validate_login(data, args):
|
||||
print("Auth invalid")
|
||||
|
||||
|
||||
def change_password(data, args):
|
||||
async def change_password(data, args):
|
||||
"""Change password."""
|
||||
try:
|
||||
data.change_password(args.username, args.new_password)
|
||||
data.save()
|
||||
await data.async_save()
|
||||
print("Password changed")
|
||||
except hass_auth.InvalidUser:
|
||||
print("User not found")
|
||||
|
@ -1,60 +1,48 @@
|
||||
"""Test the Home Assistant local auth provider."""
|
||||
from unittest.mock import patch, mock_open
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||
|
||||
|
||||
MOCK_PATH = '/bla/users.json'
|
||||
JSON__OPEN_PATH = 'homeassistant.util.json.open'
|
||||
@pytest.fixture
|
||||
def data(hass):
|
||||
"""Create a loaded data class."""
|
||||
data = hass_auth.Data(hass)
|
||||
hass.loop.run_until_complete(data.async_load())
|
||||
return data
|
||||
|
||||
|
||||
def test_initialize_empty_config_file_not_found():
|
||||
"""Test that we initialize an empty config."""
|
||||
with patch('homeassistant.util.json.open', side_effect=FileNotFoundError):
|
||||
data = hass_auth.load_data(MOCK_PATH)
|
||||
|
||||
assert data is not None
|
||||
|
||||
|
||||
def test_adding_user():
|
||||
async def test_adding_user(data, hass):
|
||||
"""Test adding a user."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data.validate_login('test-user', 'test-pass')
|
||||
|
||||
|
||||
def test_adding_user_duplicate_username():
|
||||
async def test_adding_user_duplicate_username(data, hass):
|
||||
"""Test adding a user."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
data.add_user('test-user', 'test-pass')
|
||||
with pytest.raises(hass_auth.InvalidUser):
|
||||
data.add_user('test-user', 'other-pass')
|
||||
|
||||
|
||||
def test_validating_password_invalid_user():
|
||||
async def test_validating_password_invalid_user(data, hass):
|
||||
"""Test validating an invalid user."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
|
||||
with pytest.raises(hass_auth.InvalidAuth):
|
||||
data.validate_login('non-existing', 'pw')
|
||||
|
||||
|
||||
def test_validating_password_invalid_password():
|
||||
async def test_validating_password_invalid_password(data, hass):
|
||||
"""Test validating an invalid user."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
data.add_user('test-user', 'test-pass')
|
||||
|
||||
with pytest.raises(hass_auth.InvalidAuth):
|
||||
data.validate_login('test-user', 'invalid-pass')
|
||||
|
||||
|
||||
def test_changing_password():
|
||||
async def test_changing_password(data, hass):
|
||||
"""Test adding a user."""
|
||||
user = 'test-user'
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
data.add_user(user, 'test-pass')
|
||||
data.change_password(user, 'new-pass')
|
||||
|
||||
@ -64,61 +52,50 @@ def test_changing_password():
|
||||
data.validate_login(user, 'new-pass')
|
||||
|
||||
|
||||
def test_changing_password_raises_invalid_user():
|
||||
async def test_changing_password_raises_invalid_user(data, hass):
|
||||
"""Test that we initialize an empty config."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
|
||||
with pytest.raises(hass_auth.InvalidUser):
|
||||
data.change_password('non-existing', 'pw')
|
||||
|
||||
|
||||
async def test_login_flow_validates(hass):
|
||||
async def test_login_flow_validates(data, hass):
|
||||
"""Test login flow."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
data.add_user('test-user', 'test-pass')
|
||||
await data.async_save()
|
||||
|
||||
provider = hass_auth.HassAuthProvider(hass, None, {})
|
||||
flow = hass_auth.LoginFlow(provider)
|
||||
result = await flow.async_step_init()
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||
|
||||
with patch.object(provider, '_auth_data', return_value=data):
|
||||
result = await flow.async_step_init({
|
||||
'username': 'incorrect-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result['errors']['base'] == 'invalid_auth'
|
||||
result = await flow.async_step_init({
|
||||
'username': 'incorrect-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result['errors']['base'] == 'invalid_auth'
|
||||
|
||||
result = await flow.async_step_init({
|
||||
'username': 'test-user',
|
||||
'password': 'incorrect-pass',
|
||||
})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result['errors']['base'] == 'invalid_auth'
|
||||
result = await flow.async_step_init({
|
||||
'username': 'test-user',
|
||||
'password': 'incorrect-pass',
|
||||
})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result['errors']['base'] == 'invalid_auth'
|
||||
|
||||
result = await flow.async_step_init({
|
||||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
result = await flow.async_step_init({
|
||||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
|
||||
|
||||
async def test_saving_loading(hass):
|
||||
async def test_saving_loading(data, hass):
|
||||
"""Test saving and loading JSON."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data.add_user('second-user', 'second-pass')
|
||||
await data.async_save()
|
||||
|
||||
with patch(JSON__OPEN_PATH, mock_open(), create=True) as mock_write:
|
||||
await hass.async_add_job(data.save)
|
||||
|
||||
# Mock open calls are: open file, context enter, write, context leave
|
||||
written = mock_write.mock_calls[2][1][0]
|
||||
|
||||
with patch('os.path.isfile', return_value=True), \
|
||||
patch(JSON__OPEN_PATH, mock_open(read_data=written), create=True):
|
||||
await hass.async_add_job(hass_auth.load_data, MOCK_PATH)
|
||||
|
||||
data = hass_auth.Data(hass)
|
||||
await data.async_load()
|
||||
data.validate_login('test-user', 'test-pass')
|
||||
data.validate_login('second-user', 'second-pass')
|
||||
|
@ -6,16 +6,21 @@ import pytest
|
||||
from homeassistant.scripts import auth as script_auth
|
||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||
|
||||
MOCK_PATH = '/bla/users.json'
|
||||
|
||||
@pytest.fixture
|
||||
def data(hass):
|
||||
"""Create a loaded data class."""
|
||||
data = hass_auth.Data(hass)
|
||||
hass.loop.run_until_complete(data.async_load())
|
||||
return data
|
||||
|
||||
|
||||
def test_list_user(capsys):
|
||||
async def test_list_user(data, capsys):
|
||||
"""Test we can list users."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
data.add_user('test-user', 'test-pass')
|
||||
data.add_user('second-user', 'second-pass')
|
||||
|
||||
script_auth.list_users(data, None)
|
||||
await script_auth.list_users(data, None)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
|
||||
@ -28,15 +33,12 @@ def test_list_user(capsys):
|
||||
])
|
||||
|
||||
|
||||
def test_add_user(capsys):
|
||||
async def test_add_user(data, capsys, hass_storage):
|
||||
"""Test we can add a user."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
await script_auth.add_user(
|
||||
data, Mock(username='paulus', password='test-pass'))
|
||||
|
||||
with patch.object(data, 'save') as mock_save:
|
||||
script_auth.add_user(
|
||||
data, Mock(username='paulus', password='test-pass'))
|
||||
|
||||
assert len(mock_save.mock_calls) == 1
|
||||
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == 'User created\n'
|
||||
@ -45,37 +47,34 @@ def test_add_user(capsys):
|
||||
data.validate_login('paulus', 'test-pass')
|
||||
|
||||
|
||||
def test_validate_login(capsys):
|
||||
async def test_validate_login(data, capsys):
|
||||
"""Test we can validate a user login."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
data.add_user('test-user', 'test-pass')
|
||||
|
||||
script_auth.validate_login(
|
||||
await script_auth.validate_login(
|
||||
data, Mock(username='test-user', password='test-pass'))
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == 'Auth valid\n'
|
||||
|
||||
script_auth.validate_login(
|
||||
await script_auth.validate_login(
|
||||
data, Mock(username='test-user', password='invalid-pass'))
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == 'Auth invalid\n'
|
||||
|
||||
script_auth.validate_login(
|
||||
await script_auth.validate_login(
|
||||
data, Mock(username='invalid-user', password='test-pass'))
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == 'Auth invalid\n'
|
||||
|
||||
|
||||
def test_change_password(capsys):
|
||||
async def test_change_password(data, capsys, hass_storage):
|
||||
"""Test we can change a password."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
data.add_user('test-user', 'test-pass')
|
||||
|
||||
with patch.object(data, 'save') as mock_save:
|
||||
script_auth.change_password(
|
||||
data, Mock(username='test-user', new_password='new-pass'))
|
||||
await script_auth.change_password(
|
||||
data, Mock(username='test-user', new_password='new-pass'))
|
||||
|
||||
assert len(mock_save.mock_calls) == 1
|
||||
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == 'Password changed\n'
|
||||
data.validate_login('test-user', 'new-pass')
|
||||
@ -83,18 +82,35 @@ def test_change_password(capsys):
|
||||
data.validate_login('test-user', 'test-pass')
|
||||
|
||||
|
||||
def test_change_password_invalid_user(capsys):
|
||||
async def test_change_password_invalid_user(data, capsys, hass_storage):
|
||||
"""Test changing password of non-existing user."""
|
||||
data = hass_auth.Data(MOCK_PATH, None)
|
||||
data.add_user('test-user', 'test-pass')
|
||||
|
||||
with patch.object(data, 'save') as mock_save:
|
||||
script_auth.change_password(
|
||||
data, Mock(username='invalid-user', new_password='new-pass'))
|
||||
await script_auth.change_password(
|
||||
data, Mock(username='invalid-user', new_password='new-pass'))
|
||||
|
||||
assert len(mock_save.mock_calls) == 0
|
||||
assert hass_auth.STORAGE_KEY not in hass_storage
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == 'User not found\n'
|
||||
data.validate_login('test-user', 'test-pass')
|
||||
with pytest.raises(hass_auth.InvalidAuth):
|
||||
data.validate_login('invalid-user', 'new-pass')
|
||||
|
||||
|
||||
def test_parsing_args(loop):
|
||||
"""Test we parse args correctly."""
|
||||
called = False
|
||||
|
||||
async def mock_func(data, args2):
|
||||
"""Mock function to be called."""
|
||||
nonlocal called
|
||||
called = True
|
||||
assert data.hass.config.config_dir == '/somewhere/config'
|
||||
assert args2 is args
|
||||
|
||||
args = Mock(config='/somewhere/config', func=mock_func)
|
||||
|
||||
with patch('argparse.ArgumentParser.parse_args', return_value=args):
|
||||
script_auth.run(None)
|
||||
|
||||
assert called, 'Mock function did not get called'
|
||||
|
Loading…
x
Reference in New Issue
Block a user