mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 19:57:07 +00:00
Add type hints to homeassistant.auth (#15853)
* Always load users in auth store before use * Use namedtuple instead of dict for user meta * Ignore auth store tokens with invalid created_at * Add type hints to homeassistant.auth
This commit is contained in:
parent
e9e5bce10c
commit
649f17fe47
@ -2,7 +2,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List, Awaitable
|
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
@ -10,15 +10,17 @@ from homeassistant import data_entry_flow
|
|||||||
from homeassistant.core import callback, HomeAssistant
|
from homeassistant.core import callback, HomeAssistant
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from . import auth_store
|
from . import auth_store, models
|
||||||
from .providers import auth_provider_from_config
|
from .providers import auth_provider_from_config, AuthProvider
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
_ProviderKey = Tuple[str, Optional[str]]
|
||||||
|
_ProviderDict = Dict[_ProviderKey, AuthProvider]
|
||||||
|
|
||||||
|
|
||||||
async def auth_manager_from_config(
|
async def auth_manager_from_config(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
provider_configs: List[dict]) -> Awaitable['AuthManager']:
|
provider_configs: List[Dict[str, Any]]) -> 'AuthManager':
|
||||||
"""Initialize an auth manager from config."""
|
"""Initialize an auth manager from config."""
|
||||||
store = auth_store.AuthStore(hass)
|
store = auth_store.AuthStore(hass)
|
||||||
if provider_configs:
|
if provider_configs:
|
||||||
@ -26,9 +28,9 @@ async def auth_manager_from_config(
|
|||||||
*[auth_provider_from_config(hass, store, config)
|
*[auth_provider_from_config(hass, store, config)
|
||||||
for config in provider_configs])
|
for config in provider_configs])
|
||||||
else:
|
else:
|
||||||
providers = []
|
providers = ()
|
||||||
# So returned auth providers are in same order as config
|
# So returned auth providers are in same order as config
|
||||||
provider_hash = OrderedDict()
|
provider_hash = OrderedDict() # type: _ProviderDict
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if provider is None:
|
if provider is None:
|
||||||
continue
|
continue
|
||||||
@ -49,7 +51,8 @@ async def auth_manager_from_config(
|
|||||||
class AuthManager:
|
class AuthManager:
|
||||||
"""Manage the authentication for Home Assistant."""
|
"""Manage the authentication for Home Assistant."""
|
||||||
|
|
||||||
def __init__(self, hass, store, providers):
|
def __init__(self, hass: HomeAssistant, store: auth_store.AuthStore,
|
||||||
|
providers: _ProviderDict) -> None:
|
||||||
"""Initialize the auth manager."""
|
"""Initialize the auth manager."""
|
||||||
self._store = store
|
self._store = store
|
||||||
self._providers = providers
|
self._providers = providers
|
||||||
@ -58,12 +61,12 @@ class AuthManager:
|
|||||||
self._async_finish_login_flow)
|
self._async_finish_login_flow)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def active(self):
|
def active(self) -> bool:
|
||||||
"""Return if any auth providers are registered."""
|
"""Return if any auth providers are registered."""
|
||||||
return bool(self._providers)
|
return bool(self._providers)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def support_legacy(self):
|
def support_legacy(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Return if legacy_api_password auth providers are registered.
|
Return if legacy_api_password auth providers are registered.
|
||||||
|
|
||||||
@ -75,19 +78,19 @@ class AuthManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def auth_providers(self):
|
def auth_providers(self) -> List[AuthProvider]:
|
||||||
"""Return a list of available auth providers."""
|
"""Return a list of available auth providers."""
|
||||||
return list(self._providers.values())
|
return list(self._providers.values())
|
||||||
|
|
||||||
async def async_get_users(self):
|
async def async_get_users(self) -> List[models.User]:
|
||||||
"""Retrieve all users."""
|
"""Retrieve all users."""
|
||||||
return await self._store.async_get_users()
|
return await self._store.async_get_users()
|
||||||
|
|
||||||
async def async_get_user(self, user_id):
|
async def async_get_user(self, user_id: str) -> Optional[models.User]:
|
||||||
"""Retrieve a user."""
|
"""Retrieve a user."""
|
||||||
return await self._store.async_get_user(user_id)
|
return await self._store.async_get_user(user_id)
|
||||||
|
|
||||||
async def async_create_system_user(self, name):
|
async def async_create_system_user(self, name: str) -> models.User:
|
||||||
"""Create a system user."""
|
"""Create a system user."""
|
||||||
return await self._store.async_create_user(
|
return await self._store.async_create_user(
|
||||||
name=name,
|
name=name,
|
||||||
@ -95,19 +98,20 @@ class AuthManager:
|
|||||||
is_active=True,
|
is_active=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_create_user(self, name):
|
async def async_create_user(self, name: str) -> models.User:
|
||||||
"""Create a user."""
|
"""Create a user."""
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'name': name,
|
'name': name,
|
||||||
'is_active': True,
|
'is_active': True,
|
||||||
}
|
} # type: Dict[str, Any]
|
||||||
|
|
||||||
if await self._user_should_be_owner():
|
if await self._user_should_be_owner():
|
||||||
kwargs['is_owner'] = True
|
kwargs['is_owner'] = True
|
||||||
|
|
||||||
return await self._store.async_create_user(**kwargs)
|
return await self._store.async_create_user(**kwargs)
|
||||||
|
|
||||||
async def async_get_or_create_user(self, credentials):
|
async def async_get_or_create_user(self, credentials: models.Credentials) \
|
||||||
|
-> models.User:
|
||||||
"""Get or create a user."""
|
"""Get or create a user."""
|
||||||
if not credentials.is_new:
|
if not credentials.is_new:
|
||||||
for user in await self._store.async_get_users():
|
for user in await self._store.async_get_users():
|
||||||
@ -127,15 +131,16 @@ class AuthManager:
|
|||||||
|
|
||||||
return await self._store.async_create_user(
|
return await self._store.async_create_user(
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
name=info.get('name'),
|
name=info.name,
|
||||||
is_active=info.get('is_active', False)
|
is_active=info.is_active,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_link_user(self, user, credentials):
|
async def async_link_user(self, user: models.User,
|
||||||
|
credentials: models.Credentials) -> None:
|
||||||
"""Link credentials to an existing user."""
|
"""Link credentials to an existing user."""
|
||||||
await self._store.async_link_user(user, credentials)
|
await self._store.async_link_user(user, credentials)
|
||||||
|
|
||||||
async def async_remove_user(self, user):
|
async def async_remove_user(self, user: models.User) -> None:
|
||||||
"""Remove a user."""
|
"""Remove a user."""
|
||||||
tasks = [
|
tasks = [
|
||||||
self.async_remove_credentials(credentials)
|
self.async_remove_credentials(credentials)
|
||||||
@ -147,27 +152,32 @@ class AuthManager:
|
|||||||
|
|
||||||
await self._store.async_remove_user(user)
|
await self._store.async_remove_user(user)
|
||||||
|
|
||||||
async def async_activate_user(self, user):
|
async def async_activate_user(self, user: models.User) -> None:
|
||||||
"""Activate a user."""
|
"""Activate a user."""
|
||||||
await self._store.async_activate_user(user)
|
await self._store.async_activate_user(user)
|
||||||
|
|
||||||
async def async_deactivate_user(self, user):
|
async def async_deactivate_user(self, user: models.User) -> None:
|
||||||
"""Deactivate a user."""
|
"""Deactivate a user."""
|
||||||
if user.is_owner:
|
if user.is_owner:
|
||||||
raise ValueError('Unable to deactive the owner')
|
raise ValueError('Unable to deactive the owner')
|
||||||
await self._store.async_deactivate_user(user)
|
await self._store.async_deactivate_user(user)
|
||||||
|
|
||||||
async def async_remove_credentials(self, credentials):
|
async def async_remove_credentials(
|
||||||
|
self, credentials: models.Credentials) -> None:
|
||||||
"""Remove credentials."""
|
"""Remove credentials."""
|
||||||
provider = self._async_get_auth_provider(credentials)
|
provider = self._async_get_auth_provider(credentials)
|
||||||
|
|
||||||
if (provider is not None and
|
if (provider is not None and
|
||||||
hasattr(provider, 'async_will_remove_credentials')):
|
hasattr(provider, 'async_will_remove_credentials')):
|
||||||
await provider.async_will_remove_credentials(credentials)
|
# https://github.com/python/mypy/issues/1424
|
||||||
|
await provider.async_will_remove_credentials( # type: ignore
|
||||||
|
credentials)
|
||||||
|
|
||||||
await self._store.async_remove_credentials(credentials)
|
await self._store.async_remove_credentials(credentials)
|
||||||
|
|
||||||
async def async_create_refresh_token(self, user, client_id=None):
|
async def async_create_refresh_token(self, user: models.User,
|
||||||
|
client_id: Optional[str] = None) \
|
||||||
|
-> models.RefreshToken:
|
||||||
"""Create a new refresh token for a user."""
|
"""Create a new refresh token for a user."""
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
raise ValueError('User is not active')
|
raise ValueError('User is not active')
|
||||||
@ -182,16 +192,19 @@ class AuthManager:
|
|||||||
|
|
||||||
return await self._store.async_create_refresh_token(user, client_id)
|
return await self._store.async_create_refresh_token(user, client_id)
|
||||||
|
|
||||||
async def async_get_refresh_token(self, token_id):
|
async def async_get_refresh_token(
|
||||||
|
self, token_id: str) -> Optional[models.RefreshToken]:
|
||||||
"""Get refresh token by id."""
|
"""Get refresh token by id."""
|
||||||
return await self._store.async_get_refresh_token(token_id)
|
return await self._store.async_get_refresh_token(token_id)
|
||||||
|
|
||||||
async def async_get_refresh_token_by_token(self, token):
|
async def async_get_refresh_token_by_token(
|
||||||
|
self, token: str) -> Optional[models.RefreshToken]:
|
||||||
"""Get refresh token by token."""
|
"""Get refresh token by token."""
|
||||||
return await self._store.async_get_refresh_token_by_token(token)
|
return await self._store.async_get_refresh_token_by_token(token)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_access_token(self, refresh_token):
|
def async_create_access_token(self,
|
||||||
|
refresh_token: models.RefreshToken) -> str:
|
||||||
"""Create a new access token."""
|
"""Create a new access token."""
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
return jwt.encode({
|
return jwt.encode({
|
||||||
@ -200,7 +213,8 @@ class AuthManager:
|
|||||||
'exp': dt_util.utcnow() + refresh_token.access_token_expiration,
|
'exp': dt_util.utcnow() + refresh_token.access_token_expiration,
|
||||||
}, refresh_token.jwt_key, algorithm='HS256').decode()
|
}, refresh_token.jwt_key, algorithm='HS256').decode()
|
||||||
|
|
||||||
async def async_validate_access_token(self, token):
|
async def async_validate_access_token(
|
||||||
|
self, token: str) -> Optional[models.RefreshToken]:
|
||||||
"""Return if an access token is valid."""
|
"""Return if an access token is valid."""
|
||||||
try:
|
try:
|
||||||
unverif_claims = jwt.decode(token, verify=False)
|
unverif_claims = jwt.decode(token, verify=False)
|
||||||
@ -208,7 +222,7 @@ class AuthManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
refresh_token = await self.async_get_refresh_token(
|
refresh_token = await self.async_get_refresh_token(
|
||||||
unverif_claims.get('iss'))
|
cast(str, unverif_claims.get('iss')))
|
||||||
|
|
||||||
if refresh_token is None:
|
if refresh_token is None:
|
||||||
jwt_key = ''
|
jwt_key = ''
|
||||||
@ -228,18 +242,22 @@ class AuthManager:
|
|||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not refresh_token.user.is_active:
|
if refresh_token is None or not refresh_token.user.is_active:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return refresh_token
|
return refresh_token
|
||||||
|
|
||||||
async def _async_create_login_flow(self, handler, *, context, data):
|
async def _async_create_login_flow(
|
||||||
|
self, handler: _ProviderKey, *, context: Optional[Dict],
|
||||||
|
data: Optional[Any]) -> data_entry_flow.FlowHandler:
|
||||||
"""Create a login flow."""
|
"""Create a login flow."""
|
||||||
auth_provider = self._providers[handler]
|
auth_provider = self._providers[handler]
|
||||||
|
|
||||||
return await auth_provider.async_credential_flow(context)
|
return await auth_provider.async_credential_flow(context)
|
||||||
|
|
||||||
async def _async_finish_login_flow(self, context, result):
|
async def _async_finish_login_flow(
|
||||||
|
self, context: Optional[Dict], result: Dict[str, Any]) \
|
||||||
|
-> Optional[models.Credentials]:
|
||||||
"""Result of a credential login flow."""
|
"""Result of a credential login flow."""
|
||||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||||
return None
|
return None
|
||||||
@ -249,13 +267,14 @@ class AuthManager:
|
|||||||
result['data'])
|
result['data'])
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_get_auth_provider(self, credentials):
|
def _async_get_auth_provider(
|
||||||
|
self, credentials: models.Credentials) -> Optional[AuthProvider]:
|
||||||
"""Helper to get auth provider from a set of credentials."""
|
"""Helper to get auth provider from a set of credentials."""
|
||||||
auth_provider_key = (credentials.auth_provider_type,
|
auth_provider_key = (credentials.auth_provider_type,
|
||||||
credentials.auth_provider_id)
|
credentials.auth_provider_id)
|
||||||
return self._providers.get(auth_provider_key)
|
return self._providers.get(auth_provider_key)
|
||||||
|
|
||||||
async def _user_should_be_owner(self):
|
async def _user_should_be_owner(self) -> bool:
|
||||||
"""Determine if user should be owner.
|
"""Determine if user should be owner.
|
||||||
|
|
||||||
A user should be an owner if it is the first non-system user that is
|
A user should be an owner if it is the first non-system user that is
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
"""Storage for auth models."""
|
"""Storage for auth models."""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Any, Dict, List, Optional # noqa: F401
|
||||||
import hmac
|
import hmac
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from . import models
|
from . import models
|
||||||
@ -20,35 +23,41 @@ class AuthStore:
|
|||||||
called that needs it.
|
called that needs it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hass):
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the auth store."""
|
"""Initialize the auth store."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._users = None
|
self._users = None # type: Optional[Dict[str, models.User]]
|
||||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||||
|
|
||||||
async def async_get_users(self):
|
async def async_get_users(self) -> List[models.User]:
|
||||||
"""Retrieve all users."""
|
"""Retrieve all users."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
assert self._users is not None
|
||||||
|
|
||||||
return list(self._users.values())
|
return list(self._users.values())
|
||||||
|
|
||||||
async def async_get_user(self, user_id):
|
async def async_get_user(self, user_id: str) -> Optional[models.User]:
|
||||||
"""Retrieve a user by id."""
|
"""Retrieve a user by id."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
assert self._users is not None
|
||||||
|
|
||||||
return self._users.get(user_id)
|
return self._users.get(user_id)
|
||||||
|
|
||||||
async def async_create_user(self, name, is_owner=None, is_active=None,
|
async def async_create_user(
|
||||||
system_generated=None, credentials=None):
|
self, name: Optional[str], is_owner: Optional[bool] = None,
|
||||||
|
is_active: Optional[bool] = None,
|
||||||
|
system_generated: Optional[bool] = None,
|
||||||
|
credentials: Optional[models.Credentials] = None) -> models.User:
|
||||||
"""Create a new user."""
|
"""Create a new user."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
assert self._users is not None
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'name': name
|
'name': name
|
||||||
}
|
} # type: Dict[str, Any]
|
||||||
|
|
||||||
if is_owner is not None:
|
if is_owner is not None:
|
||||||
kwargs['is_owner'] = is_owner
|
kwargs['is_owner'] = is_owner
|
||||||
@ -71,29 +80,39 @@ class AuthStore:
|
|||||||
await self.async_link_user(new_user, credentials)
|
await self.async_link_user(new_user, credentials)
|
||||||
return new_user
|
return new_user
|
||||||
|
|
||||||
async def async_link_user(self, user, credentials):
|
async def async_link_user(self, user: models.User,
|
||||||
|
credentials: models.Credentials) -> None:
|
||||||
"""Add credentials to an existing user."""
|
"""Add credentials to an existing user."""
|
||||||
user.credentials.append(credentials)
|
user.credentials.append(credentials)
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
credentials.is_new = False
|
credentials.is_new = False
|
||||||
|
|
||||||
async def async_remove_user(self, user):
|
async def async_remove_user(self, user: models.User) -> None:
|
||||||
"""Remove a user."""
|
"""Remove a user."""
|
||||||
|
if self._users is None:
|
||||||
|
await self.async_load()
|
||||||
|
assert self._users is not None
|
||||||
|
|
||||||
self._users.pop(user.id)
|
self._users.pop(user.id)
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
|
|
||||||
async def async_activate_user(self, user):
|
async def async_activate_user(self, user: models.User) -> None:
|
||||||
"""Activate a user."""
|
"""Activate a user."""
|
||||||
user.is_active = True
|
user.is_active = True
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
|
|
||||||
async def async_deactivate_user(self, user):
|
async def async_deactivate_user(self, user: models.User) -> None:
|
||||||
"""Activate a user."""
|
"""Activate a user."""
|
||||||
user.is_active = False
|
user.is_active = False
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
|
|
||||||
async def async_remove_credentials(self, credentials):
|
async def async_remove_credentials(
|
||||||
|
self, credentials: models.Credentials) -> None:
|
||||||
"""Remove credentials."""
|
"""Remove credentials."""
|
||||||
|
if self._users is None:
|
||||||
|
await self.async_load()
|
||||||
|
assert self._users is not None
|
||||||
|
|
||||||
for user in self._users.values():
|
for user in self._users.values():
|
||||||
found = None
|
found = None
|
||||||
|
|
||||||
@ -108,17 +127,21 @@ class AuthStore:
|
|||||||
|
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
|
|
||||||
async def async_create_refresh_token(self, user, client_id=None):
|
async def async_create_refresh_token(
|
||||||
|
self, user: models.User, client_id: Optional[str] = None) \
|
||||||
|
-> models.RefreshToken:
|
||||||
"""Create a new token for a user."""
|
"""Create a new token for a user."""
|
||||||
refresh_token = models.RefreshToken(user=user, client_id=client_id)
|
refresh_token = models.RefreshToken(user=user, client_id=client_id)
|
||||||
user.refresh_tokens[refresh_token.id] = refresh_token
|
user.refresh_tokens[refresh_token.id] = refresh_token
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
return refresh_token
|
return refresh_token
|
||||||
|
|
||||||
async def async_get_refresh_token(self, token_id):
|
async def async_get_refresh_token(
|
||||||
|
self, token_id: str) -> Optional[models.RefreshToken]:
|
||||||
"""Get refresh token by id."""
|
"""Get refresh token by id."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
assert self._users is not None
|
||||||
|
|
||||||
for user in self._users.values():
|
for user in self._users.values():
|
||||||
refresh_token = user.refresh_tokens.get(token_id)
|
refresh_token = user.refresh_tokens.get(token_id)
|
||||||
@ -127,10 +150,12 @@ class AuthStore:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def async_get_refresh_token_by_token(self, token):
|
async def async_get_refresh_token_by_token(
|
||||||
|
self, token: str) -> Optional[models.RefreshToken]:
|
||||||
"""Get refresh token by token."""
|
"""Get refresh token by token."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
assert self._users is not None
|
||||||
|
|
||||||
found = None
|
found = None
|
||||||
|
|
||||||
@ -141,7 +166,7 @@ class AuthStore:
|
|||||||
|
|
||||||
return found
|
return found
|
||||||
|
|
||||||
async def async_load(self):
|
async def async_load(self) -> None:
|
||||||
"""Load the users."""
|
"""Load the users."""
|
||||||
data = await self._store.async_load()
|
data = await self._store.async_load()
|
||||||
|
|
||||||
@ -150,7 +175,7 @@ class AuthStore:
|
|||||||
if self._users is not None:
|
if self._users is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
users = OrderedDict()
|
users = OrderedDict() # type: Dict[str, models.User]
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
self._users = users
|
self._users = users
|
||||||
@ -173,11 +198,17 @@ class AuthStore:
|
|||||||
if 'jwt_key' not in rt_dict:
|
if 'jwt_key' not in rt_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
created_at = dt_util.parse_datetime(rt_dict['created_at'])
|
||||||
|
if created_at is None:
|
||||||
|
getLogger(__name__).error(
|
||||||
|
'Ignoring refresh token %(id)s with invalid created_at '
|
||||||
|
'%(created_at)s for user_id %(user_id)s', rt_dict)
|
||||||
|
continue
|
||||||
token = models.RefreshToken(
|
token = models.RefreshToken(
|
||||||
id=rt_dict['id'],
|
id=rt_dict['id'],
|
||||||
user=users[rt_dict['user_id']],
|
user=users[rt_dict['user_id']],
|
||||||
client_id=rt_dict['client_id'],
|
client_id=rt_dict['client_id'],
|
||||||
created_at=dt_util.parse_datetime(rt_dict['created_at']),
|
created_at=created_at,
|
||||||
access_token_expiration=timedelta(
|
access_token_expiration=timedelta(
|
||||||
seconds=rt_dict['access_token_expiration']),
|
seconds=rt_dict['access_token_expiration']),
|
||||||
token=rt_dict['token'],
|
token=rt_dict['token'],
|
||||||
@ -187,8 +218,12 @@ class AuthStore:
|
|||||||
|
|
||||||
self._users = users
|
self._users = users
|
||||||
|
|
||||||
async def async_save(self):
|
async def async_save(self) -> None:
|
||||||
"""Save users."""
|
"""Save users."""
|
||||||
|
if self._users is None:
|
||||||
|
await self.async_load()
|
||||||
|
assert self._users is not None
|
||||||
|
|
||||||
users = [
|
users = [
|
||||||
{
|
{
|
||||||
'id': user.id,
|
'id': user.id,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Auth models."""
|
"""Auth models."""
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, NamedTuple, Optional # noqa: F401
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
@ -14,17 +15,21 @@ from .util import generate_secret
|
|||||||
class User:
|
class User:
|
||||||
"""A user."""
|
"""A user."""
|
||||||
|
|
||||||
name = attr.ib(type=str)
|
name = attr.ib(type=str) # type: Optional[str]
|
||||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||||
is_owner = attr.ib(type=bool, default=False)
|
is_owner = attr.ib(type=bool, default=False)
|
||||||
is_active = attr.ib(type=bool, default=False)
|
is_active = attr.ib(type=bool, default=False)
|
||||||
system_generated = attr.ib(type=bool, default=False)
|
system_generated = attr.ib(type=bool, default=False)
|
||||||
|
|
||||||
# List of credentials of a user.
|
# List of credentials of a user.
|
||||||
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
credentials = attr.ib(
|
||||||
|
type=list, default=attr.Factory(list), cmp=False
|
||||||
|
) # type: List[Credentials]
|
||||||
|
|
||||||
# Tokens associated with a user.
|
# Tokens associated with a user.
|
||||||
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
|
refresh_tokens = attr.ib(
|
||||||
|
type=dict, default=attr.Factory(dict), cmp=False
|
||||||
|
) # type: Dict[str, RefreshToken]
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
@ -32,7 +37,7 @@ class RefreshToken:
|
|||||||
"""RefreshToken for a user to grant new access tokens."""
|
"""RefreshToken for a user to grant new access tokens."""
|
||||||
|
|
||||||
user = attr.ib(type=User)
|
user = attr.ib(type=User)
|
||||||
client_id = attr.ib(type=str)
|
client_id = attr.ib(type=str) # type: Optional[str]
|
||||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||||
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
||||||
access_token_expiration = attr.ib(type=timedelta,
|
access_token_expiration = attr.ib(type=timedelta,
|
||||||
@ -48,10 +53,14 @@ class Credentials:
|
|||||||
"""Credentials for a user on an auth provider."""
|
"""Credentials for a user on an auth provider."""
|
||||||
|
|
||||||
auth_provider_type = attr.ib(type=str)
|
auth_provider_type = attr.ib(type=str)
|
||||||
auth_provider_id = attr.ib(type=str)
|
auth_provider_id = attr.ib(type=str) # type: Optional[str]
|
||||||
|
|
||||||
# Allow the auth provider to store data to represent their auth.
|
# Allow the auth provider to store data to represent their auth.
|
||||||
data = attr.ib(type=dict)
|
data = attr.ib(type=dict)
|
||||||
|
|
||||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||||
is_new = attr.ib(type=bool, default=True)
|
is_new = attr.ib(type=bool, default=True)
|
||||||
|
|
||||||
|
|
||||||
|
UserMeta = NamedTuple("UserMeta",
|
||||||
|
[('name', Optional[str]), ('is_active', bool)])
|
||||||
|
@ -1,16 +1,19 @@
|
|||||||
"""Auth providers for Home Assistant."""
|
"""Auth providers for Home Assistant."""
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import types
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous.humanize import humanize_error
|
from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
from homeassistant import requirements
|
from homeassistant import data_entry_flow, requirements
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback, HomeAssistant
|
||||||
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
||||||
from homeassistant.util.decorator import Registry
|
from homeassistant.util.decorator import Registry
|
||||||
|
|
||||||
from homeassistant.auth.models import Credentials
|
from homeassistant.auth.auth_store import AuthStore
|
||||||
|
from homeassistant.auth.models import Credentials, UserMeta
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
DATA_REQS = 'auth_prov_reqs_processed'
|
DATA_REQS = 'auth_prov_reqs_processed'
|
||||||
@ -25,7 +28,80 @@ AUTH_PROVIDER_SCHEMA = vol.Schema({
|
|||||||
}, extra=vol.ALLOW_EXTRA)
|
}, extra=vol.ALLOW_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
async def auth_provider_from_config(hass, store, config):
|
class AuthProvider:
|
||||||
|
"""Provider of user authentication."""
|
||||||
|
|
||||||
|
DEFAULT_TITLE = 'Unnamed auth provider'
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, store: AuthStore,
|
||||||
|
config: Dict[str, Any]) -> None:
|
||||||
|
"""Initialize an auth provider."""
|
||||||
|
self.hass = hass
|
||||||
|
self.store = store
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> Optional[str]: # pylint: disable=invalid-name
|
||||||
|
"""Return id of the auth provider.
|
||||||
|
|
||||||
|
Optional, can be None.
|
||||||
|
"""
|
||||||
|
return self.config.get(CONF_ID)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Return type of the provider."""
|
||||||
|
return self.config[CONF_TYPE] # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of the auth provider."""
|
||||||
|
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
|
||||||
|
|
||||||
|
async def async_credentials(self) -> List[Credentials]:
|
||||||
|
"""Return all credentials of this provider."""
|
||||||
|
users = await self.store.async_get_users()
|
||||||
|
return [
|
||||||
|
credentials
|
||||||
|
for user in users
|
||||||
|
for credentials in user.credentials
|
||||||
|
if (credentials.auth_provider_type == self.type and
|
||||||
|
credentials.auth_provider_id == self.id)
|
||||||
|
]
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_create_credentials(self, data: Dict[str, str]) -> Credentials:
|
||||||
|
"""Create credentials."""
|
||||||
|
return Credentials(
|
||||||
|
auth_provider_type=self.type,
|
||||||
|
auth_provider_id=self.id,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Implement by extending class
|
||||||
|
|
||||||
|
async def async_credential_flow(
|
||||||
|
self, context: Optional[Dict]) -> data_entry_flow.FlowHandler:
|
||||||
|
"""Return the data flow for logging in with auth provider."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_get_or_create_credentials(
|
||||||
|
self, flow_result: Dict[str, str]) -> Credentials:
|
||||||
|
"""Get credentials based on the flow result."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_user_meta_for_credentials(
|
||||||
|
self, credentials: Credentials) -> UserMeta:
|
||||||
|
"""Return extra user metadata for credentials.
|
||||||
|
|
||||||
|
Will be used to populate info when creating a new user.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_provider_from_config(
|
||||||
|
hass: HomeAssistant, store: AuthStore,
|
||||||
|
config: Dict[str, Any]) -> Optional[AuthProvider]:
|
||||||
"""Initialize an auth provider from a config."""
|
"""Initialize an auth provider from a config."""
|
||||||
provider_name = config[CONF_TYPE]
|
provider_name = config[CONF_TYPE]
|
||||||
module = await load_auth_provider_module(hass, provider_name)
|
module = await load_auth_provider_module(hass, provider_name)
|
||||||
@ -34,16 +110,17 @@ async def auth_provider_from_config(hass, store, config):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = module.CONFIG_SCHEMA(config)
|
config = module.CONFIG_SCHEMA(config) # type: ignore
|
||||||
except vol.Invalid as err:
|
except vol.Invalid as err:
|
||||||
_LOGGER.error('Invalid configuration for auth provider %s: %s',
|
_LOGGER.error('Invalid configuration for auth provider %s: %s',
|
||||||
provider_name, humanize_error(config, err))
|
provider_name, humanize_error(config, err))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return AUTH_PROVIDERS[provider_name](hass, store, config)
|
return AUTH_PROVIDERS[provider_name](hass, store, config) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
async def load_auth_provider_module(hass, provider):
|
async def load_auth_provider_module(
|
||||||
|
hass: HomeAssistant, provider: str) -> Optional[types.ModuleType]:
|
||||||
"""Load an auth provider."""
|
"""Load an auth provider."""
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(
|
module = importlib.import_module(
|
||||||
@ -62,82 +139,13 @@ async def load_auth_provider_module(hass, provider):
|
|||||||
elif provider in processed:
|
elif provider in processed:
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
# https://github.com/python/mypy/issues/1424
|
||||||
|
reqs = module.REQUIREMENTS # type: ignore
|
||||||
req_success = await requirements.async_process_requirements(
|
req_success = await requirements.async_process_requirements(
|
||||||
hass, 'auth provider {}'.format(provider), module.REQUIREMENTS)
|
hass, 'auth provider {}'.format(provider), reqs)
|
||||||
|
|
||||||
if not req_success:
|
if not req_success:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
processed.add(provider)
|
processed.add(provider)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider:
|
|
||||||
"""Provider of user authentication."""
|
|
||||||
|
|
||||||
DEFAULT_TITLE = 'Unnamed auth provider'
|
|
||||||
|
|
||||||
def __init__(self, hass, store, config):
|
|
||||||
"""Initialize an auth provider."""
|
|
||||||
self.hass = hass
|
|
||||||
self.store = store
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self): # pylint: disable=invalid-name
|
|
||||||
"""Return id of the auth provider.
|
|
||||||
|
|
||||||
Optional, can be None.
|
|
||||||
"""
|
|
||||||
return self.config.get(CONF_ID)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self):
|
|
||||||
"""Return type of the provider."""
|
|
||||||
return self.config[CONF_TYPE]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
"""Return the name of the auth provider."""
|
|
||||||
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
|
|
||||||
|
|
||||||
async def async_credentials(self):
|
|
||||||
"""Return all credentials of this provider."""
|
|
||||||
users = await self.store.async_get_users()
|
|
||||||
return [
|
|
||||||
credentials
|
|
||||||
for user in users
|
|
||||||
for credentials in user.credentials
|
|
||||||
if (credentials.auth_provider_type == self.type and
|
|
||||||
credentials.auth_provider_id == self.id)
|
|
||||||
]
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_create_credentials(self, data):
|
|
||||||
"""Create credentials."""
|
|
||||||
return Credentials(
|
|
||||||
auth_provider_type=self.type,
|
|
||||||
auth_provider_id=self.id,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Implement by extending class
|
|
||||||
|
|
||||||
async def async_credential_flow(self, context):
|
|
||||||
"""Return the data flow for logging in with auth provider."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def async_get_or_create_credentials(self, flow_result):
|
|
||||||
"""Get credentials based on the flow result."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def async_user_meta_for_credentials(self, credentials):
|
|
||||||
"""Return extra user metadata for credentials.
|
|
||||||
|
|
||||||
Will be used to populate info when creating a new user.
|
|
||||||
|
|
||||||
Values to populate:
|
|
||||||
- name: string
|
|
||||||
- is_active: boolean
|
|
||||||
"""
|
|
||||||
return {}
|
|
||||||
|
@ -3,24 +3,25 @@ import base64
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
from typing import Dict # noqa: F401 pylint: disable=unused-import
|
from typing import Any, Dict, List, Optional # noqa: F401,E501 pylint: disable=unused-import
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.const import CONF_ID
|
from homeassistant.const import CONF_ID
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from homeassistant.auth.util import generate_secret
|
from homeassistant.auth.util import generate_secret
|
||||||
|
|
||||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||||
|
from ..models import Credentials, UserMeta
|
||||||
|
|
||||||
STORAGE_VERSION = 1
|
STORAGE_VERSION = 1
|
||||||
STORAGE_KEY = 'auth_provider.homeassistant'
|
STORAGE_KEY = 'auth_provider.homeassistant'
|
||||||
|
|
||||||
|
|
||||||
def _disallow_id(conf):
|
def _disallow_id(conf: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Disallow ID in config."""
|
"""Disallow ID in config."""
|
||||||
if CONF_ID in conf:
|
if CONF_ID in conf:
|
||||||
raise vol.Invalid(
|
raise vol.Invalid(
|
||||||
@ -46,13 +47,13 @@ class InvalidUser(HomeAssistantError):
|
|||||||
class Data:
|
class Data:
|
||||||
"""Hold the user data."""
|
"""Hold the user data."""
|
||||||
|
|
||||||
def __init__(self, hass):
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the user data store."""
|
"""Initialize the user data store."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||||
self._data = None
|
self._data = None # type: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
async def async_load(self):
|
async def async_load(self) -> None:
|
||||||
"""Load stored data."""
|
"""Load stored data."""
|
||||||
data = await self._store.async_load()
|
data = await self._store.async_load()
|
||||||
|
|
||||||
@ -65,9 +66,9 @@ class Data:
|
|||||||
self._data = data
|
self._data = data
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def users(self):
|
def users(self) -> List[Dict[str, str]]:
|
||||||
"""Return users."""
|
"""Return users."""
|
||||||
return self._data['users']
|
return self._data['users'] # type: ignore
|
||||||
|
|
||||||
def validate_login(self, username: str, password: str) -> None:
|
def validate_login(self, username: str, password: str) -> None:
|
||||||
"""Validate a username and password.
|
"""Validate a username and password.
|
||||||
@ -79,7 +80,7 @@ class Data:
|
|||||||
found = None
|
found = None
|
||||||
|
|
||||||
# Compare all users to avoid timing attacks.
|
# Compare all users to avoid timing attacks.
|
||||||
for user in self._data['users']:
|
for user in self.users:
|
||||||
if username == user['username']:
|
if username == user['username']:
|
||||||
found = user
|
found = user
|
||||||
|
|
||||||
@ -94,8 +95,8 @@ class Data:
|
|||||||
|
|
||||||
def hash_password(self, password: str, for_storage: bool = False) -> bytes:
|
def hash_password(self, password: str, for_storage: bool = False) -> bytes:
|
||||||
"""Encode a password."""
|
"""Encode a password."""
|
||||||
hashed = hashlib.pbkdf2_hmac(
|
salt = self._data['salt'].encode() # type: ignore
|
||||||
'sha512', password.encode(), self._data['salt'].encode(), 100000)
|
hashed = hashlib.pbkdf2_hmac('sha512', password.encode(), salt, 100000)
|
||||||
if for_storage:
|
if for_storage:
|
||||||
hashed = base64.b64encode(hashed)
|
hashed = base64.b64encode(hashed)
|
||||||
return hashed
|
return hashed
|
||||||
@ -137,7 +138,7 @@ class Data:
|
|||||||
else:
|
else:
|
||||||
raise InvalidUser
|
raise InvalidUser
|
||||||
|
|
||||||
async def async_save(self):
|
async def async_save(self) -> None:
|
||||||
"""Save data."""
|
"""Save data."""
|
||||||
await self._store.async_save(self._data)
|
await self._store.async_save(self._data)
|
||||||
|
|
||||||
@ -150,7 +151,7 @@ class HassAuthProvider(AuthProvider):
|
|||||||
|
|
||||||
data = None
|
data = None
|
||||||
|
|
||||||
async def async_initialize(self):
|
async def async_initialize(self) -> None:
|
||||||
"""Initialize the auth provider."""
|
"""Initialize the auth provider."""
|
||||||
if self.data is not None:
|
if self.data is not None:
|
||||||
return
|
return
|
||||||
@ -158,19 +159,22 @@ class HassAuthProvider(AuthProvider):
|
|||||||
self.data = Data(self.hass)
|
self.data = Data(self.hass)
|
||||||
await self.data.async_load()
|
await self.data.async_load()
|
||||||
|
|
||||||
async def async_credential_flow(self, context):
|
async def async_credential_flow(
|
||||||
|
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return LoginFlow(self)
|
return LoginFlow(self)
|
||||||
|
|
||||||
async def async_validate_login(self, username: str, password: str):
|
async def async_validate_login(self, username: str, password: str) -> None:
|
||||||
"""Helper to validate a username and password."""
|
"""Helper to validate a username and password."""
|
||||||
if self.data is None:
|
if self.data is None:
|
||||||
await self.async_initialize()
|
await self.async_initialize()
|
||||||
|
assert self.data is not None
|
||||||
|
|
||||||
await self.hass.async_add_executor_job(
|
await self.hass.async_add_executor_job(
|
||||||
self.data.validate_login, username, password)
|
self.data.validate_login, username, password)
|
||||||
|
|
||||||
async def async_get_or_create_credentials(self, flow_result):
|
async def async_get_or_create_credentials(
|
||||||
|
self, flow_result: Dict[str, str]) -> Credentials:
|
||||||
"""Get credentials based on the flow result."""
|
"""Get credentials based on the flow result."""
|
||||||
username = flow_result['username']
|
username = flow_result['username']
|
||||||
|
|
||||||
@ -183,17 +187,17 @@ class HassAuthProvider(AuthProvider):
|
|||||||
'username': username
|
'username': username
|
||||||
})
|
})
|
||||||
|
|
||||||
async def async_user_meta_for_credentials(self, credentials):
|
async def async_user_meta_for_credentials(
|
||||||
|
self, credentials: Credentials) -> UserMeta:
|
||||||
"""Get extra info for this credential."""
|
"""Get extra info for this credential."""
|
||||||
return {
|
return UserMeta(name=credentials.data['username'], is_active=True)
|
||||||
'name': credentials.data['username'],
|
|
||||||
'is_active': True,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def async_will_remove_credentials(self, credentials):
|
async def async_will_remove_credentials(
|
||||||
|
self, credentials: Credentials) -> None:
|
||||||
"""When credentials get removed, also remove the auth."""
|
"""When credentials get removed, also remove the auth."""
|
||||||
if self.data is None:
|
if self.data is None:
|
||||||
await self.async_initialize()
|
await self.async_initialize()
|
||||||
|
assert self.data is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.data.async_remove_auth(credentials.data['username'])
|
self.data.async_remove_auth(credentials.data['username'])
|
||||||
@ -206,11 +210,12 @@ class HassAuthProvider(AuthProvider):
|
|||||||
class LoginFlow(data_entry_flow.FlowHandler):
|
class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
def __init__(self, auth_provider):
|
def __init__(self, auth_provider: HassAuthProvider) -> None:
|
||||||
"""Initialize the login flow."""
|
"""Initialize the login flow."""
|
||||||
self._auth_provider = auth_provider
|
self._auth_provider = auth_provider
|
||||||
|
|
||||||
async def async_step_init(self, user_input=None):
|
async def async_step_init(
|
||||||
|
self, user_input: Dict[str, str] = None) -> Dict[str, Any]:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
errors = {}
|
errors = {}
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Example auth provider."""
|
"""Example auth provider."""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import hmac
|
import hmac
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -9,6 +10,7 @@ from homeassistant import data_entry_flow
|
|||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||||
|
from ..models import Credentials, UserMeta
|
||||||
|
|
||||||
|
|
||||||
USER_SCHEMA = vol.Schema({
|
USER_SCHEMA = vol.Schema({
|
||||||
@ -31,12 +33,13 @@ class InvalidAuthError(HomeAssistantError):
|
|||||||
class ExampleAuthProvider(AuthProvider):
|
class ExampleAuthProvider(AuthProvider):
|
||||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||||
|
|
||||||
async def async_credential_flow(self, context):
|
async def async_credential_flow(
|
||||||
|
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return LoginFlow(self)
|
return LoginFlow(self)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_validate_login(self, username, password):
|
def async_validate_login(self, username: str, password: str) -> None:
|
||||||
"""Helper to validate a username and password."""
|
"""Helper to validate a username and password."""
|
||||||
user = None
|
user = None
|
||||||
|
|
||||||
@ -56,7 +59,8 @@ class ExampleAuthProvider(AuthProvider):
|
|||||||
password.encode('utf-8')):
|
password.encode('utf-8')):
|
||||||
raise InvalidAuthError
|
raise InvalidAuthError
|
||||||
|
|
||||||
async def async_get_or_create_credentials(self, flow_result):
|
async def async_get_or_create_credentials(
|
||||||
|
self, flow_result: Dict[str, str]) -> Credentials:
|
||||||
"""Get credentials based on the flow result."""
|
"""Get credentials based on the flow result."""
|
||||||
username = flow_result['username']
|
username = flow_result['username']
|
||||||
|
|
||||||
@ -69,32 +73,32 @@ class ExampleAuthProvider(AuthProvider):
|
|||||||
'username': username
|
'username': username
|
||||||
})
|
})
|
||||||
|
|
||||||
async def async_user_meta_for_credentials(self, credentials):
|
async def async_user_meta_for_credentials(
|
||||||
|
self, credentials: Credentials) -> UserMeta:
|
||||||
"""Return extra user metadata for credentials.
|
"""Return extra user metadata for credentials.
|
||||||
|
|
||||||
Will be used to populate info when creating a new user.
|
Will be used to populate info when creating a new user.
|
||||||
"""
|
"""
|
||||||
username = credentials.data['username']
|
username = credentials.data['username']
|
||||||
info = {
|
name = None
|
||||||
'is_active': True,
|
|
||||||
}
|
|
||||||
|
|
||||||
for user in self.config['users']:
|
for user in self.config['users']:
|
||||||
if user['username'] == username:
|
if user['username'] == username:
|
||||||
info['name'] = user.get('name')
|
name = user.get('name')
|
||||||
break
|
break
|
||||||
|
|
||||||
return info
|
return UserMeta(name=name, is_active=True)
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(data_entry_flow.FlowHandler):
|
class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
def __init__(self, auth_provider):
|
def __init__(self, auth_provider: ExampleAuthProvider) -> None:
|
||||||
"""Initialize the login flow."""
|
"""Initialize the login flow."""
|
||||||
self._auth_provider = auth_provider
|
self._auth_provider = auth_provider
|
||||||
|
|
||||||
async def async_step_init(self, user_input=None):
|
async def async_step_init(
|
||||||
|
self, user_input: Dict[str, str] = None) -> Dict[str, Any]:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
errors = {}
|
errors = {}
|
||||||
|
|
||||||
@ -111,7 +115,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||||||
data=user_input
|
data=user_input
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = OrderedDict()
|
schema = OrderedDict() # type: Dict[str, type]
|
||||||
schema['username'] = str
|
schema['username'] = str
|
||||||
schema['password'] = str
|
schema['password'] = str
|
||||||
|
|
||||||
|
@ -5,14 +5,17 @@ It will be removed when auth system production ready
|
|||||||
"""
|
"""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import hmac
|
import hmac
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||||
|
from ..models import Credentials, UserMeta
|
||||||
|
|
||||||
|
|
||||||
USER_SCHEMA = vol.Schema({
|
USER_SCHEMA = vol.Schema({
|
||||||
@ -36,25 +39,29 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
|||||||
|
|
||||||
DEFAULT_TITLE = 'Legacy API Password'
|
DEFAULT_TITLE = 'Legacy API Password'
|
||||||
|
|
||||||
async def async_credential_flow(self, context):
|
async def async_credential_flow(
|
||||||
|
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return LoginFlow(self)
|
return LoginFlow(self)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_validate_login(self, password):
|
def async_validate_login(self, password: str) -> None:
|
||||||
"""Helper to validate a username and password."""
|
"""Helper to validate a username and password."""
|
||||||
if not hasattr(self.hass, 'http'):
|
hass_http = getattr(self.hass, 'http', None) # type: HomeAssistantHTTP
|
||||||
|
|
||||||
|
if not hass_http:
|
||||||
raise ValueError('http component is not loaded')
|
raise ValueError('http component is not loaded')
|
||||||
|
|
||||||
if self.hass.http.api_password is None:
|
if hass_http.api_password is None:
|
||||||
raise ValueError('http component is not configured using'
|
raise ValueError('http component is not configured using'
|
||||||
' api_password')
|
' api_password')
|
||||||
|
|
||||||
if not hmac.compare_digest(self.hass.http.api_password.encode('utf-8'),
|
if not hmac.compare_digest(hass_http.api_password.encode('utf-8'),
|
||||||
password.encode('utf-8')):
|
password.encode('utf-8')):
|
||||||
raise InvalidAuthError
|
raise InvalidAuthError
|
||||||
|
|
||||||
async def async_get_or_create_credentials(self, flow_result):
|
async def async_get_or_create_credentials(
|
||||||
|
self, flow_result: Dict[str, str]) -> Credentials:
|
||||||
"""Return LEGACY_USER always."""
|
"""Return LEGACY_USER always."""
|
||||||
for credential in await self.async_credentials():
|
for credential in await self.async_credentials():
|
||||||
if credential.data['username'] == LEGACY_USER:
|
if credential.data['username'] == LEGACY_USER:
|
||||||
@ -64,26 +71,25 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
|||||||
'username': LEGACY_USER
|
'username': LEGACY_USER
|
||||||
})
|
})
|
||||||
|
|
||||||
async def async_user_meta_for_credentials(self, credentials):
|
async def async_user_meta_for_credentials(
|
||||||
|
self, credentials: Credentials) -> UserMeta:
|
||||||
"""
|
"""
|
||||||
Set name as LEGACY_USER always.
|
Set name as LEGACY_USER always.
|
||||||
|
|
||||||
Will be used to populate info when creating a new user.
|
Will be used to populate info when creating a new user.
|
||||||
"""
|
"""
|
||||||
return {
|
return UserMeta(name=LEGACY_USER, is_active=True)
|
||||||
'name': LEGACY_USER,
|
|
||||||
'is_active': True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(data_entry_flow.FlowHandler):
|
class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
def __init__(self, auth_provider):
|
def __init__(self, auth_provider: LegacyApiPasswordAuthProvider) -> None:
|
||||||
"""Initialize the login flow."""
|
"""Initialize the login flow."""
|
||||||
self._auth_provider = auth_provider
|
self._auth_provider = auth_provider
|
||||||
|
|
||||||
async def async_step_init(self, user_input=None):
|
async def async_step_init(
|
||||||
|
self, user_input: Dict[str, str] = None) -> Dict[str, Any]:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
errors = {}
|
errors = {}
|
||||||
|
|
||||||
@ -100,7 +106,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||||||
data={}
|
data={}
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = OrderedDict()
|
schema = OrderedDict() # type: Dict[str, type]
|
||||||
schema['password'] = str
|
schema['password'] = str
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
|
@ -3,12 +3,16 @@
|
|||||||
It shows list of users if access from trusted network.
|
It shows list of users if access from trusted network.
|
||||||
Abort login flow if not access from trusted network.
|
Abort login flow if not access from trusted network.
|
||||||
"""
|
"""
|
||||||
|
from typing import Any, Dict, Optional, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
|
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||||
|
from ..models import Credentials, UserMeta
|
||||||
|
|
||||||
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||||
}, extra=vol.PREVENT_EXTRA)
|
}, extra=vol.PREVENT_EXTRA)
|
||||||
@ -31,16 +35,20 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||||||
|
|
||||||
DEFAULT_TITLE = 'Trusted Networks'
|
DEFAULT_TITLE = 'Trusted Networks'
|
||||||
|
|
||||||
async def async_credential_flow(self, context):
|
async def async_credential_flow(
|
||||||
|
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
|
assert context is not None
|
||||||
users = await self.store.async_get_users()
|
users = await self.store.async_get_users()
|
||||||
available_users = {user.id: user.name
|
available_users = {user.id: user.name
|
||||||
for user in users
|
for user in users
|
||||||
if not user.system_generated and user.is_active}
|
if not user.system_generated and user.is_active}
|
||||||
|
|
||||||
return LoginFlow(self, context.get('ip_address'), available_users)
|
return LoginFlow(self, cast(str, context.get('ip_address')),
|
||||||
|
available_users)
|
||||||
|
|
||||||
async def async_get_or_create_credentials(self, flow_result):
|
async def async_get_or_create_credentials(
|
||||||
|
self, flow_result: Dict[str, str]) -> Credentials:
|
||||||
"""Get credentials based on the flow result."""
|
"""Get credentials based on the flow result."""
|
||||||
user_id = flow_result['user']
|
user_id = flow_result['user']
|
||||||
|
|
||||||
@ -59,7 +67,8 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||||||
# We only allow login as exist user
|
# We only allow login as exist user
|
||||||
raise InvalidUserError
|
raise InvalidUserError
|
||||||
|
|
||||||
async def async_user_meta_for_credentials(self, credentials):
|
async def async_user_meta_for_credentials(
|
||||||
|
self, credentials: Credentials) -> UserMeta:
|
||||||
"""Return extra user metadata for credentials.
|
"""Return extra user metadata for credentials.
|
||||||
|
|
||||||
Trusted network auth provider should never create new user.
|
Trusted network auth provider should never create new user.
|
||||||
@ -67,31 +76,36 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_validate_access(self, ip_address):
|
def async_validate_access(self, ip_address: str) -> None:
|
||||||
"""Make sure the access from trusted networks.
|
"""Make sure the access from trusted networks.
|
||||||
|
|
||||||
Raise InvalidAuthError if not.
|
Raise InvalidAuthError if not.
|
||||||
Raise InvalidAuthError if trusted_networks is not config
|
Raise InvalidAuthError if trusted_networks is not configured.
|
||||||
"""
|
"""
|
||||||
if (not hasattr(self.hass, 'http') or
|
hass_http = getattr(self.hass, 'http', None) # type: HomeAssistantHTTP
|
||||||
not self.hass.http or not self.hass.http.trusted_networks):
|
|
||||||
|
if not hass_http or not hass_http.trusted_networks:
|
||||||
raise InvalidAuthError('trusted_networks is not configured')
|
raise InvalidAuthError('trusted_networks is not configured')
|
||||||
|
|
||||||
if not any(ip_address in trusted_network for trusted_network
|
if not any(ip_address in trusted_network for trusted_network
|
||||||
in self.hass.http.trusted_networks):
|
in hass_http.trusted_networks):
|
||||||
raise InvalidAuthError('Not in trusted_networks')
|
raise InvalidAuthError('Not in trusted_networks')
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(data_entry_flow.FlowHandler):
|
class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
def __init__(self, auth_provider, ip_address, available_users):
|
def __init__(self, auth_provider: TrustedNetworksAuthProvider,
|
||||||
|
ip_address: str, available_users: Dict[str, Optional[str]]) \
|
||||||
|
-> None:
|
||||||
"""Initialize the login flow."""
|
"""Initialize the login flow."""
|
||||||
self._auth_provider = auth_provider
|
self._auth_provider = auth_provider
|
||||||
self._available_users = available_users
|
self._available_users = available_users
|
||||||
self._ip_address = ip_address
|
self._ip_address = ip_address
|
||||||
|
|
||||||
async def async_step_init(self, user_input=None):
|
async def async_step_init(
|
||||||
|
self, user_input: Optional[Dict[str, str]] = None) \
|
||||||
|
-> Dict[str, Any]:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
errors = {}
|
errors = {}
|
||||||
try:
|
try:
|
||||||
|
2
tox.ini
2
tox.ini
@ -58,4 +58,4 @@ whitelist_externals=/bin/bash
|
|||||||
deps =
|
deps =
|
||||||
-r{toxinidir}/requirements_test.txt
|
-r{toxinidir}/requirements_test.txt
|
||||||
commands =
|
commands =
|
||||||
/bin/bash -c 'mypy homeassistant/*.py homeassistant/util/'
|
/bin/bash -c 'mypy homeassistant/*.py homeassistant/auth/ homeassistant/util/'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user