mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 19:57:07 +00:00
Reorg auth (#15443)
This commit is contained in:
parent
23f1b49e55
commit
b6ca03ce47
@ -1,613 +0,0 @@
|
|||||||
"""Provide an authentication layer for Home Assistant."""
|
|
||||||
import asyncio
|
|
||||||
import binascii
|
|
||||||
import importlib
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from collections import OrderedDict
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
import attr
|
|
||||||
import voluptuous as vol
|
|
||||||
from voluptuous.humanize import humanize_error
|
|
||||||
|
|
||||||
from homeassistant import data_entry_flow, requirements
|
|
||||||
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
|
||||||
from homeassistant.core import callback
|
|
||||||
from homeassistant.util import dt as dt_util
|
|
||||||
from homeassistant.util.decorator import Registry
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
STORAGE_VERSION = 1
|
|
||||||
STORAGE_KEY = 'auth'
|
|
||||||
|
|
||||||
AUTH_PROVIDERS = Registry()
|
|
||||||
|
|
||||||
AUTH_PROVIDER_SCHEMA = vol.Schema({
|
|
||||||
vol.Required(CONF_TYPE): str,
|
|
||||||
vol.Optional(CONF_NAME): str,
|
|
||||||
# Specify ID if you have two auth providers for same type.
|
|
||||||
vol.Optional(CONF_ID): str,
|
|
||||||
}, extra=vol.ALLOW_EXTRA)
|
|
||||||
|
|
||||||
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
|
|
||||||
DATA_REQS = 'auth_reqs_processed'
|
|
||||||
|
|
||||||
|
|
||||||
def generate_secret(entropy: int = 32) -> str:
|
|
||||||
"""Generate a secret.
|
|
||||||
|
|
||||||
Backport of secrets.token_hex from Python 3.6
|
|
||||||
|
|
||||||
Event loop friendly.
|
|
||||||
"""
|
|
||||||
return binascii.hexlify(os.urandom(entropy)).decode('ascii')
|
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider:
|
|
||||||
"""Provider of user authentication."""
|
|
||||||
|
|
||||||
DEFAULT_TITLE = 'Unnamed auth provider'
|
|
||||||
|
|
||||||
initialized = False
|
|
||||||
|
|
||||||
def __init__(self, hass, store, config):
|
|
||||||
"""Initialize an auth provider."""
|
|
||||||
self.hass = hass
|
|
||||||
self.store = store
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self): # pylint: disable=invalid-name
|
|
||||||
"""Return id of the auth provider.
|
|
||||||
|
|
||||||
Optional, can be None.
|
|
||||||
"""
|
|
||||||
return self.config.get(CONF_ID)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self):
|
|
||||||
"""Return type of the provider."""
|
|
||||||
return self.config[CONF_TYPE]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
"""Return the name of the auth provider."""
|
|
||||||
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
|
|
||||||
|
|
||||||
async def async_credentials(self):
|
|
||||||
"""Return all credentials of this provider."""
|
|
||||||
users = await self.store.async_get_users()
|
|
||||||
return [
|
|
||||||
credentials
|
|
||||||
for user in users
|
|
||||||
for credentials in user.credentials
|
|
||||||
if (credentials.auth_provider_type == self.type and
|
|
||||||
credentials.auth_provider_id == self.id)
|
|
||||||
]
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_create_credentials(self, data):
|
|
||||||
"""Create credentials."""
|
|
||||||
return Credentials(
|
|
||||||
auth_provider_type=self.type,
|
|
||||||
auth_provider_id=self.id,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Implement by extending class
|
|
||||||
|
|
||||||
async def async_initialize(self):
|
|
||||||
"""Initialize the auth provider.
|
|
||||||
|
|
||||||
Optional.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def async_credential_flow(self):
|
|
||||||
"""Return the data flow for logging in with auth provider."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def async_get_or_create_credentials(self, flow_result):
|
|
||||||
"""Get credentials based on the flow result."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def async_user_meta_for_credentials(self, credentials):
|
|
||||||
"""Return extra user metadata for credentials.
|
|
||||||
|
|
||||||
Will be used to populate info when creating a new user.
|
|
||||||
"""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
|
||||||
class User:
|
|
||||||
"""A user."""
|
|
||||||
|
|
||||||
name = attr.ib(type=str)
|
|
||||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
|
||||||
is_owner = attr.ib(type=bool, default=False)
|
|
||||||
is_active = attr.ib(type=bool, default=False)
|
|
||||||
system_generated = attr.ib(type=bool, default=False)
|
|
||||||
|
|
||||||
# List of credentials of a user.
|
|
||||||
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
|
||||||
|
|
||||||
# Tokens associated with a user.
|
|
||||||
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
|
||||||
class RefreshToken:
|
|
||||||
"""RefreshToken for a user to grant new access tokens."""
|
|
||||||
|
|
||||||
user = attr.ib(type=User)
|
|
||||||
client_id = attr.ib(type=str)
|
|
||||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
|
||||||
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
|
||||||
access_token_expiration = attr.ib(type=timedelta,
|
|
||||||
default=ACCESS_TOKEN_EXPIRATION)
|
|
||||||
token = attr.ib(type=str,
|
|
||||||
default=attr.Factory(lambda: generate_secret(64)))
|
|
||||||
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
|
||||||
class AccessToken:
|
|
||||||
"""Access token to access the API.
|
|
||||||
|
|
||||||
These will only ever be stored in memory and not be persisted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
refresh_token = attr.ib(type=RefreshToken)
|
|
||||||
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
|
||||||
token = attr.ib(type=str,
|
|
||||||
default=attr.Factory(generate_secret))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def expired(self):
|
|
||||||
"""Return if this token has expired."""
|
|
||||||
expires = self.created_at + self.refresh_token.access_token_expiration
|
|
||||||
return dt_util.utcnow() > expires
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
|
||||||
class Credentials:
|
|
||||||
"""Credentials for a user on an auth provider."""
|
|
||||||
|
|
||||||
auth_provider_type = attr.ib(type=str)
|
|
||||||
auth_provider_id = attr.ib(type=str)
|
|
||||||
|
|
||||||
# Allow the auth provider to store data to represent their auth.
|
|
||||||
data = attr.ib(type=dict)
|
|
||||||
|
|
||||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
|
||||||
is_new = attr.ib(type=bool, default=True)
|
|
||||||
|
|
||||||
|
|
||||||
async def load_auth_provider_module(hass, provider):
|
|
||||||
"""Load an auth provider."""
|
|
||||||
try:
|
|
||||||
module = importlib.import_module(
|
|
||||||
'homeassistant.auth_providers.{}'.format(provider))
|
|
||||||
except ImportError:
|
|
||||||
_LOGGER.warning('Unable to find auth provider %s', provider)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if hass.config.skip_pip or not hasattr(module, 'REQUIREMENTS'):
|
|
||||||
return module
|
|
||||||
|
|
||||||
processed = hass.data.get(DATA_REQS)
|
|
||||||
|
|
||||||
if processed is None:
|
|
||||||
processed = hass.data[DATA_REQS] = set()
|
|
||||||
elif provider in processed:
|
|
||||||
return module
|
|
||||||
|
|
||||||
req_success = await requirements.async_process_requirements(
|
|
||||||
hass, 'auth provider {}'.format(provider), module.REQUIREMENTS)
|
|
||||||
|
|
||||||
if not req_success:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
async def auth_manager_from_config(hass, provider_configs):
|
|
||||||
"""Initialize an auth manager from config."""
|
|
||||||
store = AuthStore(hass)
|
|
||||||
if provider_configs:
|
|
||||||
providers = await asyncio.gather(
|
|
||||||
*[_auth_provider_from_config(hass, store, config)
|
|
||||||
for config in provider_configs])
|
|
||||||
else:
|
|
||||||
providers = []
|
|
||||||
# So returned auth providers are in same order as config
|
|
||||||
provider_hash = OrderedDict()
|
|
||||||
for provider in providers:
|
|
||||||
if provider is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
key = (provider.type, provider.id)
|
|
||||||
|
|
||||||
if key in provider_hash:
|
|
||||||
_LOGGER.error(
|
|
||||||
'Found duplicate provider: %s. Please add unique IDs if you '
|
|
||||||
'want to have the same provider twice.', key)
|
|
||||||
continue
|
|
||||||
|
|
||||||
provider_hash[key] = provider
|
|
||||||
manager = AuthManager(hass, store, provider_hash)
|
|
||||||
return manager
|
|
||||||
|
|
||||||
|
|
||||||
async def _auth_provider_from_config(hass, store, config):
|
|
||||||
"""Initialize an auth provider from a config."""
|
|
||||||
provider_name = config[CONF_TYPE]
|
|
||||||
module = await load_auth_provider_module(hass, provider_name)
|
|
||||||
|
|
||||||
if module is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
config = module.CONFIG_SCHEMA(config)
|
|
||||||
except vol.Invalid as err:
|
|
||||||
_LOGGER.error('Invalid configuration for auth provider %s: %s',
|
|
||||||
provider_name, humanize_error(config, err))
|
|
||||||
return None
|
|
||||||
|
|
||||||
return AUTH_PROVIDERS[provider_name](hass, store, config)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthManager:
|
|
||||||
"""Manage the authentication for Home Assistant."""
|
|
||||||
|
|
||||||
def __init__(self, hass, store, providers):
|
|
||||||
"""Initialize the auth manager."""
|
|
||||||
self._store = store
|
|
||||||
self._providers = providers
|
|
||||||
self.login_flow = data_entry_flow.FlowManager(
|
|
||||||
hass, self._async_create_login_flow,
|
|
||||||
self._async_finish_login_flow)
|
|
||||||
self._access_tokens = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def active(self):
|
|
||||||
"""Return if any auth providers are registered."""
|
|
||||||
return bool(self._providers)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def support_legacy(self):
|
|
||||||
"""
|
|
||||||
Return if legacy_api_password auth providers are registered.
|
|
||||||
|
|
||||||
Should be removed when we removed legacy_api_password auth providers.
|
|
||||||
"""
|
|
||||||
for provider_type, _ in self._providers:
|
|
||||||
if provider_type == 'legacy_api_password':
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def async_auth_providers(self):
|
|
||||||
"""Return a list of available auth providers."""
|
|
||||||
return self._providers.values()
|
|
||||||
|
|
||||||
async def async_get_user(self, user_id):
|
|
||||||
"""Retrieve a user."""
|
|
||||||
return await self._store.async_get_user(user_id)
|
|
||||||
|
|
||||||
async def async_create_system_user(self, name):
|
|
||||||
"""Create a system user."""
|
|
||||||
return await self._store.async_create_user(
|
|
||||||
name=name,
|
|
||||||
system_generated=True,
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def async_get_or_create_user(self, credentials):
|
|
||||||
"""Get or create a user."""
|
|
||||||
if not credentials.is_new:
|
|
||||||
for user in await self._store.async_get_users():
|
|
||||||
for creds in user.credentials:
|
|
||||||
if creds.id == credentials.id:
|
|
||||||
return user
|
|
||||||
|
|
||||||
raise ValueError('Unable to find the user.')
|
|
||||||
|
|
||||||
auth_provider = self._async_get_auth_provider(credentials)
|
|
||||||
info = await auth_provider.async_user_meta_for_credentials(
|
|
||||||
credentials)
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
'credentials': credentials,
|
|
||||||
'name': info.get('name')
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make owner and activate user if it's the first user.
|
|
||||||
if await self._store.async_get_users():
|
|
||||||
kwargs['is_owner'] = False
|
|
||||||
kwargs['is_active'] = False
|
|
||||||
else:
|
|
||||||
kwargs['is_owner'] = True
|
|
||||||
kwargs['is_active'] = True
|
|
||||||
|
|
||||||
return await self._store.async_create_user(**kwargs)
|
|
||||||
|
|
||||||
async def async_link_user(self, user, credentials):
|
|
||||||
"""Link credentials to an existing user."""
|
|
||||||
await self._store.async_link_user(user, credentials)
|
|
||||||
|
|
||||||
async def async_remove_user(self, user):
|
|
||||||
"""Remove a user."""
|
|
||||||
await self._store.async_remove_user(user)
|
|
||||||
|
|
||||||
async def async_create_refresh_token(self, user, client_id=None):
|
|
||||||
"""Create a new refresh token for a user."""
|
|
||||||
if not user.is_active:
|
|
||||||
raise ValueError('User is not active')
|
|
||||||
|
|
||||||
if user.system_generated and client_id is not None:
|
|
||||||
raise ValueError(
|
|
||||||
'System generated users cannot have refresh tokens connected '
|
|
||||||
'to a client.')
|
|
||||||
|
|
||||||
if not user.system_generated and client_id is None:
|
|
||||||
raise ValueError('Client is required to generate a refresh token.')
|
|
||||||
|
|
||||||
return await self._store.async_create_refresh_token(user, client_id)
|
|
||||||
|
|
||||||
async def async_get_refresh_token(self, token):
|
|
||||||
"""Get refresh token by token."""
|
|
||||||
return await self._store.async_get_refresh_token(token)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_create_access_token(self, refresh_token):
|
|
||||||
"""Create a new access token."""
|
|
||||||
access_token = AccessToken(refresh_token=refresh_token)
|
|
||||||
self._access_tokens[access_token.token] = access_token
|
|
||||||
return access_token
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_get_access_token(self, token):
|
|
||||||
"""Get an access token."""
|
|
||||||
tkn = self._access_tokens.get(token)
|
|
||||||
|
|
||||||
if tkn is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if tkn.expired:
|
|
||||||
self._access_tokens.pop(token)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return tkn
|
|
||||||
|
|
||||||
async def _async_create_login_flow(self, handler, *, source, data):
|
|
||||||
"""Create a login flow."""
|
|
||||||
auth_provider = self._providers[handler]
|
|
||||||
|
|
||||||
if not auth_provider.initialized:
|
|
||||||
auth_provider.initialized = True
|
|
||||||
await auth_provider.async_initialize()
|
|
||||||
|
|
||||||
return await auth_provider.async_credential_flow()
|
|
||||||
|
|
||||||
async def _async_finish_login_flow(self, result):
|
|
||||||
"""Result of a credential login flow."""
|
|
||||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
|
||||||
return None
|
|
||||||
|
|
||||||
auth_provider = self._providers[result['handler']]
|
|
||||||
return await auth_provider.async_get_or_create_credentials(
|
|
||||||
result['data'])
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def _async_get_auth_provider(self, credentials):
|
|
||||||
"""Helper to get auth provider from a set of credentials."""
|
|
||||||
auth_provider_key = (credentials.auth_provider_type,
|
|
||||||
credentials.auth_provider_id)
|
|
||||||
return self._providers[auth_provider_key]
|
|
||||||
|
|
||||||
|
|
||||||
class AuthStore:
|
|
||||||
"""Stores authentication info.
|
|
||||||
|
|
||||||
Any mutation to an object should happen inside the auth store.
|
|
||||||
|
|
||||||
The auth store is lazy. It won't load the data from disk until a method is
|
|
||||||
called that needs it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hass):
|
|
||||||
"""Initialize the auth store."""
|
|
||||||
self.hass = hass
|
|
||||||
self._users = None
|
|
||||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
|
||||||
|
|
||||||
async def async_get_users(self):
|
|
||||||
"""Retrieve all users."""
|
|
||||||
if self._users is None:
|
|
||||||
await self.async_load()
|
|
||||||
|
|
||||||
return list(self._users.values())
|
|
||||||
|
|
||||||
async def async_get_user(self, user_id):
|
|
||||||
"""Retrieve a user by id."""
|
|
||||||
if self._users is None:
|
|
||||||
await self.async_load()
|
|
||||||
|
|
||||||
return self._users.get(user_id)
|
|
||||||
|
|
||||||
async def async_create_user(self, name, is_owner=None, is_active=None,
|
|
||||||
system_generated=None, credentials=None):
|
|
||||||
"""Create a new user."""
|
|
||||||
if self._users is None:
|
|
||||||
await self.async_load()
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
'name': name
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_owner is not None:
|
|
||||||
kwargs['is_owner'] = is_owner
|
|
||||||
|
|
||||||
if is_active is not None:
|
|
||||||
kwargs['is_active'] = is_active
|
|
||||||
|
|
||||||
if system_generated is not None:
|
|
||||||
kwargs['system_generated'] = system_generated
|
|
||||||
|
|
||||||
new_user = User(**kwargs)
|
|
||||||
|
|
||||||
self._users[new_user.id] = new_user
|
|
||||||
|
|
||||||
if credentials is None:
|
|
||||||
await self.async_save()
|
|
||||||
return new_user
|
|
||||||
|
|
||||||
# Saving is done inside the link.
|
|
||||||
await self.async_link_user(new_user, credentials)
|
|
||||||
return new_user
|
|
||||||
|
|
||||||
async def async_link_user(self, user, credentials):
|
|
||||||
"""Add credentials to an existing user."""
|
|
||||||
user.credentials.append(credentials)
|
|
||||||
await self.async_save()
|
|
||||||
credentials.is_new = False
|
|
||||||
|
|
||||||
async def async_remove_user(self, user):
|
|
||||||
"""Remove a user."""
|
|
||||||
self._users.pop(user.id)
|
|
||||||
await self.async_save()
|
|
||||||
|
|
||||||
async def async_create_refresh_token(self, user, client_id=None):
|
|
||||||
"""Create a new token for a user."""
|
|
||||||
refresh_token = RefreshToken(user=user, client_id=client_id)
|
|
||||||
user.refresh_tokens[refresh_token.token] = refresh_token
|
|
||||||
await self.async_save()
|
|
||||||
return refresh_token
|
|
||||||
|
|
||||||
async def async_get_refresh_token(self, token):
|
|
||||||
"""Get refresh token by token."""
|
|
||||||
if self._users is None:
|
|
||||||
await self.async_load()
|
|
||||||
|
|
||||||
for user in self._users.values():
|
|
||||||
refresh_token = user.refresh_tokens.get(token)
|
|
||||||
if refresh_token is not None:
|
|
||||||
return refresh_token
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def async_load(self):
|
|
||||||
"""Load the users."""
|
|
||||||
data = await self._store.async_load()
|
|
||||||
|
|
||||||
# Make sure that we're not overriding data if 2 loads happened at the
|
|
||||||
# same time
|
|
||||||
if self._users is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if data is None:
|
|
||||||
self._users = {}
|
|
||||||
return
|
|
||||||
|
|
||||||
users = {
|
|
||||||
user_dict['id']: User(**user_dict) for user_dict in data['users']
|
|
||||||
}
|
|
||||||
|
|
||||||
for cred_dict in data['credentials']:
|
|
||||||
users[cred_dict['user_id']].credentials.append(Credentials(
|
|
||||||
id=cred_dict['id'],
|
|
||||||
is_new=False,
|
|
||||||
auth_provider_type=cred_dict['auth_provider_type'],
|
|
||||||
auth_provider_id=cred_dict['auth_provider_id'],
|
|
||||||
data=cred_dict['data'],
|
|
||||||
))
|
|
||||||
|
|
||||||
refresh_tokens = {}
|
|
||||||
|
|
||||||
for rt_dict in data['refresh_tokens']:
|
|
||||||
token = RefreshToken(
|
|
||||||
id=rt_dict['id'],
|
|
||||||
user=users[rt_dict['user_id']],
|
|
||||||
client_id=rt_dict['client_id'],
|
|
||||||
created_at=dt_util.parse_datetime(rt_dict['created_at']),
|
|
||||||
access_token_expiration=timedelta(
|
|
||||||
seconds=rt_dict['access_token_expiration']),
|
|
||||||
token=rt_dict['token'],
|
|
||||||
)
|
|
||||||
refresh_tokens[token.id] = token
|
|
||||||
users[rt_dict['user_id']].refresh_tokens[token.token] = token
|
|
||||||
|
|
||||||
for ac_dict in data['access_tokens']:
|
|
||||||
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
|
|
||||||
token = AccessToken(
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
created_at=dt_util.parse_datetime(ac_dict['created_at']),
|
|
||||||
token=ac_dict['token'],
|
|
||||||
)
|
|
||||||
refresh_token.access_tokens.append(token)
|
|
||||||
|
|
||||||
self._users = users
|
|
||||||
|
|
||||||
async def async_save(self):
|
|
||||||
"""Save users."""
|
|
||||||
users = [
|
|
||||||
{
|
|
||||||
'id': user.id,
|
|
||||||
'is_owner': user.is_owner,
|
|
||||||
'is_active': user.is_active,
|
|
||||||
'name': user.name,
|
|
||||||
'system_generated': user.system_generated,
|
|
||||||
}
|
|
||||||
for user in self._users.values()
|
|
||||||
]
|
|
||||||
|
|
||||||
credentials = [
|
|
||||||
{
|
|
||||||
'id': credential.id,
|
|
||||||
'user_id': user.id,
|
|
||||||
'auth_provider_type': credential.auth_provider_type,
|
|
||||||
'auth_provider_id': credential.auth_provider_id,
|
|
||||||
'data': credential.data,
|
|
||||||
}
|
|
||||||
for user in self._users.values()
|
|
||||||
for credential in user.credentials
|
|
||||||
]
|
|
||||||
|
|
||||||
refresh_tokens = [
|
|
||||||
{
|
|
||||||
'id': refresh_token.id,
|
|
||||||
'user_id': user.id,
|
|
||||||
'client_id': refresh_token.client_id,
|
|
||||||
'created_at': refresh_token.created_at.isoformat(),
|
|
||||||
'access_token_expiration':
|
|
||||||
refresh_token.access_token_expiration.total_seconds(),
|
|
||||||
'token': refresh_token.token,
|
|
||||||
}
|
|
||||||
for user in self._users.values()
|
|
||||||
for refresh_token in user.refresh_tokens.values()
|
|
||||||
]
|
|
||||||
|
|
||||||
access_tokens = [
|
|
||||||
{
|
|
||||||
'id': user.id,
|
|
||||||
'refresh_token_id': refresh_token.id,
|
|
||||||
'created_at': access_token.created_at.isoformat(),
|
|
||||||
'token': access_token.token,
|
|
||||||
}
|
|
||||||
for user in self._users.values()
|
|
||||||
for refresh_token in user.refresh_tokens.values()
|
|
||||||
for access_token in refresh_token.access_tokens
|
|
||||||
]
|
|
||||||
|
|
||||||
data = {
|
|
||||||
'users': users,
|
|
||||||
'credentials': credentials,
|
|
||||||
'access_tokens': access_tokens,
|
|
||||||
'refresh_tokens': refresh_tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
await self._store.async_save(data, delay=1)
|
|
191
homeassistant/auth/__init__.py
Normal file
191
homeassistant/auth/__init__.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
"""Provide an authentication layer for Home Assistant."""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from homeassistant import data_entry_flow
|
||||||
|
from homeassistant.core import callback
|
||||||
|
|
||||||
|
from . import models
|
||||||
|
from . import auth_store
|
||||||
|
from .providers import auth_provider_from_config
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_manager_from_config(hass, provider_configs):
|
||||||
|
"""Initialize an auth manager from config."""
|
||||||
|
store = auth_store.AuthStore(hass)
|
||||||
|
if provider_configs:
|
||||||
|
providers = await asyncio.gather(
|
||||||
|
*[auth_provider_from_config(hass, store, config)
|
||||||
|
for config in provider_configs])
|
||||||
|
else:
|
||||||
|
providers = []
|
||||||
|
# So returned auth providers are in same order as config
|
||||||
|
provider_hash = OrderedDict()
|
||||||
|
for provider in providers:
|
||||||
|
if provider is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = (provider.type, provider.id)
|
||||||
|
|
||||||
|
if key in provider_hash:
|
||||||
|
_LOGGER.error(
|
||||||
|
'Found duplicate provider: %s. Please add unique IDs if you '
|
||||||
|
'want to have the same provider twice.', key)
|
||||||
|
continue
|
||||||
|
|
||||||
|
provider_hash[key] = provider
|
||||||
|
manager = AuthManager(hass, store, provider_hash)
|
||||||
|
return manager
|
||||||
|
|
||||||
|
|
||||||
|
class AuthManager:
|
||||||
|
"""Manage the authentication for Home Assistant."""
|
||||||
|
|
||||||
|
def __init__(self, hass, store, providers):
|
||||||
|
"""Initialize the auth manager."""
|
||||||
|
self._store = store
|
||||||
|
self._providers = providers
|
||||||
|
self.login_flow = data_entry_flow.FlowManager(
|
||||||
|
hass, self._async_create_login_flow,
|
||||||
|
self._async_finish_login_flow)
|
||||||
|
self._access_tokens = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active(self):
|
||||||
|
"""Return if any auth providers are registered."""
|
||||||
|
return bool(self._providers)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def support_legacy(self):
|
||||||
|
"""
|
||||||
|
Return if legacy_api_password auth providers are registered.
|
||||||
|
|
||||||
|
Should be removed when we removed legacy_api_password auth providers.
|
||||||
|
"""
|
||||||
|
for provider_type, _ in self._providers:
|
||||||
|
if provider_type == 'legacy_api_password':
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def async_auth_providers(self):
|
||||||
|
"""Return a list of available auth providers."""
|
||||||
|
return self._providers.values()
|
||||||
|
|
||||||
|
async def async_get_user(self, user_id):
|
||||||
|
"""Retrieve a user."""
|
||||||
|
return await self._store.async_get_user(user_id)
|
||||||
|
|
||||||
|
async def async_create_system_user(self, name):
|
||||||
|
"""Create a system user."""
|
||||||
|
return await self._store.async_create_user(
|
||||||
|
name=name,
|
||||||
|
system_generated=True,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_get_or_create_user(self, credentials):
|
||||||
|
"""Get or create a user."""
|
||||||
|
if not credentials.is_new:
|
||||||
|
for user in await self._store.async_get_users():
|
||||||
|
for creds in user.credentials:
|
||||||
|
if creds.id == credentials.id:
|
||||||
|
return user
|
||||||
|
|
||||||
|
raise ValueError('Unable to find the user.')
|
||||||
|
|
||||||
|
auth_provider = self._async_get_auth_provider(credentials)
|
||||||
|
info = await auth_provider.async_user_meta_for_credentials(
|
||||||
|
credentials)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
'credentials': credentials,
|
||||||
|
'name': info.get('name')
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make owner and activate user if it's the first user.
|
||||||
|
if await self._store.async_get_users():
|
||||||
|
kwargs['is_owner'] = False
|
||||||
|
kwargs['is_active'] = False
|
||||||
|
else:
|
||||||
|
kwargs['is_owner'] = True
|
||||||
|
kwargs['is_active'] = True
|
||||||
|
|
||||||
|
return await self._store.async_create_user(**kwargs)
|
||||||
|
|
||||||
|
async def async_link_user(self, user, credentials):
|
||||||
|
"""Link credentials to an existing user."""
|
||||||
|
await self._store.async_link_user(user, credentials)
|
||||||
|
|
||||||
|
async def async_remove_user(self, user):
|
||||||
|
"""Remove a user."""
|
||||||
|
await self._store.async_remove_user(user)
|
||||||
|
|
||||||
|
async def async_create_refresh_token(self, user, client_id=None):
|
||||||
|
"""Create a new refresh token for a user."""
|
||||||
|
if not user.is_active:
|
||||||
|
raise ValueError('User is not active')
|
||||||
|
|
||||||
|
if user.system_generated and client_id is not None:
|
||||||
|
raise ValueError(
|
||||||
|
'System generated users cannot have refresh tokens connected '
|
||||||
|
'to a client.')
|
||||||
|
|
||||||
|
if not user.system_generated and client_id is None:
|
||||||
|
raise ValueError('Client is required to generate a refresh token.')
|
||||||
|
|
||||||
|
return await self._store.async_create_refresh_token(user, client_id)
|
||||||
|
|
||||||
|
async def async_get_refresh_token(self, token):
|
||||||
|
"""Get refresh token by token."""
|
||||||
|
return await self._store.async_get_refresh_token(token)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_create_access_token(self, refresh_token):
|
||||||
|
"""Create a new access token."""
|
||||||
|
access_token = models.AccessToken(refresh_token=refresh_token)
|
||||||
|
self._access_tokens[access_token.token] = access_token
|
||||||
|
return access_token
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_access_token(self, token):
|
||||||
|
"""Get an access token."""
|
||||||
|
tkn = self._access_tokens.get(token)
|
||||||
|
|
||||||
|
if tkn is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if tkn.expired:
|
||||||
|
self._access_tokens.pop(token)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return tkn
|
||||||
|
|
||||||
|
async def _async_create_login_flow(self, handler, *, source, data):
|
||||||
|
"""Create a login flow."""
|
||||||
|
auth_provider = self._providers[handler]
|
||||||
|
|
||||||
|
if not auth_provider.initialized:
|
||||||
|
auth_provider.initialized = True
|
||||||
|
await auth_provider.async_initialize()
|
||||||
|
|
||||||
|
return await auth_provider.async_credential_flow()
|
||||||
|
|
||||||
|
async def _async_finish_login_flow(self, result):
|
||||||
|
"""Result of a credential login flow."""
|
||||||
|
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||||
|
return None
|
||||||
|
|
||||||
|
auth_provider = self._providers[result['handler']]
|
||||||
|
return await auth_provider.async_get_or_create_credentials(
|
||||||
|
result['data'])
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_get_auth_provider(self, credentials):
|
||||||
|
"""Helper to get auth provider from a set of credentials."""
|
||||||
|
auth_provider_key = (credentials.auth_provider_type,
|
||||||
|
credentials.auth_provider_id)
|
||||||
|
return self._providers[auth_provider_key]
|
213
homeassistant/auth/auth_store.py
Normal file
213
homeassistant/auth/auth_store.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
"""Storage for auth models."""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
from . import models
|
||||||
|
|
||||||
|
STORAGE_VERSION = 1
|
||||||
|
STORAGE_KEY = 'auth'
|
||||||
|
|
||||||
|
|
||||||
|
class AuthStore:
|
||||||
|
"""Stores authentication info.
|
||||||
|
|
||||||
|
Any mutation to an object should happen inside the auth store.
|
||||||
|
|
||||||
|
The auth store is lazy. It won't load the data from disk until a method is
|
||||||
|
called that needs it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hass):
|
||||||
|
"""Initialize the auth store."""
|
||||||
|
self.hass = hass
|
||||||
|
self._users = None
|
||||||
|
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||||
|
|
||||||
|
async def async_get_users(self):
|
||||||
|
"""Retrieve all users."""
|
||||||
|
if self._users is None:
|
||||||
|
await self.async_load()
|
||||||
|
|
||||||
|
return list(self._users.values())
|
||||||
|
|
||||||
|
async def async_get_user(self, user_id):
|
||||||
|
"""Retrieve a user by id."""
|
||||||
|
if self._users is None:
|
||||||
|
await self.async_load()
|
||||||
|
|
||||||
|
return self._users.get(user_id)
|
||||||
|
|
||||||
|
async def async_create_user(self, name, is_owner=None, is_active=None,
|
||||||
|
system_generated=None, credentials=None):
|
||||||
|
"""Create a new user."""
|
||||||
|
if self._users is None:
|
||||||
|
await self.async_load()
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
'name': name
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_owner is not None:
|
||||||
|
kwargs['is_owner'] = is_owner
|
||||||
|
|
||||||
|
if is_active is not None:
|
||||||
|
kwargs['is_active'] = is_active
|
||||||
|
|
||||||
|
if system_generated is not None:
|
||||||
|
kwargs['system_generated'] = system_generated
|
||||||
|
|
||||||
|
new_user = models.User(**kwargs)
|
||||||
|
|
||||||
|
self._users[new_user.id] = new_user
|
||||||
|
|
||||||
|
if credentials is None:
|
||||||
|
await self.async_save()
|
||||||
|
return new_user
|
||||||
|
|
||||||
|
# Saving is done inside the link.
|
||||||
|
await self.async_link_user(new_user, credentials)
|
||||||
|
return new_user
|
||||||
|
|
||||||
|
async def async_link_user(self, user, credentials):
|
||||||
|
"""Add credentials to an existing user."""
|
||||||
|
user.credentials.append(credentials)
|
||||||
|
await self.async_save()
|
||||||
|
credentials.is_new = False
|
||||||
|
|
||||||
|
async def async_remove_user(self, user):
|
||||||
|
"""Remove a user."""
|
||||||
|
self._users.pop(user.id)
|
||||||
|
await self.async_save()
|
||||||
|
|
||||||
|
async def async_create_refresh_token(self, user, client_id=None):
|
||||||
|
"""Create a new token for a user."""
|
||||||
|
refresh_token = models.RefreshToken(user=user, client_id=client_id)
|
||||||
|
user.refresh_tokens[refresh_token.token] = refresh_token
|
||||||
|
await self.async_save()
|
||||||
|
return refresh_token
|
||||||
|
|
||||||
|
async def async_get_refresh_token(self, token):
|
||||||
|
"""Get refresh token by token."""
|
||||||
|
if self._users is None:
|
||||||
|
await self.async_load()
|
||||||
|
|
||||||
|
for user in self._users.values():
|
||||||
|
refresh_token = user.refresh_tokens.get(token)
|
||||||
|
if refresh_token is not None:
|
||||||
|
return refresh_token
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_load(self):
|
||||||
|
"""Load the users."""
|
||||||
|
data = await self._store.async_load()
|
||||||
|
|
||||||
|
# Make sure that we're not overriding data if 2 loads happened at the
|
||||||
|
# same time
|
||||||
|
if self._users is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
self._users = {}
|
||||||
|
return
|
||||||
|
|
||||||
|
users = {
|
||||||
|
user_dict['id']: models.User(**user_dict)
|
||||||
|
for user_dict in data['users']
|
||||||
|
}
|
||||||
|
|
||||||
|
for cred_dict in data['credentials']:
|
||||||
|
users[cred_dict['user_id']].credentials.append(models.Credentials(
|
||||||
|
id=cred_dict['id'],
|
||||||
|
is_new=False,
|
||||||
|
auth_provider_type=cred_dict['auth_provider_type'],
|
||||||
|
auth_provider_id=cred_dict['auth_provider_id'],
|
||||||
|
data=cred_dict['data'],
|
||||||
|
))
|
||||||
|
|
||||||
|
refresh_tokens = {}
|
||||||
|
|
||||||
|
for rt_dict in data['refresh_tokens']:
|
||||||
|
token = models.RefreshToken(
|
||||||
|
id=rt_dict['id'],
|
||||||
|
user=users[rt_dict['user_id']],
|
||||||
|
client_id=rt_dict['client_id'],
|
||||||
|
created_at=dt_util.parse_datetime(rt_dict['created_at']),
|
||||||
|
access_token_expiration=timedelta(
|
||||||
|
seconds=rt_dict['access_token_expiration']),
|
||||||
|
token=rt_dict['token'],
|
||||||
|
)
|
||||||
|
refresh_tokens[token.id] = token
|
||||||
|
users[rt_dict['user_id']].refresh_tokens[token.token] = token
|
||||||
|
|
||||||
|
for ac_dict in data['access_tokens']:
|
||||||
|
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
|
||||||
|
token = models.AccessToken(
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
created_at=dt_util.parse_datetime(ac_dict['created_at']),
|
||||||
|
token=ac_dict['token'],
|
||||||
|
)
|
||||||
|
refresh_token.access_tokens.append(token)
|
||||||
|
|
||||||
|
self._users = users
|
||||||
|
|
||||||
|
async def async_save(self):
|
||||||
|
"""Save users."""
|
||||||
|
users = [
|
||||||
|
{
|
||||||
|
'id': user.id,
|
||||||
|
'is_owner': user.is_owner,
|
||||||
|
'is_active': user.is_active,
|
||||||
|
'name': user.name,
|
||||||
|
'system_generated': user.system_generated,
|
||||||
|
}
|
||||||
|
for user in self._users.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
credentials = [
|
||||||
|
{
|
||||||
|
'id': credential.id,
|
||||||
|
'user_id': user.id,
|
||||||
|
'auth_provider_type': credential.auth_provider_type,
|
||||||
|
'auth_provider_id': credential.auth_provider_id,
|
||||||
|
'data': credential.data,
|
||||||
|
}
|
||||||
|
for user in self._users.values()
|
||||||
|
for credential in user.credentials
|
||||||
|
]
|
||||||
|
|
||||||
|
refresh_tokens = [
|
||||||
|
{
|
||||||
|
'id': refresh_token.id,
|
||||||
|
'user_id': user.id,
|
||||||
|
'client_id': refresh_token.client_id,
|
||||||
|
'created_at': refresh_token.created_at.isoformat(),
|
||||||
|
'access_token_expiration':
|
||||||
|
refresh_token.access_token_expiration.total_seconds(),
|
||||||
|
'token': refresh_token.token,
|
||||||
|
}
|
||||||
|
for user in self._users.values()
|
||||||
|
for refresh_token in user.refresh_tokens.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
access_tokens = [
|
||||||
|
{
|
||||||
|
'id': user.id,
|
||||||
|
'refresh_token_id': refresh_token.id,
|
||||||
|
'created_at': access_token.created_at.isoformat(),
|
||||||
|
'token': access_token.token,
|
||||||
|
}
|
||||||
|
for user in self._users.values()
|
||||||
|
for refresh_token in user.refresh_tokens.values()
|
||||||
|
for access_token in refresh_token.access_tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'users': users,
|
||||||
|
'credentials': credentials,
|
||||||
|
'access_tokens': access_tokens,
|
||||||
|
'refresh_tokens': refresh_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
await self._store.async_save(data, delay=1)
|
4
homeassistant/auth/const.py
Normal file
4
homeassistant/auth/const.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
"""Constants for the auth module."""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
|
75
homeassistant/auth/models.py
Normal file
75
homeassistant/auth/models.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
"""Auth models."""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
from .const import ACCESS_TOKEN_EXPIRATION
|
||||||
|
from .util import generate_secret
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class User:
|
||||||
|
"""A user."""
|
||||||
|
|
||||||
|
name = attr.ib(type=str)
|
||||||
|
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||||
|
is_owner = attr.ib(type=bool, default=False)
|
||||||
|
is_active = attr.ib(type=bool, default=False)
|
||||||
|
system_generated = attr.ib(type=bool, default=False)
|
||||||
|
|
||||||
|
# List of credentials of a user.
|
||||||
|
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
||||||
|
|
||||||
|
# Tokens associated with a user.
|
||||||
|
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class RefreshToken:
|
||||||
|
"""RefreshToken for a user to grant new access tokens."""
|
||||||
|
|
||||||
|
user = attr.ib(type=User)
|
||||||
|
client_id = attr.ib(type=str)
|
||||||
|
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||||
|
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
||||||
|
access_token_expiration = attr.ib(type=timedelta,
|
||||||
|
default=ACCESS_TOKEN_EXPIRATION)
|
||||||
|
token = attr.ib(type=str,
|
||||||
|
default=attr.Factory(lambda: generate_secret(64)))
|
||||||
|
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class AccessToken:
|
||||||
|
"""Access token to access the API.
|
||||||
|
|
||||||
|
These will only ever be stored in memory and not be persisted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
refresh_token = attr.ib(type=RefreshToken)
|
||||||
|
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
||||||
|
token = attr.ib(type=str,
|
||||||
|
default=attr.Factory(generate_secret))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
"""Return if this token has expired."""
|
||||||
|
expires = self.created_at + self.refresh_token.access_token_expiration
|
||||||
|
return dt_util.utcnow() > expires
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class Credentials:
|
||||||
|
"""Credentials for a user on an auth provider."""
|
||||||
|
|
||||||
|
auth_provider_type = attr.ib(type=str)
|
||||||
|
auth_provider_id = attr.ib(type=str)
|
||||||
|
|
||||||
|
# Allow the auth provider to store data to represent their auth.
|
||||||
|
data = attr.ib(type=dict)
|
||||||
|
|
||||||
|
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||||
|
is_new = attr.ib(type=bool, default=True)
|
147
homeassistant/auth/providers/__init__.py
Normal file
147
homeassistant/auth/providers/__init__.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
"""Auth providers for Home Assistant."""
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
|
from homeassistant import requirements
|
||||||
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
||||||
|
from homeassistant.util.decorator import Registry
|
||||||
|
|
||||||
|
from homeassistant.auth.models import Credentials
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
DATA_REQS = 'auth_prov_reqs_processed'
|
||||||
|
|
||||||
|
AUTH_PROVIDERS = Registry()
|
||||||
|
|
||||||
|
AUTH_PROVIDER_SCHEMA = vol.Schema({
|
||||||
|
vol.Required(CONF_TYPE): str,
|
||||||
|
vol.Optional(CONF_NAME): str,
|
||||||
|
# Specify ID if you have two auth providers for same type.
|
||||||
|
vol.Optional(CONF_ID): str,
|
||||||
|
}, extra=vol.ALLOW_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_provider_from_config(hass, store, config):
|
||||||
|
"""Initialize an auth provider from a config."""
|
||||||
|
provider_name = config[CONF_TYPE]
|
||||||
|
module = await load_auth_provider_module(hass, provider_name)
|
||||||
|
|
||||||
|
if module is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = module.CONFIG_SCHEMA(config)
|
||||||
|
except vol.Invalid as err:
|
||||||
|
_LOGGER.error('Invalid configuration for auth provider %s: %s',
|
||||||
|
provider_name, humanize_error(config, err))
|
||||||
|
return None
|
||||||
|
|
||||||
|
return AUTH_PROVIDERS[provider_name](hass, store, config)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_auth_provider_module(hass, provider):
|
||||||
|
"""Load an auth provider."""
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(
|
||||||
|
'homeassistant.auth.providers.{}'.format(provider))
|
||||||
|
except ImportError:
|
||||||
|
_LOGGER.warning('Unable to find auth provider %s', provider)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if hass.config.skip_pip or not hasattr(module, 'REQUIREMENTS'):
|
||||||
|
return module
|
||||||
|
|
||||||
|
processed = hass.data.get(DATA_REQS)
|
||||||
|
|
||||||
|
if processed is None:
|
||||||
|
processed = hass.data[DATA_REQS] = set()
|
||||||
|
elif provider in processed:
|
||||||
|
return module
|
||||||
|
|
||||||
|
req_success = await requirements.async_process_requirements(
|
||||||
|
hass, 'auth provider {}'.format(provider), module.REQUIREMENTS)
|
||||||
|
|
||||||
|
if not req_success:
|
||||||
|
return None
|
||||||
|
|
||||||
|
processed.add(provider)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class AuthProvider:
|
||||||
|
"""Provider of user authentication."""
|
||||||
|
|
||||||
|
DEFAULT_TITLE = 'Unnamed auth provider'
|
||||||
|
|
||||||
|
initialized = False
|
||||||
|
|
||||||
|
def __init__(self, hass, store, config):
|
||||||
|
"""Initialize an auth provider."""
|
||||||
|
self.hass = hass
|
||||||
|
self.store = store
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self): # pylint: disable=invalid-name
|
||||||
|
"""Return id of the auth provider.
|
||||||
|
|
||||||
|
Optional, can be None.
|
||||||
|
"""
|
||||||
|
return self.config.get(CONF_ID)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self):
|
||||||
|
"""Return type of the provider."""
|
||||||
|
return self.config[CONF_TYPE]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
"""Return the name of the auth provider."""
|
||||||
|
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
|
||||||
|
|
||||||
|
async def async_credentials(self):
|
||||||
|
"""Return all credentials of this provider."""
|
||||||
|
users = await self.store.async_get_users()
|
||||||
|
return [
|
||||||
|
credentials
|
||||||
|
for user in users
|
||||||
|
for credentials in user.credentials
|
||||||
|
if (credentials.auth_provider_type == self.type and
|
||||||
|
credentials.auth_provider_id == self.id)
|
||||||
|
]
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_create_credentials(self, data):
|
||||||
|
"""Create credentials."""
|
||||||
|
return Credentials(
|
||||||
|
auth_provider_type=self.type,
|
||||||
|
auth_provider_id=self.id,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Implement by extending class
|
||||||
|
|
||||||
|
async def async_initialize(self):
|
||||||
|
"""Initialize the auth provider.
|
||||||
|
|
||||||
|
Optional.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def async_credential_flow(self):
|
||||||
|
"""Return the data flow for logging in with auth provider."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_get_or_create_credentials(self, flow_result):
|
||||||
|
"""Get credentials based on the flow result."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_user_meta_for_credentials(self, credentials):
|
||||||
|
"""Return extra user metadata for credentials.
|
||||||
|
|
||||||
|
Will be used to populate info when creating a new user.
|
||||||
|
"""
|
||||||
|
return {}
|
@ -6,14 +6,17 @@ import hmac
|
|||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import auth, data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
|
from homeassistant.auth.util import generate_secret
|
||||||
|
|
||||||
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||||
|
|
||||||
STORAGE_VERSION = 1
|
STORAGE_VERSION = 1
|
||||||
STORAGE_KEY = 'auth_provider.homeassistant'
|
STORAGE_KEY = 'auth_provider.homeassistant'
|
||||||
|
|
||||||
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
|
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||||
}, extra=vol.PREVENT_EXTRA)
|
}, extra=vol.PREVENT_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
@ -43,7 +46,7 @@ class Data:
|
|||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
data = {
|
data = {
|
||||||
'salt': auth.generate_secret(),
|
'salt': generate_secret(),
|
||||||
'users': []
|
'users': []
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,8 +115,8 @@ class Data:
|
|||||||
await self._store.async_save(self._data)
|
await self._store.async_save(self._data)
|
||||||
|
|
||||||
|
|
||||||
@auth.AUTH_PROVIDERS.register('homeassistant')
|
@AUTH_PROVIDERS.register('homeassistant')
|
||||||
class HassAuthProvider(auth.AuthProvider):
|
class HassAuthProvider(AuthProvider):
|
||||||
"""Auth provider based on a local storage of users in HASS config dir."""
|
"""Auth provider based on a local storage of users in HASS config dir."""
|
||||||
|
|
||||||
DEFAULT_TITLE = 'Home Assistant Local'
|
DEFAULT_TITLE = 'Home Assistant Local'
|
@ -5,9 +5,11 @@ import hmac
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant import auth, data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
USER_SCHEMA = vol.Schema({
|
USER_SCHEMA = vol.Schema({
|
||||||
vol.Required('username'): str,
|
vol.Required('username'): str,
|
||||||
@ -16,7 +18,7 @@ USER_SCHEMA = vol.Schema({
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
|
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||||
vol.Required('users'): [USER_SCHEMA]
|
vol.Required('users'): [USER_SCHEMA]
|
||||||
}, extra=vol.PREVENT_EXTRA)
|
}, extra=vol.PREVENT_EXTRA)
|
||||||
|
|
||||||
@ -25,8 +27,8 @@ class InvalidAuthError(HomeAssistantError):
|
|||||||
"""Raised when submitting invalid authentication."""
|
"""Raised when submitting invalid authentication."""
|
||||||
|
|
||||||
|
|
||||||
@auth.AUTH_PROVIDERS.register('insecure_example')
|
@AUTH_PROVIDERS.register('insecure_example')
|
||||||
class ExampleAuthProvider(auth.AuthProvider):
|
class ExampleAuthProvider(AuthProvider):
|
||||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||||
|
|
||||||
async def async_credential_flow(self):
|
async def async_credential_flow(self):
|
@ -9,15 +9,18 @@ import hmac
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant import auth, data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
USER_SCHEMA = vol.Schema({
|
USER_SCHEMA = vol.Schema({
|
||||||
vol.Required('username'): str,
|
vol.Required('username'): str,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
|
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||||
}, extra=vol.PREVENT_EXTRA)
|
}, extra=vol.PREVENT_EXTRA)
|
||||||
|
|
||||||
LEGACY_USER = 'homeassistant'
|
LEGACY_USER = 'homeassistant'
|
||||||
@ -27,8 +30,8 @@ class InvalidAuthError(HomeAssistantError):
|
|||||||
"""Raised when submitting invalid authentication."""
|
"""Raised when submitting invalid authentication."""
|
||||||
|
|
||||||
|
|
||||||
@auth.AUTH_PROVIDERS.register('legacy_api_password')
|
@AUTH_PROVIDERS.register('legacy_api_password')
|
||||||
class LegacyApiPasswordAuthProvider(auth.AuthProvider):
|
class LegacyApiPasswordAuthProvider(AuthProvider):
|
||||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||||
|
|
||||||
DEFAULT_TITLE = 'Legacy API Password'
|
DEFAULT_TITLE = 'Legacy API Password'
|
13
homeassistant/auth/util.py
Normal file
13
homeassistant/auth/util.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
"""Auth utils."""
|
||||||
|
import binascii
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def generate_secret(entropy: int = 32) -> str:
|
||||||
|
"""Generate a secret.
|
||||||
|
|
||||||
|
Backport of secrets.token_hex from Python 3.6
|
||||||
|
|
||||||
|
Event loop friendly.
|
||||||
|
"""
|
||||||
|
return binascii.hexlify(os.urandom(entropy)).decode('ascii')
|
@ -1 +0,0 @@
|
|||||||
"""Auth providers for Home Assistant."""
|
|
@ -10,7 +10,7 @@ import logging
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.auth import generate_secret
|
from homeassistant.auth.util import generate_secret
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.const import CONF_API_KEY, EVENT_HOMEASSISTANT_STOP, URL_API
|
from homeassistant.const import CONF_API_KEY, EVENT_HOMEASSISTANT_STOP, URL_API
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
@ -13,6 +13,7 @@ import voluptuous as vol
|
|||||||
from voluptuous.humanize import humanize_error
|
from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
from homeassistant import auth
|
from homeassistant import auth
|
||||||
|
from homeassistant.auth import providers as auth_providers
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_FRIENDLY_NAME, ATTR_HIDDEN, ATTR_ASSUMED_STATE,
|
ATTR_FRIENDLY_NAME, ATTR_HIDDEN, ATTR_ASSUMED_STATE,
|
||||||
CONF_LATITUDE, CONF_LONGITUDE, CONF_NAME, CONF_PACKAGES, CONF_UNIT_SYSTEM,
|
CONF_LATITUDE, CONF_LONGITUDE, CONF_NAME, CONF_PACKAGES, CONF_UNIT_SYSTEM,
|
||||||
@ -159,7 +160,7 @@ CORE_CONFIG_SCHEMA = CUSTOMIZE_CONFIG_SCHEMA.extend({
|
|||||||
vol.All(cv.ensure_list, [vol.IsDir()]),
|
vol.All(cv.ensure_list, [vol.IsDir()]),
|
||||||
vol.Optional(CONF_PACKAGES, default={}): PACKAGES_CONFIG_SCHEMA,
|
vol.Optional(CONF_PACKAGES, default={}): PACKAGES_CONFIG_SCHEMA,
|
||||||
vol.Optional(CONF_AUTH_PROVIDERS):
|
vol.Optional(CONF_AUTH_PROVIDERS):
|
||||||
vol.All(cv.ensure_list, [auth.AUTH_PROVIDER_SCHEMA])
|
vol.All(cv.ensure_list, [auth_providers.AUTH_PROVIDER_SCHEMA])
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import os
|
|||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.config import get_default_config_dir
|
from homeassistant.config import get_default_config_dir
|
||||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
from homeassistant.auth.providers import homeassistant as hass_auth
|
||||||
|
|
||||||
|
|
||||||
def run(args):
|
def run(args):
|
||||||
|
1
tests/auth/__init__.py
Normal file
1
tests/auth/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Tests for the Home Assistant auth module."""
|
@ -2,7 +2,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
from homeassistant.auth.providers import homeassistant as hass_auth
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
@ -4,8 +4,8 @@ import uuid
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import auth
|
from homeassistant.auth import auth_store, models as auth_models
|
||||||
from homeassistant.auth_providers import insecure_example
|
from homeassistant.auth.providers import insecure_example
|
||||||
|
|
||||||
from tests.common import mock_coro
|
from tests.common import mock_coro
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ from tests.common import mock_coro
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def store(hass):
|
def store(hass):
|
||||||
"""Mock store."""
|
"""Mock store."""
|
||||||
return auth.AuthStore(hass)
|
return auth_store.AuthStore(hass)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -45,7 +45,7 @@ async def test_create_new_credential(provider):
|
|||||||
|
|
||||||
async def test_match_existing_credentials(store, provider):
|
async def test_match_existing_credentials(store, provider):
|
||||||
"""See if we match existing users."""
|
"""See if we match existing users."""
|
||||||
existing = auth.Credentials(
|
existing = auth_models.Credentials(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
auth_provider_type='insecure_example',
|
auth_provider_type='insecure_example',
|
||||||
auth_provider_id=None,
|
auth_provider_id=None,
|
@ -4,13 +4,14 @@ from unittest.mock import Mock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import auth
|
from homeassistant import auth
|
||||||
from homeassistant.auth_providers import legacy_api_password
|
from homeassistant.auth import auth_store
|
||||||
|
from homeassistant.auth.providers import legacy_api_password
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def store(hass):
|
def store(hass):
|
||||||
"""Mock store."""
|
"""Mock store."""
|
||||||
return auth.AuthStore(hass)
|
return auth_store.AuthStore(hass)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
@ -5,6 +5,8 @@ from unittest.mock import Mock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import auth, data_entry_flow
|
from homeassistant import auth, data_entry_flow
|
||||||
|
from homeassistant.auth import (
|
||||||
|
models as auth_models, auth_store, const as auth_const)
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
MockUser, ensure_auth_manager_loaded, flush_store, CLIENT_ID)
|
MockUser, ensure_auth_manager_loaded, flush_store, CLIENT_ID)
|
||||||
@ -101,7 +103,7 @@ async def test_login_as_existing_user(mock_hass):
|
|||||||
is_active=False,
|
is_active=False,
|
||||||
name='Not user',
|
name='Not user',
|
||||||
).add_to_auth_manager(manager)
|
).add_to_auth_manager(manager)
|
||||||
user.credentials.append(auth.Credentials(
|
user.credentials.append(auth_models.Credentials(
|
||||||
id='mock-id2',
|
id='mock-id2',
|
||||||
auth_provider_type='insecure_example',
|
auth_provider_type='insecure_example',
|
||||||
auth_provider_id=None,
|
auth_provider_id=None,
|
||||||
@ -116,7 +118,7 @@ async def test_login_as_existing_user(mock_hass):
|
|||||||
is_active=False,
|
is_active=False,
|
||||||
name='Paulus',
|
name='Paulus',
|
||||||
).add_to_auth_manager(manager)
|
).add_to_auth_manager(manager)
|
||||||
user.credentials.append(auth.Credentials(
|
user.credentials.append(auth_models.Credentials(
|
||||||
id='mock-id',
|
id='mock-id',
|
||||||
auth_provider_type='insecure_example',
|
auth_provider_type='insecure_example',
|
||||||
auth_provider_id=None,
|
auth_provider_id=None,
|
||||||
@ -203,7 +205,7 @@ async def test_saving_loading(hass, hass_storage):
|
|||||||
|
|
||||||
await flush_store(manager._store._store)
|
await flush_store(manager._store._store)
|
||||||
|
|
||||||
store2 = auth.AuthStore(hass)
|
store2 = auth_store.AuthStore(hass)
|
||||||
users = await store2.async_get_users()
|
users = await store2.async_get_users()
|
||||||
assert len(users) == 1
|
assert len(users) == 1
|
||||||
assert users[0] == user
|
assert users[0] == user
|
||||||
@ -211,23 +213,25 @@ async def test_saving_loading(hass, hass_storage):
|
|||||||
|
|
||||||
def test_access_token_expired():
|
def test_access_token_expired():
|
||||||
"""Test that the expired property on access tokens work."""
|
"""Test that the expired property on access tokens work."""
|
||||||
refresh_token = auth.RefreshToken(
|
refresh_token = auth_models.RefreshToken(
|
||||||
user=None,
|
user=None,
|
||||||
client_id='bla'
|
client_id='bla'
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = auth.AccessToken(
|
access_token = auth_models.AccessToken(
|
||||||
refresh_token=refresh_token
|
refresh_token=refresh_token
|
||||||
)
|
)
|
||||||
|
|
||||||
assert access_token.expired is False
|
assert access_token.expired is False
|
||||||
|
|
||||||
with patch('homeassistant.auth.dt_util.utcnow',
|
with patch('homeassistant.util.dt.utcnow',
|
||||||
return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION):
|
return_value=dt_util.utcnow() +
|
||||||
|
auth_const.ACCESS_TOKEN_EXPIRATION):
|
||||||
assert access_token.expired is True
|
assert access_token.expired is True
|
||||||
|
|
||||||
almost_exp = dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION - timedelta(1)
|
almost_exp = \
|
||||||
with patch('homeassistant.auth.dt_util.utcnow', return_value=almost_exp):
|
dt_util.utcnow() + auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(1)
|
||||||
|
with patch('homeassistant.util.dt.utcnow', return_value=almost_exp):
|
||||||
assert access_token.expired is False
|
assert access_token.expired is False
|
||||||
|
|
||||||
|
|
||||||
@ -242,8 +246,9 @@ async def test_cannot_retrieve_expired_access_token(hass):
|
|||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
assert manager.async_get_access_token(access_token.token) is access_token
|
assert manager.async_get_access_token(access_token.token) is access_token
|
||||||
|
|
||||||
with patch('homeassistant.auth.dt_util.utcnow',
|
with patch('homeassistant.util.dt.utcnow',
|
||||||
return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION):
|
return_value=dt_util.utcnow() +
|
||||||
|
auth_const.ACCESS_TOKEN_EXPIRATION):
|
||||||
assert manager.async_get_access_token(access_token.token) is None
|
assert manager.async_get_access_token(access_token.token) is None
|
||||||
|
|
||||||
# Even with unpatched time, it should have been removed from manager
|
# Even with unpatched time, it should have been removed from manager
|
@ -12,6 +12,7 @@ import threading
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
from homeassistant import auth, core as ha, data_entry_flow, config_entries
|
from homeassistant import auth, core as ha, data_entry_flow, config_entries
|
||||||
|
from homeassistant.auth import models as auth_models, auth_store
|
||||||
from homeassistant.setup import setup_component, async_setup_component
|
from homeassistant.setup import setup_component, async_setup_component
|
||||||
from homeassistant.config import async_process_component_config
|
from homeassistant.config import async_process_component_config
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
@ -114,7 +115,7 @@ def async_test_home_assistant(loop):
|
|||||||
"""Return a Home Assistant object pointing at test config dir."""
|
"""Return a Home Assistant object pointing at test config dir."""
|
||||||
hass = ha.HomeAssistant(loop)
|
hass = ha.HomeAssistant(loop)
|
||||||
hass.config.async_load = Mock()
|
hass.config.async_load = Mock()
|
||||||
store = auth.AuthStore(hass)
|
store = auth_store.AuthStore(hass)
|
||||||
hass.auth = auth.AuthManager(hass, store, {})
|
hass.auth = auth.AuthManager(hass, store, {})
|
||||||
ensure_auth_manager_loaded(hass.auth)
|
ensure_auth_manager_loaded(hass.auth)
|
||||||
INSTANCES.append(hass)
|
INSTANCES.append(hass)
|
||||||
@ -308,7 +309,7 @@ def mock_registry(hass, mock_entries=None):
|
|||||||
return registry
|
return registry
|
||||||
|
|
||||||
|
|
||||||
class MockUser(auth.User):
|
class MockUser(auth_models.User):
|
||||||
"""Mock a user in Home Assistant."""
|
"""Mock a user in Home Assistant."""
|
||||||
|
|
||||||
def __init__(self, id='mock-id', is_owner=True, is_active=True,
|
def __init__(self, id='mock-id', is_owner=True, is_active=True,
|
||||||
|
@ -7,7 +7,7 @@ import pytest
|
|||||||
from aiohttp import BasicAuth, web
|
from aiohttp import BasicAuth, web
|
||||||
from aiohttp.web_exceptions import HTTPUnauthorized
|
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||||
|
|
||||||
from homeassistant.auth import AccessToken, RefreshToken
|
from homeassistant.auth.models import AccessToken, RefreshToken
|
||||||
from homeassistant.components.http.auth import setup_auth
|
from homeassistant.components.http.auth import setup_auth
|
||||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||||
from homeassistant.components.http.real_ip import setup_real_ip
|
from homeassistant.components.http.real_ip import setup_real_ip
|
||||||
|
@ -4,7 +4,7 @@ from unittest.mock import Mock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.scripts import auth as script_auth
|
from homeassistant.scripts import auth as script_auth
|
||||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
from homeassistant.auth.providers import homeassistant as hass_auth
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
Loading…
x
Reference in New Issue
Block a user