User management (#15420)

* User management

* Lint

* Fix dict

* Reuse data instance

* OrderedDict all the way
This commit is contained in:
Paulus Schoutsen 2018-07-13 15:31:20 +02:00 committed by GitHub
parent 84858f5c19
commit 70fe463ef0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 982 additions and 116 deletions

View File

@ -51,7 +51,7 @@ class AuthManager:
self.login_flow = data_entry_flow.FlowManager( self.login_flow = data_entry_flow.FlowManager(
hass, self._async_create_login_flow, hass, self._async_create_login_flow,
self._async_finish_login_flow) self._async_finish_login_flow)
self._access_tokens = {} self._access_tokens = OrderedDict()
@property @property
def active(self): def active(self):
@ -71,9 +71,13 @@ class AuthManager:
return False return False
@property @property
def async_auth_providers(self): def auth_providers(self):
"""Return a list of available auth providers.""" """Return a list of available auth providers."""
return self._providers.values() return list(self._providers.values())
async def async_get_users(self):
"""Retrieve all users."""
return await self._store.async_get_users()
async def async_get_user(self, user_id): async def async_get_user(self, user_id):
"""Retrieve a user.""" """Retrieve a user."""
@ -87,6 +91,13 @@ class AuthManager:
is_active=True, is_active=True,
) )
async def async_create_user(self, name):
"""Create a user."""
return await self._store.async_create_user(
name=name,
is_active=True,
)
async def async_get_or_create_user(self, credentials): async def async_get_or_create_user(self, credentials):
"""Get or create a user.""" """Get or create a user."""
if not credentials.is_new: if not credentials.is_new:
@ -98,6 +109,10 @@ class AuthManager:
raise ValueError('Unable to find the user.') raise ValueError('Unable to find the user.')
auth_provider = self._async_get_auth_provider(credentials) auth_provider = self._async_get_auth_provider(credentials)
if auth_provider is None:
raise RuntimeError('Credential with unknown provider encountered')
info = await auth_provider.async_user_meta_for_credentials( info = await auth_provider.async_user_meta_for_credentials(
credentials) credentials)
@ -122,8 +137,26 @@ class AuthManager:
async def async_remove_user(self, user): async def async_remove_user(self, user):
"""Remove a user.""" """Remove a user."""
tasks = [
self.async_remove_credentials(credentials)
for credentials in user.credentials
]
if tasks:
await asyncio.wait(tasks)
await self._store.async_remove_user(user) await self._store.async_remove_user(user)
async def async_remove_credentials(self, credentials):
"""Remove credentials."""
provider = self._async_get_auth_provider(credentials)
if (provider is not None and
hasattr(provider, 'async_will_remove_credentials')):
await provider.async_will_remove_credentials(credentials)
await self._store.async_remove_credentials(credentials)
async def async_create_refresh_token(self, user, client_id=None): async def async_create_refresh_token(self, user, client_id=None):
"""Create a new refresh token for a user.""" """Create a new refresh token for a user."""
if not user.is_active: if not user.is_active:
@ -168,10 +201,6 @@ class AuthManager:
"""Create a login flow.""" """Create a login flow."""
auth_provider = self._providers[handler] auth_provider = self._providers[handler]
if not auth_provider.initialized:
auth_provider.initialized = True
await auth_provider.async_initialize()
return await auth_provider.async_credential_flow() return await auth_provider.async_credential_flow()
async def _async_finish_login_flow(self, result): async def _async_finish_login_flow(self, result):
@ -188,4 +217,4 @@ class AuthManager:
"""Helper to get auth provider from a set of credentials.""" """Helper to get auth provider from a set of credentials."""
auth_provider_key = (credentials.auth_provider_type, auth_provider_key = (credentials.auth_provider_type,
credentials.auth_provider_id) credentials.auth_provider_id)
return self._providers[auth_provider_key] return self._providers.get(auth_provider_key)

View File

@ -1,4 +1,5 @@
"""Storage for auth models.""" """Storage for auth models."""
from collections import OrderedDict
from datetime import timedelta from datetime import timedelta
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -80,6 +81,22 @@ class AuthStore:
self._users.pop(user.id) self._users.pop(user.id)
await self.async_save() await self.async_save()
async def async_remove_credentials(self, credentials):
"""Remove credentials."""
for user in self._users.values():
found = None
for index, cred in enumerate(user.credentials):
if cred is credentials:
found = index
break
if found is not None:
user.credentials.pop(found)
break
await self.async_save()
async def async_create_refresh_token(self, user, client_id=None): async def async_create_refresh_token(self, user, client_id=None):
"""Create a new token for a user.""" """Create a new token for a user."""
refresh_token = models.RefreshToken(user=user, client_id=client_id) refresh_token = models.RefreshToken(user=user, client_id=client_id)
@ -108,14 +125,14 @@ class AuthStore:
if self._users is not None: if self._users is not None:
return return
users = OrderedDict()
if data is None: if data is None:
self._users = {} self._users = users
return return
users = { for user_dict in data['users']:
user_dict['id']: models.User(**user_dict) users[user_dict['id']] = models.User(**user_dict)
for user_dict in data['users']
}
for cred_dict in data['credentials']: for cred_dict in data['credentials']:
users[cred_dict['user_id']].credentials.append(models.Credentials( users[cred_dict['user_id']].credentials.append(models.Credentials(
@ -126,7 +143,7 @@ class AuthStore:
data=cred_dict['data'], data=cred_dict['data'],
)) ))
refresh_tokens = {} refresh_tokens = OrderedDict()
for rt_dict in data['refresh_tokens']: for rt_dict in data['refresh_tokens']:
token = models.RefreshToken( token = models.RefreshToken(

View File

@ -77,8 +77,6 @@ class AuthProvider:
DEFAULT_TITLE = 'Unnamed auth provider' DEFAULT_TITLE = 'Unnamed auth provider'
initialized = False
def __init__(self, hass, store, config): def __init__(self, hass, store, config):
"""Initialize an auth provider.""" """Initialize an auth provider."""
self.hass = hass self.hass = hass
@ -125,12 +123,6 @@ class AuthProvider:
# Implement by extending class # Implement by extending class
async def async_initialize(self):
"""Initialize the auth provider.
Optional.
"""
async def async_credential_flow(self): async def async_credential_flow(self):
"""Return the data flow for logging in with auth provider.""" """Return the data flow for logging in with auth provider."""
raise NotImplementedError raise NotImplementedError

View File

@ -7,6 +7,8 @@ import hmac
import voluptuous as vol import voluptuous as vol
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.const import CONF_ID
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.auth.util import generate_secret from homeassistant.auth.util import generate_secret
@ -16,8 +18,17 @@ from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
STORAGE_VERSION = 1 STORAGE_VERSION = 1
STORAGE_KEY = 'auth_provider.homeassistant' STORAGE_KEY = 'auth_provider.homeassistant'
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
}, extra=vol.PREVENT_EXTRA) def _disallow_id(conf):
"""Disallow ID in config."""
if CONF_ID in conf:
raise vol.Invalid(
'ID is not allowed for the homeassistant auth provider.')
return conf
CONFIG_SCHEMA = vol.All(AUTH_PROVIDER_SCHEMA, _disallow_id)
class InvalidAuth(HomeAssistantError): class InvalidAuth(HomeAssistantError):
@ -88,8 +99,8 @@ class Data:
hashed = base64.b64encode(hashed).decode() hashed = base64.b64encode(hashed).decode()
return hashed return hashed
def add_user(self, username, password): def add_auth(self, username, password):
"""Add a user.""" """Add a new authenticated user/pass."""
if any(user['username'] == username for user in self.users): if any(user['username'] == username for user in self.users):
raise InvalidUser raise InvalidUser
@ -98,8 +109,22 @@ class Data:
'password': self.hash_password(password, True), 'password': self.hash_password(password, True),
}) })
@callback
def async_remove_auth(self, username):
"""Remove authentication."""
index = None
for i, user in enumerate(self.users):
if user['username'] == username:
index = i
break
if index is None:
raise InvalidUser
self.users.pop(index)
def change_password(self, username, new_password): def change_password(self, username, new_password):
"""Update the password of a user. """Update the password.
Raises InvalidUser if user cannot be found. Raises InvalidUser if user cannot be found.
""" """
@ -121,16 +146,24 @@ class HassAuthProvider(AuthProvider):
DEFAULT_TITLE = 'Home Assistant Local' DEFAULT_TITLE = 'Home Assistant Local'
data = None
async def async_initialize(self):
"""Initialize the auth provider."""
self.data = Data(self.hass)
await self.data.async_load()
async def async_credential_flow(self): async def async_credential_flow(self):
"""Return a flow to login.""" """Return a flow to login."""
return LoginFlow(self) return LoginFlow(self)
async def async_validate_login(self, username, password): async def async_validate_login(self, username, password):
"""Helper to validate a username and password.""" """Helper to validate a username and password."""
data = Data(self.hass) if self.data is None:
await data.async_load() await self.async_initialize()
await self.hass.async_add_executor_job( await self.hass.async_add_executor_job(
data.validate_login, username, password) self.data.validate_login, username, password)
async def async_get_or_create_credentials(self, flow_result): async def async_get_or_create_credentials(self, flow_result):
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
@ -145,6 +178,24 @@ class HassAuthProvider(AuthProvider):
'username': username 'username': username
}) })
async def async_user_meta_for_credentials(self, credentials):
"""Get extra info for this credential."""
return {
'name': credentials.data['username']
}
async def async_will_remove_credentials(self, credentials):
"""When credentials get removed, also remove the auth."""
if self.data is None:
await self.async_initialize()
try:
self.data.async_remove_auth(credentials.data['username'])
await self.data.async_save()
except InvalidUser:
# Can happen if somehow we didn't clean up a credential
pass
class LoginFlow(data_entry_flow.FlowHandler): class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow.""" """Handler for the login flow."""

View File

@ -152,7 +152,7 @@ class AuthProvidersView(HomeAssistantView):
'name': provider.name, 'name': provider.name,
'id': provider.id, 'id': provider.id,
'type': provider.type, 'type': provider.type,
} for provider in request.app['hass'].auth.async_auth_providers]) } for provider in request.app['hass'].auth.auth_providers])
class LoginFlowIndexView(FlowManagerIndexView): class LoginFlowIndexView(FlowManagerIndexView):

View File

@ -66,8 +66,8 @@ CAMERA_SERVICE_SNAPSHOT = CAMERA_SERVICE_SCHEMA.extend({
WS_TYPE_CAMERA_THUMBNAIL = 'camera_thumbnail' WS_TYPE_CAMERA_THUMBNAIL = 'camera_thumbnail'
SCHEMA_WS_CAMERA_THUMBNAIL = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({ SCHEMA_WS_CAMERA_THUMBNAIL = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
'type': WS_TYPE_CAMERA_THUMBNAIL, vol.Required('type'): WS_TYPE_CAMERA_THUMBNAIL,
'entity_id': cv.entity_id vol.Required('entity_id'): cv.entity_id
}) })

View File

@ -49,6 +49,10 @@ async def async_setup(hass, config):
tasks = [setup_panel(panel_name) for panel_name in SECTIONS] tasks = [setup_panel(panel_name) for panel_name in SECTIONS]
if hass.auth.active:
tasks.append(setup_panel('auth'))
tasks.append(setup_panel('auth_provider_homeassistant'))
for panel_name in ON_DEMAND: for panel_name in ON_DEMAND:
if panel_name in hass.config.components: if panel_name in hass.config.components:
tasks.append(setup_panel(panel_name)) tasks.append(setup_panel(panel_name))

View File

@ -0,0 +1,113 @@
"""Offer API to configure Home Assistant auth."""
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components import websocket_api
WS_TYPE_LIST = 'config/auth/list'
SCHEMA_WS_LIST = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): WS_TYPE_LIST,
})
WS_TYPE_DELETE = 'config/auth/delete'
SCHEMA_WS_DELETE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): WS_TYPE_DELETE,
vol.Required('user_id'): str,
})
WS_TYPE_CREATE = 'config/auth/create'
SCHEMA_WS_CREATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): WS_TYPE_CREATE,
vol.Required('name'): str,
})
async def async_setup(hass):
"""Enable the Home Assistant views."""
hass.components.websocket_api.async_register_command(
WS_TYPE_LIST, websocket_list,
SCHEMA_WS_LIST
)
hass.components.websocket_api.async_register_command(
WS_TYPE_DELETE, websocket_delete,
SCHEMA_WS_DELETE
)
hass.components.websocket_api.async_register_command(
WS_TYPE_CREATE, websocket_create,
SCHEMA_WS_CREATE
)
return True
@callback
@websocket_api.require_owner
def websocket_list(hass, connection, msg):
"""Return a list of users."""
async def send_users():
"""Send users."""
result = [_user_info(u) for u in await hass.auth.async_get_users()]
connection.send_message_outside(
websocket_api.result_message(msg['id'], result))
hass.async_add_job(send_users())
@callback
@websocket_api.require_owner
def websocket_delete(hass, connection, msg):
"""Delete a user."""
async def delete_user():
"""Delete user."""
if msg['user_id'] == connection.request.get('hass_user').id:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'no_delete_self',
'Unable to delete your own account'))
return
user = await hass.auth.async_get_user(msg['user_id'])
if not user:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'not_found', 'User not found'))
return
await hass.auth.async_remove_user(user)
connection.send_message_outside(
websocket_api.result_message(msg['id']))
hass.async_add_job(delete_user())
@callback
@websocket_api.require_owner
def websocket_create(hass, connection, msg):
"""Create a user."""
async def create_user():
"""Create a user."""
user = await hass.auth.async_create_user(msg['name'])
connection.send_message_outside(
websocket_api.result_message(msg['id'], {
'user': _user_info(user)
}))
hass.async_add_job(create_user())
def _user_info(user):
"""Format a user."""
return {
'id': user.id,
'name': user.name,
'is_owner': user.is_owner,
'is_active': user.is_active,
'system_generated': user.system_generated,
'credentials': [
{
'type': c.auth_provider_type,
} for c in user.credentials
]
}

View File

@ -0,0 +1,120 @@
"""Offer API to configure the Home Assistant auth provider."""
import voluptuous as vol
from homeassistant.auth.providers import homeassistant as auth_ha
from homeassistant.core import callback
from homeassistant.components import websocket_api
WS_TYPE_CREATE = 'config/auth_provider/homeassistant/create'
SCHEMA_WS_CREATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): WS_TYPE_CREATE,
vol.Required('user_id'): str,
vol.Required('username'): str,
vol.Required('password'): str,
})
WS_TYPE_DELETE = 'config/auth_provider/homeassistant/delete'
SCHEMA_WS_DELETE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): WS_TYPE_DELETE,
vol.Required('username'): str,
})
async def async_setup(hass):
"""Enable the Home Assistant views."""
hass.components.websocket_api.async_register_command(
WS_TYPE_CREATE, websocket_create,
SCHEMA_WS_CREATE
)
hass.components.websocket_api.async_register_command(
WS_TYPE_DELETE, websocket_delete,
SCHEMA_WS_DELETE
)
return True
def _get_provider(hass):
"""Get homeassistant auth provider."""
for prv in hass.auth.auth_providers:
if prv.type == 'homeassistant':
return prv
raise RuntimeError('Provider not found')
@callback
@websocket_api.require_owner
def websocket_create(hass, connection, msg):
"""Create credentials and attach to a user."""
async def create_creds():
"""Create credentials."""
provider = _get_provider(hass)
await provider.async_initialize()
user = await hass.auth.async_get_user(msg['user_id'])
if user is None:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'not_found', 'User not found'))
return
if user.system_generated:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'system_generated',
'Cannot add credentials to a system generated user.'))
return
try:
await hass.async_add_executor_job(
provider.data.add_auth, msg['username'], msg['password'])
except auth_ha.InvalidUser:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'username_exists', 'Username already exists'))
return
credentials = await provider.async_get_or_create_credentials({
'username': msg['username']
})
await hass.auth.async_link_user(user, credentials)
await provider.data.async_save()
connection.to_write.put_nowait(websocket_api.result_message(msg['id']))
hass.async_add_job(create_creds())
@callback
@websocket_api.require_owner
def websocket_delete(hass, connection, msg):
"""Delete username and related credential."""
async def delete_creds():
"""Delete user credentials."""
provider = _get_provider(hass)
await provider.async_initialize()
credentials = await provider.async_get_or_create_credentials({
'username': msg['username']
})
# if not new, an existing credential exists.
# Removing the credential will also remove the auth.
if not credentials.is_new:
await hass.auth.async_remove_credentials(credentials)
connection.to_write.put_nowait(
websocket_api.result_message(msg['id']))
return
try:
provider.data.async_remove_auth(msg['username'])
await provider.data.async_save()
except auth_ha.InvalidUser:
connection.to_write.put_nowait(websocket_api.error_message(
msg['id'], 'auth_not_found', 'Given username was not found.'))
return
connection.to_write.put_nowait(
websocket_api.result_message(msg['id']))
hass.async_add_job(delete_creds())

View File

@ -106,6 +106,11 @@ async def async_validate_auth_header(request, api_password=None):
if access_token is None: if access_token is None:
return False return False
user = access_token.refresh_token.user
if not user.is_active:
return False
request['hass_user'] = access_token.refresh_token.user request['hass_user'] = access_token.refresh_token.user
return True return True

View File

@ -7,7 +7,7 @@ https://home-assistant.io/developers/websocket_api/
import asyncio import asyncio
from concurrent import futures from concurrent import futures
from contextlib import suppress from contextlib import suppress
from functools import partial from functools import partial, wraps
import json import json
import logging import logging
@ -196,6 +196,23 @@ def async_register_command(hass, command, handler, schema):
handlers[command] = (handler, schema) handlers[command] = (handler, schema)
def require_owner(func):
"""Websocket decorator to require user to be an owner."""
@wraps(func)
def with_owner(hass, connection, msg):
"""Check owner and call function."""
user = connection.request.get('hass_user')
if user is None or not user.is_owner:
connection.to_write.put_nowait(error_message(
msg['id'], 'unauthorized', 'This command is for owners only.'))
return
func(hass, connection, msg)
return with_owner
async def async_setup(hass, config): async def async_setup(hass, config):
"""Initialize the websocket API.""" """Initialize the websocket API."""
hass.http.register_view(WebsocketAPIView) hass.http.register_view(WebsocketAPIView)
@ -325,6 +342,8 @@ class ActiveConnection:
token = self.hass.auth.async_get_access_token( token = self.hass.auth.async_get_access_token(
msg['access_token']) msg['access_token'])
authenticated = token is not None authenticated = token is not None
if authenticated:
request['hass_user'] = token.refresh_token.user
elif ((not self.hass.auth.active or elif ((not self.hass.auth.active or
self.hass.auth.support_legacy) and self.hass.auth.support_legacy) and

View File

@ -1,8 +1,10 @@
"""Script to manage users for the Home Assistant auth provider.""" """Script to manage users for the Home Assistant auth provider."""
import argparse import argparse
import asyncio import asyncio
import logging
import os import os
from homeassistant.auth import auth_manager_from_config
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.config import get_default_config_dir from homeassistant.config import get_default_config_dir
from homeassistant.auth.providers import homeassistant as hass_auth from homeassistant.auth.providers import homeassistant as hass_auth
@ -42,16 +44,28 @@ def run(args):
args = parser.parse_args(args) args = parser.parse_args(args)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
hass = HomeAssistant(loop=loop) hass = HomeAssistant(loop=loop)
loop.run_until_complete(run_command(hass, args))
# Triggers save on used storage helpers with delay (core auth)
logging.getLogger('homeassistant.core').setLevel(logging.WARNING)
loop.run_until_complete(hass.async_stop())
async def run_command(hass, args):
"""Run the command."""
hass.config.config_dir = os.path.join(os.getcwd(), args.config) hass.config.config_dir = os.path.join(os.getcwd(), args.config)
data = hass_auth.Data(hass) hass.auth = await auth_manager_from_config(hass, [{
loop.run_until_complete(data.async_load()) 'type': 'homeassistant',
loop.run_until_complete(args.func(data, args)) }])
provider = hass.auth.auth_providers[0]
await provider.async_initialize()
await args.func(hass, provider, args)
async def list_users(data, args): async def list_users(hass, provider, args):
"""List the users.""" """List the users."""
count = 0 count = 0
for user in data.users: for user in provider.data.users:
count += 1 count += 1
print(user['username']) print(user['username'])
@ -59,27 +73,40 @@ async def list_users(data, args):
print("Total users:", count) print("Total users:", count)
async def add_user(data, args): async def add_user(hass, provider, args):
"""Create a user.""" """Create a user."""
data.add_user(args.username, args.password) try:
await data.async_save() provider.data.add_auth(args.username, args.password)
except hass_auth.InvalidUser:
print("Username already exists!")
return
credentials = await provider.async_get_or_create_credentials({
'username': args.username
})
user = await hass.auth.async_create_user(args.username)
await hass.auth.async_link_user(user, credentials)
# Save username/password
await provider.data.async_save()
print("User created") print("User created")
async def validate_login(data, args): async def validate_login(hass, provider, args):
"""Validate a login.""" """Validate a login."""
try: try:
data.validate_login(args.username, args.password) provider.data.validate_login(args.username, args.password)
print("Auth valid") print("Auth valid")
except hass_auth.InvalidAuth: except hass_auth.InvalidAuth:
print("Auth invalid") print("Auth invalid")
async def change_password(data, args): async def change_password(hass, provider, args):
"""Change password.""" """Change password."""
try: try:
data.change_password(args.username, args.new_password) provider.data.change_password(args.username, args.new_password)
await data.async_save() await provider.data.async_save()
print("Password changed") print("Password changed")
except hass_auth.InvalidUser: except hass_auth.InvalidUser:
print("User not found") print("User not found")

View File

@ -1,8 +1,11 @@
"""Test the Home Assistant local auth provider.""" """Test the Home Assistant local auth provider."""
from unittest.mock import Mock
import pytest import pytest
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.auth.providers import homeassistant as hass_auth from homeassistant.auth.providers import (
auth_provider_from_config, homeassistant as hass_auth)
@pytest.fixture @pytest.fixture
@ -15,15 +18,15 @@ def data(hass):
async def test_adding_user(data, hass): async def test_adding_user(data, hass):
"""Test adding a user.""" """Test adding a user."""
data.add_user('test-user', 'test-pass') data.add_auth('test-user', 'test-pass')
data.validate_login('test-user', 'test-pass') data.validate_login('test-user', 'test-pass')
async def test_adding_user_duplicate_username(data, hass): async def test_adding_user_duplicate_username(data, hass):
"""Test adding a user.""" """Test adding a user."""
data.add_user('test-user', 'test-pass') data.add_auth('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidUser): with pytest.raises(hass_auth.InvalidUser):
data.add_user('test-user', 'other-pass') data.add_auth('test-user', 'other-pass')
async def test_validating_password_invalid_user(data, hass): async def test_validating_password_invalid_user(data, hass):
@ -34,7 +37,7 @@ async def test_validating_password_invalid_user(data, hass):
async def test_validating_password_invalid_password(data, hass): async def test_validating_password_invalid_password(data, hass):
"""Test validating an invalid user.""" """Test validating an invalid user."""
data.add_user('test-user', 'test-pass') data.add_auth('test-user', 'test-pass')
with pytest.raises(hass_auth.InvalidAuth): with pytest.raises(hass_auth.InvalidAuth):
data.validate_login('test-user', 'invalid-pass') data.validate_login('test-user', 'invalid-pass')
@ -43,7 +46,7 @@ async def test_validating_password_invalid_password(data, hass):
async def test_changing_password(data, hass): async def test_changing_password(data, hass):
"""Test adding a user.""" """Test adding a user."""
user = 'test-user' user = 'test-user'
data.add_user(user, 'test-pass') data.add_auth(user, 'test-pass')
data.change_password(user, 'new-pass') data.change_password(user, 'new-pass')
with pytest.raises(hass_auth.InvalidAuth): with pytest.raises(hass_auth.InvalidAuth):
@ -60,7 +63,7 @@ async def test_changing_password_raises_invalid_user(data, hass):
async def test_login_flow_validates(data, hass): async def test_login_flow_validates(data, hass):
"""Test login flow.""" """Test login flow."""
data.add_user('test-user', 'test-pass') data.add_auth('test-user', 'test-pass')
await data.async_save() await data.async_save()
provider = hass_auth.HassAuthProvider(hass, None, {}) provider = hass_auth.HassAuthProvider(hass, None, {})
@ -91,11 +94,21 @@ async def test_login_flow_validates(data, hass):
async def test_saving_loading(data, hass): async def test_saving_loading(data, hass):
"""Test saving and loading JSON.""" """Test saving and loading JSON."""
data.add_user('test-user', 'test-pass') data.add_auth('test-user', 'test-pass')
data.add_user('second-user', 'second-pass') data.add_auth('second-user', 'second-pass')
await data.async_save() await data.async_save()
data = hass_auth.Data(hass) data = hass_auth.Data(hass)
await data.async_load() await data.async_load()
data.validate_login('test-user', 'test-pass') data.validate_login('test-user', 'test-pass')
data.validate_login('second-user', 'second-pass') data.validate_login('second-user', 'second-pass')
async def test_not_allow_set_id():
"""Test we are not allowed to set an ID in config."""
hass = Mock()
provider = await auth_provider_from_config(hass, None, {
'type': 'homeassistant',
'id': 'invalid',
})
assert provider is None

View File

@ -46,7 +46,7 @@ async def test_auth_manager_from_config_validates_config_and_id(mock_hass):
'name': provider.name, 'name': provider.name,
'id': provider.id, 'id': provider.id,
'type': provider.type, 'type': provider.type,
} for provider in manager.async_auth_providers] } for provider in manager.auth_providers]
assert providers == [{ assert providers == [{
'name': 'Test Name', 'name': 'Test Name',
'type': 'insecure_example', 'type': 'insecure_example',

View File

@ -1,5 +1,6 @@
"""Test the helper method for writing tests.""" """Test the helper method for writing tests."""
import asyncio import asyncio
from collections import OrderedDict
from datetime import timedelta from datetime import timedelta
import functools as ft import functools as ft
import json import json
@ -12,7 +13,8 @@ import threading
from contextlib import contextmanager from contextlib import contextmanager
from homeassistant import auth, core as ha, data_entry_flow, config_entries from homeassistant import auth, core as ha, data_entry_flow, config_entries
from homeassistant.auth import models as auth_models, auth_store from homeassistant.auth import (
models as auth_models, auth_store, providers as auth_providers)
from homeassistant.setup import setup_component, async_setup_component from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config from homeassistant.config import async_process_component_config
from homeassistant.helpers import ( from homeassistant.helpers import (
@ -312,11 +314,12 @@ def mock_registry(hass, mock_entries=None):
class MockUser(auth_models.User): class MockUser(auth_models.User):
"""Mock a user in Home Assistant.""" """Mock a user in Home Assistant."""
def __init__(self, id='mock-id', is_owner=True, is_active=True, def __init__(self, id='mock-id', is_owner=False, is_active=True,
name='Mock User'): name='Mock User', system_generated=False):
"""Initialize mock user.""" """Initialize mock user."""
super().__init__( super().__init__(
id=id, is_owner=is_owner, is_active=is_active, name=name) id=id, is_owner=is_owner, is_active=is_active, name=name,
system_generated=system_generated)
def add_to_hass(self, hass): def add_to_hass(self, hass):
"""Test helper to add entry to hass.""" """Test helper to add entry to hass."""
@ -329,12 +332,27 @@ class MockUser(auth_models.User):
return self return self
async def register_auth_provider(hass, config):
"""Helper to register an auth provider."""
provider = await auth_providers.auth_provider_from_config(
hass, hass.auth._store, config)
assert provider is not None, 'Invalid config specified'
key = (provider.type, provider.id)
providers = hass.auth._providers
if key in providers:
raise ValueError('Provider already registered')
providers[key] = provider
return provider
@ha.callback @ha.callback
def ensure_auth_manager_loaded(auth_mgr): def ensure_auth_manager_loaded(auth_mgr):
"""Ensure an auth manager is considered loaded.""" """Ensure an auth manager is considered loaded."""
store = auth_mgr._store store = auth_mgr._store
if store._users is None: if store._users is None:
store._users = {} store._users = OrderedDict()
class MockModule(object): class MockModule(object):
@ -731,7 +749,13 @@ def mock_storage(data=None):
if store.key not in data: if store.key not in data:
return None return None
store._data = data.get(store.key) mock_data = data.get(store.key)
if 'data' not in mock_data or 'version' not in mock_data:
_LOGGER.error('Mock data needs "version" and "data"')
raise ValueError('Mock data needs "version" and "data"')
store._data = mock_data
# Route through original load so that we trigger migration # Route through original load so that we trigger migration
loaded = await orig_load(store) loaded = await orig_load(store)

View File

@ -0,0 +1,211 @@
"""Test config entries API."""
from unittest.mock import PropertyMock, patch
import pytest
from homeassistant.auth import models as auth_models
from homeassistant.components.config import auth as auth_config
from tests.common import MockUser, CLIENT_ID
@pytest.fixture(autouse=True)
def auth_active(hass):
"""Mock that auth is active."""
with patch('homeassistant.auth.AuthManager.active',
PropertyMock(return_value=True)):
yield
@pytest.fixture(autouse=True)
def setup_config(hass, aiohttp_client):
"""Fixture that sets up the auth provider homeassistant module."""
hass.loop.run_until_complete(auth_config.async_setup(hass))
async def test_list_requires_owner(hass, hass_ws_client, hass_access_token):
"""Test get users requires auth."""
client = await hass_ws_client(hass, hass_access_token)
await client.send_json({
'id': 5,
'type': auth_config.WS_TYPE_LIST,
})
result = await client.receive_json()
assert not result['success'], result
assert result['error']['code'] == 'unauthorized'
async def test_list(hass, hass_ws_client):
"""Test get users."""
owner = MockUser(
id='abc',
name='Test Owner',
is_owner=True,
).add_to_hass(hass)
owner.credentials.append(auth_models.Credentials(
auth_provider_type='homeassistant',
auth_provider_id=None,
data={},
))
system = MockUser(
id='efg',
name='Test Hass.io',
system_generated=True
).add_to_hass(hass)
inactive = MockUser(
id='hij',
name='Inactive User',
is_active=False,
).add_to_hass(hass)
refresh_token = await hass.auth.async_create_refresh_token(
owner, CLIENT_ID)
access_token = hass.auth.async_create_access_token(refresh_token)
client = await hass_ws_client(hass, access_token)
await client.send_json({
'id': 5,
'type': auth_config.WS_TYPE_LIST,
})
result = await client.receive_json()
assert result['success'], result
data = result['result']
assert len(data) == 3
assert data[0] == {
'id': owner.id,
'name': 'Test Owner',
'is_owner': True,
'is_active': True,
'system_generated': False,
'credentials': [{'type': 'homeassistant'}]
}
assert data[1] == {
'id': system.id,
'name': 'Test Hass.io',
'is_owner': False,
'is_active': True,
'system_generated': True,
'credentials': [],
}
assert data[2] == {
'id': inactive.id,
'name': 'Inactive User',
'is_owner': False,
'is_active': False,
'system_generated': False,
'credentials': [],
}
async def test_delete_requires_owner(hass, hass_ws_client, hass_access_token):
"""Test delete command requires an owner."""
client = await hass_ws_client(hass, hass_access_token)
await client.send_json({
'id': 5,
'type': auth_config.WS_TYPE_DELETE,
'user_id': 'abcd',
})
result = await client.receive_json()
assert not result['success'], result
assert result['error']['code'] == 'unauthorized'
async def test_delete_unable_self_account(hass, hass_ws_client,
hass_access_token):
"""Test we cannot delete our own account."""
client = await hass_ws_client(hass, hass_access_token)
await client.send_json({
'id': 5,
'type': auth_config.WS_TYPE_DELETE,
'user_id': hass_access_token.refresh_token.user.id,
})
result = await client.receive_json()
assert not result['success'], result
assert result['error']['code'] == 'unauthorized'
async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token):
"""Test we cannot delete an unknown user."""
client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True
await client.send_json({
'id': 5,
'type': auth_config.WS_TYPE_DELETE,
'user_id': 'abcd',
})
result = await client.receive_json()
assert not result['success'], result
assert result['error']['code'] == 'not_found'
async def test_delete(hass, hass_ws_client, hass_access_token):
"""Test delete command works."""
client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True
test_user = MockUser(
id='efg',
).add_to_hass(hass)
assert len(await hass.auth.async_get_users()) == 2
await client.send_json({
'id': 5,
'type': auth_config.WS_TYPE_DELETE,
'user_id': test_user.id,
})
result = await client.receive_json()
assert result['success'], result
assert len(await hass.auth.async_get_users()) == 1
async def test_create(hass, hass_ws_client, hass_access_token):
"""Test create command works."""
client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True
assert len(await hass.auth.async_get_users()) == 1
await client.send_json({
'id': 5,
'type': auth_config.WS_TYPE_CREATE,
'name': 'Paulus',
})
result = await client.receive_json()
assert result['success'], result
assert len(await hass.auth.async_get_users()) == 2
data_user = result['result']['user']
user = await hass.auth.async_get_user(data_user['id'])
assert user is not None
assert user.name == data_user['name']
assert user.is_active
assert not user.is_owner
assert not user.system_generated
async def test_create_requires_owner(hass, hass_ws_client, hass_access_token):
"""Test create command requires an owner."""
client = await hass_ws_client(hass, hass_access_token)
await client.send_json({
'id': 5,
'type': auth_config.WS_TYPE_CREATE,
'name': 'YO',
})
result = await client.receive_json()
assert not result['success'], result
assert result['error']['code'] == 'unauthorized'

View File

@ -0,0 +1,229 @@
"""Test config entries API."""
import pytest
from homeassistant.auth.providers import homeassistant as prov_ha
from homeassistant.components.config import (
auth_provider_homeassistant as auth_ha)
from tests.common import MockUser, register_auth_provider
@pytest.fixture(autouse=True)
def setup_config(hass, aiohttp_client):
"""Fixture that sets up the auth provider homeassistant module."""
hass.loop.run_until_complete(register_auth_provider(hass, {
'type': 'homeassistant'
}))
hass.loop.run_until_complete(auth_ha.async_setup(hass))
async def test_create_auth_system_generated_user(hass, hass_access_token,
hass_ws_client):
"""Test we can't add auth to system generated users."""
system_user = MockUser(system_generated=True).add_to_hass(hass)
client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True
await client.send_json({
'id': 5,
'type': auth_ha.WS_TYPE_CREATE,
'user_id': system_user.id,
'username': 'test-user',
'password': 'test-pass',
})
result = await client.receive_json()
assert not result['success'], result
assert result['error']['code'] == 'system_generated'
async def test_create_auth_user_already_credentials():
"""Test we can't create auth for user with pre-existing credentials."""
# assert False
async def test_create_auth_unknown_user(hass_ws_client, hass,
hass_access_token):
"""Test create pointing at unknown user."""
client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True
await client.send_json({
'id': 5,
'type': auth_ha.WS_TYPE_CREATE,
'user_id': 'test-id',
'username': 'test-user',
'password': 'test-pass',
})
result = await client.receive_json()
assert not result['success'], result
assert result['error']['code'] == 'not_found'
async def test_create_auth_requires_owner(hass, hass_ws_client,
hass_access_token):
"""Test create requires owner to call API."""
client = await hass_ws_client(hass, hass_access_token)
await client.send_json({
'id': 5,
'type': auth_ha.WS_TYPE_CREATE,
'user_id': 'test-id',
'username': 'test-user',
'password': 'test-pass',
})
result = await client.receive_json()
assert not result['success'], result
assert result['error']['code'] == 'unauthorized'
async def test_create_auth(hass, hass_ws_client, hass_access_token,
hass_storage):
"""Test create auth command works."""
client = await hass_ws_client(hass, hass_access_token)
user = MockUser().add_to_hass(hass)
hass_access_token.refresh_token.user.is_owner = True
assert len(user.credentials) == 0
await client.send_json({
'id': 5,
'type': auth_ha.WS_TYPE_CREATE,
'user_id': user.id,
'username': 'test-user',
'password': 'test-pass',
})
result = await client.receive_json()
assert result['success'], result
assert len(user.credentials) == 1
creds = user.credentials[0]
assert creds.auth_provider_type == 'homeassistant'
assert creds.auth_provider_id is None
assert creds.data == {
'username': 'test-user'
}
assert prov_ha.STORAGE_KEY in hass_storage
entry = hass_storage[prov_ha.STORAGE_KEY]['data']['users'][0]
assert entry['username'] == 'test-user'
async def test_create_auth_duplicate_username(hass, hass_ws_client,
hass_access_token, hass_storage):
"""Test we can't create auth with a duplicate username."""
client = await hass_ws_client(hass, hass_access_token)
user = MockUser().add_to_hass(hass)
hass_access_token.refresh_token.user.is_owner = True
hass_storage[prov_ha.STORAGE_KEY] = {
'version': 1,
'data': {
'users': [{
'username': 'test-user'
}]
}
}
await client.send_json({
'id': 5,
'type': auth_ha.WS_TYPE_CREATE,
'user_id': user.id,
'username': 'test-user',
'password': 'test-pass',
})
result = await client.receive_json()
assert not result['success'], result
assert result['error']['code'] == 'username_exists'
async def test_delete_removes_just_auth(hass_ws_client, hass, hass_storage,
hass_access_token):
"""Test deleting an auth without being connected to a user."""
client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True
hass_storage[prov_ha.STORAGE_KEY] = {
'version': 1,
'data': {
'users': [{
'username': 'test-user'
}]
}
}
await client.send_json({
'id': 5,
'type': auth_ha.WS_TYPE_DELETE,
'username': 'test-user',
})
result = await client.receive_json()
assert result['success'], result
assert len(hass_storage[prov_ha.STORAGE_KEY]['data']['users']) == 0
async def test_delete_removes_credential(hass, hass_ws_client,
hass_access_token, hass_storage):
"""Test deleting auth that is connected to a user."""
client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True
user = MockUser().add_to_hass(hass)
user.credentials.append(
await hass.auth.auth_providers[0].async_get_or_create_credentials({
'username': 'test-user'}))
hass_storage[prov_ha.STORAGE_KEY] = {
'version': 1,
'data': {
'users': [{
'username': 'test-user'
}]
}
}
await client.send_json({
'id': 5,
'type': auth_ha.WS_TYPE_DELETE,
'username': 'test-user',
})
result = await client.receive_json()
assert result['success'], result
assert len(hass_storage[prov_ha.STORAGE_KEY]['data']['users']) == 0
async def test_delete_requires_owner(hass, hass_ws_client, hass_access_token):
"""Test delete requires owner."""
client = await hass_ws_client(hass, hass_access_token)
await client.send_json({
'id': 5,
'type': auth_ha.WS_TYPE_DELETE,
'username': 'test-user',
})
result = await client.receive_json()
assert not result['success'], result
assert result['error']['code'] == 'unauthorized'
async def test_delete_unknown_auth(hass, hass_ws_client, hass_access_token):
"""Test trying to delete an unknown auth username."""
client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True
await client.send_json({
'id': 5,
'type': auth_ha.WS_TYPE_DELETE,
'username': 'test-user',
})
result = await client.receive_json()
assert not result['success'], result
assert result['error']['code'] == 'auth_not_found'

View File

@ -2,6 +2,7 @@
import pytest import pytest
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.components import websocket_api
from tests.common import MockUser, CLIENT_ID from tests.common import MockUser, CLIENT_ID
@ -9,13 +10,27 @@ from tests.common import MockUser, CLIENT_ID
@pytest.fixture @pytest.fixture
def hass_ws_client(aiohttp_client): def hass_ws_client(aiohttp_client):
"""Websocket client fixture connected to websocket server.""" """Websocket client fixture connected to websocket server."""
async def create_client(hass): async def create_client(hass, access_token=None):
"""Create a websocket client.""" """Create a websocket client."""
wapi = hass.components.websocket_api wapi = hass.components.websocket_api
assert await async_setup_component(hass, 'websocket_api') assert await async_setup_component(hass, 'websocket_api')
client = await aiohttp_client(hass.http.app) client = await aiohttp_client(hass.http.app)
websocket = await client.ws_connect(wapi.URL) websocket = await client.ws_connect(wapi.URL)
auth_resp = await websocket.receive_json()
if auth_resp['type'] == wapi.TYPE_AUTH_OK:
assert access_token is None, \
'Access token given but no auth required'
return websocket
assert access_token is not None, 'Access token required for fixture'
await websocket.send_json({
'type': websocket_api.TYPE_AUTH,
'access_token': access_token.token
})
auth_ok = await websocket.receive_json() auth_ok = await websocket.receive_json()
assert auth_ok['type'] == wapi.TYPE_AUTH_OK assert auth_ok['type'] == wapi.TYPE_AUTH_OK

View File

@ -1,13 +1,12 @@
"""The tests for the Home Assistant HTTP component.""" """The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access # pylint: disable=protected-access
from ipaddress import ip_network from ipaddress import ip_network
from unittest.mock import patch, Mock from unittest.mock import patch
import pytest import pytest
from aiohttp import BasicAuth, web from aiohttp import BasicAuth, web
from aiohttp.web_exceptions import HTTPUnauthorized from aiohttp.web_exceptions import HTTPUnauthorized
from homeassistant.auth.models import AccessToken, RefreshToken
from homeassistant.components.http.auth import setup_auth from homeassistant.components.http.auth import setup_auth
from homeassistant.components.http.const import KEY_AUTHENTICATED from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.components.http.real_ip import setup_real_ip from homeassistant.components.http.real_ip import setup_real_ip
@ -16,8 +15,6 @@ from homeassistant.setup import async_setup_component
from . import mock_real_ip from . import mock_real_ip
ACCESS_TOKEN = 'tk.1234'
API_PASSWORD = 'test1234' API_PASSWORD = 'test1234'
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases # Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
@ -39,33 +36,21 @@ async def mock_handler(request):
return web.Response(status=200) return web.Response(status=200)
def mock_async_get_access_token(token):
"""Return if token is valid."""
if token == ACCESS_TOKEN:
return Mock(spec=AccessToken,
token=ACCESS_TOKEN,
refresh_token=Mock(spec=RefreshToken))
else:
return None
@pytest.fixture @pytest.fixture
def app(): def app(hass):
"""Fixture to setup a web.Application.""" """Fixture to setup a web.Application."""
app = web.Application() app = web.Application()
mock_auth = Mock(async_get_access_token=mock_async_get_access_token) app['hass'] = hass
app['hass'] = Mock(auth=mock_auth)
app.router.add_get('/', mock_handler) app.router.add_get('/', mock_handler)
setup_real_ip(app, False, []) setup_real_ip(app, False, [])
return app return app
@pytest.fixture @pytest.fixture
def app2(): def app2(hass):
"""Fixture to setup a web.Application without real_ip middleware.""" """Fixture to setup a web.Application without real_ip middleware."""
app = web.Application() app = web.Application()
mock_auth = Mock(async_get_access_token=mock_async_get_access_token) app['hass'] = hass
app['hass'] = Mock(auth=mock_auth)
app.router.add_get('/', mock_handler) app.router.add_get('/', mock_handler)
return app return app
@ -171,33 +156,35 @@ async def test_access_with_trusted_ip(app2, aiohttp_client):
async def test_auth_active_access_with_access_token_in_header( async def test_auth_active_access_with_access_token_in_header(
app, aiohttp_client): app, aiohttp_client, hass_access_token):
"""Test access with access token in header.""" """Test access with access token in header."""
token = hass_access_token.token
setup_auth(app, [], True, api_password=None) setup_auth(app, [], True, api_password=None)
client = await aiohttp_client(app) client = await aiohttp_client(app)
req = await client.get( req = await client.get(
'/', headers={'Authorization': 'Bearer {}'.format(ACCESS_TOKEN)}) '/', headers={'Authorization': 'Bearer {}'.format(token)})
assert req.status == 200 assert req.status == 200
req = await client.get( req = await client.get(
'/', headers={'AUTHORIZATION': 'Bearer {}'.format(ACCESS_TOKEN)}) '/', headers={'AUTHORIZATION': 'Bearer {}'.format(token)})
assert req.status == 200 assert req.status == 200
req = await client.get( req = await client.get(
'/', headers={'authorization': 'Bearer {}'.format(ACCESS_TOKEN)}) '/', headers={'authorization': 'Bearer {}'.format(token)})
assert req.status == 200 assert req.status == 200
req = await client.get( req = await client.get(
'/', headers={'Authorization': ACCESS_TOKEN}) '/', headers={'Authorization': token})
assert req.status == 401 assert req.status == 401
req = await client.get( req = await client.get(
'/', headers={'Authorization': 'BEARER {}'.format(ACCESS_TOKEN)}) '/', headers={'Authorization': 'BEARER {}'.format(token)})
assert req.status == 401 assert req.status == 401
hass_access_token.refresh_token.user.is_active = False
req = await client.get( req = await client.get(
'/', headers={'Authorization': 'Bearer wrong-pass'}) '/', headers={'Authorization': 'Bearer {}'.format(token)})
assert req.status == 401 assert req.status == 401

View File

@ -21,7 +21,7 @@ if os.environ.get('UVLOOP') == '1':
import uvloop import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.DEBUG)
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)

View File

@ -6,21 +6,26 @@ import pytest
from homeassistant.scripts import auth as script_auth from homeassistant.scripts import auth as script_auth
from homeassistant.auth.providers import homeassistant as hass_auth from homeassistant.auth.providers import homeassistant as hass_auth
from tests.common import register_auth_provider
@pytest.fixture @pytest.fixture
def data(hass): def provider(hass):
"""Create a loaded data class.""" """Home Assistant auth provider."""
data = hass_auth.Data(hass) provider = hass.loop.run_until_complete(register_auth_provider(hass, {
hass.loop.run_until_complete(data.async_load()) 'type': 'homeassistant',
return data }))
hass.loop.run_until_complete(provider.async_initialize())
return provider
async def test_list_user(data, capsys): async def test_list_user(hass, provider, capsys):
"""Test we can list users.""" """Test we can list users."""
data.add_user('test-user', 'test-pass') data = provider.data
data.add_user('second-user', 'second-pass') data.add_auth('test-user', 'test-pass')
data.add_auth('second-user', 'second-pass')
await script_auth.list_users(data, None) await script_auth.list_users(hass, provider, None)
captured = capsys.readouterr() captured = capsys.readouterr()
@ -33,10 +38,11 @@ async def test_list_user(data, capsys):
]) ])
async def test_add_user(data, capsys, hass_storage): async def test_add_user(hass, provider, capsys, hass_storage):
"""Test we can add a user.""" """Test we can add a user."""
data = provider.data
await script_auth.add_user( await script_auth.add_user(
data, Mock(username='paulus', password='test-pass')) hass, provider, Mock(username='paulus', password='test-pass'))
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1 assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
@ -47,32 +53,34 @@ async def test_add_user(data, capsys, hass_storage):
data.validate_login('paulus', 'test-pass') data.validate_login('paulus', 'test-pass')
async def test_validate_login(data, capsys): async def test_validate_login(hass, provider, capsys):
"""Test we can validate a user login.""" """Test we can validate a user login."""
data.add_user('test-user', 'test-pass') data = provider.data
data.add_auth('test-user', 'test-pass')
await script_auth.validate_login( await script_auth.validate_login(
data, Mock(username='test-user', password='test-pass')) hass, provider, Mock(username='test-user', password='test-pass'))
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == 'Auth valid\n' assert captured.out == 'Auth valid\n'
await script_auth.validate_login( await script_auth.validate_login(
data, Mock(username='test-user', password='invalid-pass')) hass, provider, Mock(username='test-user', password='invalid-pass'))
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == 'Auth invalid\n' assert captured.out == 'Auth invalid\n'
await script_auth.validate_login( await script_auth.validate_login(
data, Mock(username='invalid-user', password='test-pass')) hass, provider, Mock(username='invalid-user', password='test-pass'))
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == 'Auth invalid\n' assert captured.out == 'Auth invalid\n'
async def test_change_password(data, capsys, hass_storage): async def test_change_password(hass, provider, capsys, hass_storage):
"""Test we can change a password.""" """Test we can change a password."""
data.add_user('test-user', 'test-pass') data = provider.data
data.add_auth('test-user', 'test-pass')
await script_auth.change_password( await script_auth.change_password(
data, Mock(username='test-user', new_password='new-pass')) hass, provider, Mock(username='test-user', new_password='new-pass'))
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1 assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
captured = capsys.readouterr() captured = capsys.readouterr()
@ -82,12 +90,14 @@ async def test_change_password(data, capsys, hass_storage):
data.validate_login('test-user', 'test-pass') data.validate_login('test-user', 'test-pass')
async def test_change_password_invalid_user(data, capsys, hass_storage): async def test_change_password_invalid_user(hass, provider, capsys,
hass_storage):
"""Test changing password of non-existing user.""" """Test changing password of non-existing user."""
data.add_user('test-user', 'test-pass') data = provider.data
data.add_auth('test-user', 'test-pass')
await script_auth.change_password( await script_auth.change_password(
data, Mock(username='invalid-user', new_password='new-pass')) hass, provider, Mock(username='invalid-user', new_password='new-pass'))
assert hass_auth.STORAGE_KEY not in hass_storage assert hass_auth.STORAGE_KEY not in hass_storage
captured = capsys.readouterr() captured = capsys.readouterr()
@ -101,11 +111,11 @@ def test_parsing_args(loop):
"""Test we parse args correctly.""" """Test we parse args correctly."""
called = False called = False
async def mock_func(data, args2): async def mock_func(hass, provider, args2):
"""Mock function to be called.""" """Mock function to be called."""
nonlocal called nonlocal called
called = True called = True
assert data.hass.config.config_dir == '/somewhere/config' assert provider.hass.config.config_dir == '/somewhere/config'
assert args2 is args assert args2 is args
args = Mock(config='/somewhere/config', func=mock_func) args = Mock(config='/somewhere/config', func=mock_func)