diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 148f97702e3..b710ca9999e 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -2,7 +2,7 @@ import asyncio import logging from collections import OrderedDict -from typing import List, Awaitable +from typing import Any, Dict, List, Optional, Tuple, cast import jwt @@ -10,15 +10,17 @@ from homeassistant import data_entry_flow from homeassistant.core import callback, HomeAssistant from homeassistant.util import dt as dt_util -from . import auth_store -from .providers import auth_provider_from_config +from . import auth_store, models +from .providers import auth_provider_from_config, AuthProvider _LOGGER = logging.getLogger(__name__) +_ProviderKey = Tuple[str, Optional[str]] +_ProviderDict = Dict[_ProviderKey, AuthProvider] async def auth_manager_from_config( hass: HomeAssistant, - provider_configs: List[dict]) -> Awaitable['AuthManager']: + provider_configs: List[Dict[str, Any]]) -> 'AuthManager': """Initialize an auth manager from config.""" store = auth_store.AuthStore(hass) if provider_configs: @@ -26,9 +28,9 @@ async def auth_manager_from_config( *[auth_provider_from_config(hass, store, config) for config in provider_configs]) else: - providers = [] + providers = () # So returned auth providers are in same order as config - provider_hash = OrderedDict() + provider_hash = OrderedDict() # type: _ProviderDict for provider in providers: if provider is None: continue @@ -49,7 +51,8 @@ async def auth_manager_from_config( class AuthManager: """Manage the authentication for Home Assistant.""" - def __init__(self, hass, store, providers): + def __init__(self, hass: HomeAssistant, store: auth_store.AuthStore, + providers: _ProviderDict) -> None: """Initialize the auth manager.""" self._store = store self._providers = providers @@ -58,12 +61,12 @@ class AuthManager: self._async_finish_login_flow) @property - def active(self): + def active(self) -> bool: """Return if any auth providers are registered.""" return bool(self._providers) @property - def support_legacy(self): + def support_legacy(self) -> bool: """ Return if legacy_api_password auth providers are registered. @@ -75,19 +78,19 @@ class AuthManager: return False @property - def auth_providers(self): + def auth_providers(self) -> List[AuthProvider]: """Return a list of available auth providers.""" return list(self._providers.values()) - async def async_get_users(self): + async def async_get_users(self) -> List[models.User]: """Retrieve all users.""" return await self._store.async_get_users() - async def async_get_user(self, user_id): + async def async_get_user(self, user_id: str) -> Optional[models.User]: """Retrieve a user.""" return await self._store.async_get_user(user_id) - async def async_create_system_user(self, name): + async def async_create_system_user(self, name: str) -> models.User: """Create a system user.""" return await self._store.async_create_user( name=name, @@ -95,19 +98,20 @@ class AuthManager: is_active=True, ) - async def async_create_user(self, name): + async def async_create_user(self, name: str) -> models.User: """Create a user.""" kwargs = { 'name': name, 'is_active': True, - } + } # type: Dict[str, Any] if await self._user_should_be_owner(): kwargs['is_owner'] = True return await self._store.async_create_user(**kwargs) - async def async_get_or_create_user(self, credentials): + async def async_get_or_create_user(self, credentials: models.Credentials) \ + -> models.User: """Get or create a user.""" if not credentials.is_new: for user in await self._store.async_get_users(): @@ -127,15 +131,16 @@ class AuthManager: return await self._store.async_create_user( credentials=credentials, - name=info.get('name'), - is_active=info.get('is_active', False) + name=info.name, + is_active=info.is_active, ) - async def async_link_user(self, user, credentials): + async def async_link_user(self, user: models.User, + credentials: models.Credentials) -> None: """Link credentials to an existing user.""" await self._store.async_link_user(user, credentials) - async def async_remove_user(self, user): + async def async_remove_user(self, user: models.User) -> None: """Remove a user.""" tasks = [ self.async_remove_credentials(credentials) @@ -147,27 +152,32 @@ class AuthManager: await self._store.async_remove_user(user) - async def async_activate_user(self, user): + async def async_activate_user(self, user: models.User) -> None: """Activate a user.""" await self._store.async_activate_user(user) - async def async_deactivate_user(self, user): + async def async_deactivate_user(self, user: models.User) -> None: """Deactivate a user.""" if user.is_owner: raise ValueError('Unable to deactive the owner') await self._store.async_deactivate_user(user) - async def async_remove_credentials(self, credentials): + async def async_remove_credentials( + self, credentials: models.Credentials) -> None: """Remove credentials.""" provider = self._async_get_auth_provider(credentials) if (provider is not None and hasattr(provider, 'async_will_remove_credentials')): - await provider.async_will_remove_credentials(credentials) + # https://github.com/python/mypy/issues/1424 + await provider.async_will_remove_credentials( # type: ignore + credentials) await self._store.async_remove_credentials(credentials) - async def async_create_refresh_token(self, user, client_id=None): + async def async_create_refresh_token(self, user: models.User, + client_id: Optional[str] = None) \ + -> models.RefreshToken: """Create a new refresh token for a user.""" if not user.is_active: raise ValueError('User is not active') @@ -182,16 +192,19 @@ class AuthManager: return await self._store.async_create_refresh_token(user, client_id) - async def async_get_refresh_token(self, token_id): + async def async_get_refresh_token( + self, token_id: str) -> Optional[models.RefreshToken]: """Get refresh token by id.""" return await self._store.async_get_refresh_token(token_id) - async def async_get_refresh_token_by_token(self, token): + async def async_get_refresh_token_by_token( + self, token: str) -> Optional[models.RefreshToken]: """Get refresh token by token.""" return await self._store.async_get_refresh_token_by_token(token) @callback - def async_create_access_token(self, refresh_token): + def async_create_access_token(self, + refresh_token: models.RefreshToken) -> str: """Create a new access token.""" # pylint: disable=no-self-use return jwt.encode({ @@ -200,7 +213,8 @@ class AuthManager: 'exp': dt_util.utcnow() + refresh_token.access_token_expiration, }, refresh_token.jwt_key, algorithm='HS256').decode() - async def async_validate_access_token(self, token): + async def async_validate_access_token( + self, token: str) -> Optional[models.RefreshToken]: """Return if an access token is valid.""" try: unverif_claims = jwt.decode(token, verify=False) @@ -208,7 +222,7 @@ class AuthManager: return None refresh_token = await self.async_get_refresh_token( - unverif_claims.get('iss')) + cast(str, unverif_claims.get('iss'))) if refresh_token is None: jwt_key = '' @@ -228,18 +242,22 @@ class AuthManager: except jwt.InvalidTokenError: return None - if not refresh_token.user.is_active: + if refresh_token is None or not refresh_token.user.is_active: return None return refresh_token - async def _async_create_login_flow(self, handler, *, context, data): + async def _async_create_login_flow( + self, handler: _ProviderKey, *, context: Optional[Dict], + data: Optional[Any]) -> data_entry_flow.FlowHandler: """Create a login flow.""" auth_provider = self._providers[handler] return await auth_provider.async_credential_flow(context) - async def _async_finish_login_flow(self, context, result): + async def _async_finish_login_flow( + self, context: Optional[Dict], result: Dict[str, Any]) \ + -> Optional[models.Credentials]: """Result of a credential login flow.""" if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: return None @@ -249,13 +267,14 @@ class AuthManager: result['data']) @callback - def _async_get_auth_provider(self, credentials): + def _async_get_auth_provider( + self, credentials: models.Credentials) -> Optional[AuthProvider]: """Helper to get auth provider from a set of credentials.""" auth_provider_key = (credentials.auth_provider_type, credentials.auth_provider_id) return self._providers.get(auth_provider_key) - async def _user_should_be_owner(self): + async def _user_should_be_owner(self) -> bool: """Determine if user should be owner. A user should be an owner if it is the first non-system user that is diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index 806cd109d78..07ab40ceaea 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -1,8 +1,11 @@ """Storage for auth models.""" from collections import OrderedDict from datetime import timedelta +from logging import getLogger +from typing import Any, Dict, List, Optional # noqa: F401 import hmac +from homeassistant.core import HomeAssistant from homeassistant.util import dt as dt_util from . import models @@ -20,35 +23,41 @@ class AuthStore: called that needs it. """ - def __init__(self, hass): + def __init__(self, hass: HomeAssistant) -> None: """Initialize the auth store.""" self.hass = hass - self._users = None + self._users = None # type: Optional[Dict[str, models.User]] self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) - async def async_get_users(self): + async def async_get_users(self) -> List[models.User]: """Retrieve all users.""" if self._users is None: await self.async_load() + assert self._users is not None return list(self._users.values()) - async def async_get_user(self, user_id): + async def async_get_user(self, user_id: str) -> Optional[models.User]: """Retrieve a user by id.""" if self._users is None: await self.async_load() + assert self._users is not None return self._users.get(user_id) - async def async_create_user(self, name, is_owner=None, is_active=None, - system_generated=None, credentials=None): + async def async_create_user( + self, name: Optional[str], is_owner: Optional[bool] = None, + is_active: Optional[bool] = None, + system_generated: Optional[bool] = None, + credentials: Optional[models.Credentials] = None) -> models.User: """Create a new user.""" if self._users is None: await self.async_load() + assert self._users is not None kwargs = { 'name': name - } + } # type: Dict[str, Any] if is_owner is not None: kwargs['is_owner'] = is_owner @@ -71,29 +80,39 @@ class AuthStore: await self.async_link_user(new_user, credentials) return new_user - async def async_link_user(self, user, credentials): + async def async_link_user(self, user: models.User, + credentials: models.Credentials) -> None: """Add credentials to an existing user.""" user.credentials.append(credentials) await self.async_save() credentials.is_new = False - async def async_remove_user(self, user): + async def async_remove_user(self, user: models.User) -> None: """Remove a user.""" + if self._users is None: + await self.async_load() + assert self._users is not None + self._users.pop(user.id) await self.async_save() - async def async_activate_user(self, user): + async def async_activate_user(self, user: models.User) -> None: """Activate a user.""" user.is_active = True await self.async_save() - async def async_deactivate_user(self, user): + async def async_deactivate_user(self, user: models.User) -> None: """Activate a user.""" user.is_active = False await self.async_save() - async def async_remove_credentials(self, credentials): + async def async_remove_credentials( + self, credentials: models.Credentials) -> None: """Remove credentials.""" + if self._users is None: + await self.async_load() + assert self._users is not None + for user in self._users.values(): found = None @@ -108,17 +127,21 @@ class AuthStore: await self.async_save() - async def async_create_refresh_token(self, user, client_id=None): + async def async_create_refresh_token( + self, user: models.User, client_id: Optional[str] = None) \ + -> models.RefreshToken: """Create a new token for a user.""" refresh_token = models.RefreshToken(user=user, client_id=client_id) user.refresh_tokens[refresh_token.id] = refresh_token await self.async_save() return refresh_token - async def async_get_refresh_token(self, token_id): + async def async_get_refresh_token( + self, token_id: str) -> Optional[models.RefreshToken]: """Get refresh token by id.""" if self._users is None: await self.async_load() + assert self._users is not None for user in self._users.values(): refresh_token = user.refresh_tokens.get(token_id) @@ -127,10 +150,12 @@ class AuthStore: return None - async def async_get_refresh_token_by_token(self, token): + async def async_get_refresh_token_by_token( + self, token: str) -> Optional[models.RefreshToken]: """Get refresh token by token.""" if self._users is None: await self.async_load() + assert self._users is not None found = None @@ -141,7 +166,7 @@ class AuthStore: return found - async def async_load(self): + async def async_load(self) -> None: """Load the users.""" data = await self._store.async_load() @@ -150,7 +175,7 @@ class AuthStore: if self._users is not None: return - users = OrderedDict() + users = OrderedDict() # type: Dict[str, models.User] if data is None: self._users = users @@ -173,11 +198,17 @@ class AuthStore: if 'jwt_key' not in rt_dict: continue + created_at = dt_util.parse_datetime(rt_dict['created_at']) + if created_at is None: + getLogger(__name__).error( + 'Ignoring refresh token %(id)s with invalid created_at ' + '%(created_at)s for user_id %(user_id)s', rt_dict) + continue token = models.RefreshToken( id=rt_dict['id'], user=users[rt_dict['user_id']], client_id=rt_dict['client_id'], - created_at=dt_util.parse_datetime(rt_dict['created_at']), + created_at=created_at, access_token_expiration=timedelta( seconds=rt_dict['access_token_expiration']), token=rt_dict['token'], @@ -187,8 +218,12 @@ class AuthStore: self._users = users - async def async_save(self): + async def async_save(self) -> None: """Save users.""" + if self._users is None: + await self.async_load() + assert self._users is not None + users = [ { 'id': user.id, diff --git a/homeassistant/auth/models.py b/homeassistant/auth/models.py index 3f49c56bce6..a6500510e0d 100644 --- a/homeassistant/auth/models.py +++ b/homeassistant/auth/models.py @@ -1,5 +1,6 @@ """Auth models.""" from datetime import datetime, timedelta +from typing import Dict, List, NamedTuple, Optional # noqa: F401 import uuid import attr @@ -14,17 +15,21 @@ from .util import generate_secret class User: """A user.""" - name = attr.ib(type=str) + name = attr.ib(type=str) # type: Optional[str] id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex)) is_owner = attr.ib(type=bool, default=False) is_active = attr.ib(type=bool, default=False) system_generated = attr.ib(type=bool, default=False) # List of credentials of a user. - credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False) + credentials = attr.ib( + type=list, default=attr.Factory(list), cmp=False + ) # type: List[Credentials] # Tokens associated with a user. - refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False) + refresh_tokens = attr.ib( + type=dict, default=attr.Factory(dict), cmp=False + ) # type: Dict[str, RefreshToken] @attr.s(slots=True) @@ -32,7 +37,7 @@ class RefreshToken: """RefreshToken for a user to grant new access tokens.""" user = attr.ib(type=User) - client_id = attr.ib(type=str) + client_id = attr.ib(type=str) # type: Optional[str] id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex)) created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow)) access_token_expiration = attr.ib(type=timedelta, @@ -48,10 +53,14 @@ class Credentials: """Credentials for a user on an auth provider.""" auth_provider_type = attr.ib(type=str) - auth_provider_id = attr.ib(type=str) + auth_provider_id = attr.ib(type=str) # type: Optional[str] # Allow the auth provider to store data to represent their auth. data = attr.ib(type=dict) id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex)) is_new = attr.ib(type=bool, default=True) + + +UserMeta = NamedTuple("UserMeta", + [('name', Optional[str]), ('is_active', bool)]) diff --git a/homeassistant/auth/providers/__init__.py b/homeassistant/auth/providers/__init__.py index ac5b6107b8a..328d83343d7 100644 --- a/homeassistant/auth/providers/__init__.py +++ b/homeassistant/auth/providers/__init__.py @@ -1,16 +1,19 @@ """Auth providers for Home Assistant.""" import importlib import logging +import types +from typing import Any, Dict, List, Optional import voluptuous as vol from voluptuous.humanize import humanize_error -from homeassistant import requirements -from homeassistant.core import callback +from homeassistant import data_entry_flow, requirements +from homeassistant.core import callback, HomeAssistant from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID from homeassistant.util.decorator import Registry -from homeassistant.auth.models import Credentials +from homeassistant.auth.auth_store import AuthStore +from homeassistant.auth.models import Credentials, UserMeta _LOGGER = logging.getLogger(__name__) DATA_REQS = 'auth_prov_reqs_processed' @@ -25,7 +28,80 @@ AUTH_PROVIDER_SCHEMA = vol.Schema({ }, extra=vol.ALLOW_EXTRA) -async def auth_provider_from_config(hass, store, config): +class AuthProvider: + """Provider of user authentication.""" + + DEFAULT_TITLE = 'Unnamed auth provider' + + def __init__(self, hass: HomeAssistant, store: AuthStore, + config: Dict[str, Any]) -> None: + """Initialize an auth provider.""" + self.hass = hass + self.store = store + self.config = config + + @property + def id(self) -> Optional[str]: # pylint: disable=invalid-name + """Return id of the auth provider. + + Optional, can be None. + """ + return self.config.get(CONF_ID) + + @property + def type(self) -> str: + """Return type of the provider.""" + return self.config[CONF_TYPE] # type: ignore + + @property + def name(self) -> str: + """Return the name of the auth provider.""" + return self.config.get(CONF_NAME, self.DEFAULT_TITLE) + + async def async_credentials(self) -> List[Credentials]: + """Return all credentials of this provider.""" + users = await self.store.async_get_users() + return [ + credentials + for user in users + for credentials in user.credentials + if (credentials.auth_provider_type == self.type and + credentials.auth_provider_id == self.id) + ] + + @callback + def async_create_credentials(self, data: Dict[str, str]) -> Credentials: + """Create credentials.""" + return Credentials( + auth_provider_type=self.type, + auth_provider_id=self.id, + data=data, + ) + + # Implement by extending class + + async def async_credential_flow( + self, context: Optional[Dict]) -> data_entry_flow.FlowHandler: + """Return the data flow for logging in with auth provider.""" + raise NotImplementedError + + async def async_get_or_create_credentials( + self, flow_result: Dict[str, str]) -> Credentials: + """Get credentials based on the flow result.""" + raise NotImplementedError + + async def async_user_meta_for_credentials( + self, credentials: Credentials) -> UserMeta: + """Return extra user metadata for credentials. + + Will be used to populate info when creating a new user. + """ + raise NotImplementedError + + +async def auth_provider_from_config( + hass: HomeAssistant, store: AuthStore, + config: Dict[str, Any]) -> Optional[AuthProvider]: """Initialize an auth provider from a config.""" provider_name = config[CONF_TYPE] module = await load_auth_provider_module(hass, provider_name) @@ -34,16 +110,17 @@ async def auth_provider_from_config(hass, store, config): return None try: - config = module.CONFIG_SCHEMA(config) + config = module.CONFIG_SCHEMA(config) # type: ignore except vol.Invalid as err: _LOGGER.error('Invalid configuration for auth provider %s: %s', provider_name, humanize_error(config, err)) return None - return AUTH_PROVIDERS[provider_name](hass, store, config) + return AUTH_PROVIDERS[provider_name](hass, store, config) # type: ignore -async def load_auth_provider_module(hass, provider): +async def load_auth_provider_module( + hass: HomeAssistant, provider: str) -> Optional[types.ModuleType]: """Load an auth provider.""" try: module = importlib.import_module( @@ -62,82 +139,13 @@ async def load_auth_provider_module(hass, provider): elif provider in processed: return module + # https://github.com/python/mypy/issues/1424 + reqs = module.REQUIREMENTS # type: ignore req_success = await requirements.async_process_requirements( - hass, 'auth provider {}'.format(provider), module.REQUIREMENTS) + hass, 'auth provider {}'.format(provider), reqs) if not req_success: return None processed.add(provider) return module - - -class AuthProvider: - """Provider of user authentication.""" - - DEFAULT_TITLE = 'Unnamed auth provider' - - def __init__(self, hass, store, config): - """Initialize an auth provider.""" - self.hass = hass - self.store = store - self.config = config - - @property - def id(self): # pylint: disable=invalid-name - """Return id of the auth provider. - - Optional, can be None. - """ - return self.config.get(CONF_ID) - - @property - def type(self): - """Return type of the provider.""" - return self.config[CONF_TYPE] - - @property - def name(self): - """Return the name of the auth provider.""" - return self.config.get(CONF_NAME, self.DEFAULT_TITLE) - - async def async_credentials(self): - """Return all credentials of this provider.""" - users = await self.store.async_get_users() - return [ - credentials - for user in users - for credentials in user.credentials - if (credentials.auth_provider_type == self.type and - credentials.auth_provider_id == self.id) - ] - - @callback - def async_create_credentials(self, data): - """Create credentials.""" - return Credentials( - auth_provider_type=self.type, - auth_provider_id=self.id, - data=data, - ) - - # Implement by extending class - - async def async_credential_flow(self, context): - """Return the data flow for logging in with auth provider.""" - raise NotImplementedError - - async def async_get_or_create_credentials(self, flow_result): - """Get credentials based on the flow result.""" - raise NotImplementedError - - async def async_user_meta_for_credentials(self, credentials): - """Return extra user metadata for credentials. - - Will be used to populate info when creating a new user. - - Values to populate: - - name: string - - is_active: boolean - """ - return {} diff --git a/homeassistant/auth/providers/homeassistant.py b/homeassistant/auth/providers/homeassistant.py index 5a2355264ab..7dbdf97b083 100644 --- a/homeassistant/auth/providers/homeassistant.py +++ b/homeassistant/auth/providers/homeassistant.py @@ -3,24 +3,25 @@ import base64 from collections import OrderedDict import hashlib import hmac -from typing import Dict # noqa: F401 pylint: disable=unused-import +from typing import Any, Dict, List, Optional # noqa: F401,E501 pylint: disable=unused-import import voluptuous as vol from homeassistant import data_entry_flow from homeassistant.const import CONF_ID -from homeassistant.core import callback +from homeassistant.core import callback, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.auth.util import generate_secret from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS +from ..models import Credentials, UserMeta STORAGE_VERSION = 1 STORAGE_KEY = 'auth_provider.homeassistant' -def _disallow_id(conf): +def _disallow_id(conf: Dict[str, Any]) -> Dict[str, Any]: """Disallow ID in config.""" if CONF_ID in conf: raise vol.Invalid( @@ -46,13 +47,13 @@ class InvalidUser(HomeAssistantError): class Data: """Hold the user data.""" - def __init__(self, hass): + def __init__(self, hass: HomeAssistant) -> None: """Initialize the user data store.""" self.hass = hass self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) - self._data = None + self._data = None # type: Optional[Dict[str, Any]] - async def async_load(self): + async def async_load(self) -> None: """Load stored data.""" data = await self._store.async_load() @@ -65,9 +66,9 @@ class Data: self._data = data @property - def users(self): + def users(self) -> List[Dict[str, str]]: """Return users.""" - return self._data['users'] + return self._data['users'] # type: ignore def validate_login(self, username: str, password: str) -> None: """Validate a username and password. @@ -79,7 +80,7 @@ class Data: found = None # Compare all users to avoid timing attacks. - for user in self._data['users']: + for user in self.users: if username == user['username']: found = user @@ -94,8 +95,8 @@ class Data: def hash_password(self, password: str, for_storage: bool = False) -> bytes: """Encode a password.""" - hashed = hashlib.pbkdf2_hmac( - 'sha512', password.encode(), self._data['salt'].encode(), 100000) + salt = self._data['salt'].encode() # type: ignore + hashed = hashlib.pbkdf2_hmac('sha512', password.encode(), salt, 100000) if for_storage: hashed = base64.b64encode(hashed) return hashed @@ -137,7 +138,7 @@ class Data: else: raise InvalidUser - async def async_save(self): + async def async_save(self) -> None: """Save data.""" await self._store.async_save(self._data) @@ -150,7 +151,7 @@ class HassAuthProvider(AuthProvider): data = None - async def async_initialize(self): + async def async_initialize(self) -> None: """Initialize the auth provider.""" if self.data is not None: return @@ -158,19 +159,22 @@ class HassAuthProvider(AuthProvider): self.data = Data(self.hass) await self.data.async_load() - async def async_credential_flow(self, context): + async def async_credential_flow( + self, context: Optional[Dict]) -> 'LoginFlow': """Return a flow to login.""" return LoginFlow(self) - async def async_validate_login(self, username: str, password: str): + async def async_validate_login(self, username: str, password: str) -> None: """Helper to validate a username and password.""" if self.data is None: await self.async_initialize() + assert self.data is not None await self.hass.async_add_executor_job( self.data.validate_login, username, password) - async def async_get_or_create_credentials(self, flow_result): + async def async_get_or_create_credentials( + self, flow_result: Dict[str, str]) -> Credentials: """Get credentials based on the flow result.""" username = flow_result['username'] @@ -183,17 +187,17 @@ class HassAuthProvider(AuthProvider): 'username': username }) - async def async_user_meta_for_credentials(self, credentials): + async def async_user_meta_for_credentials( + self, credentials: Credentials) -> UserMeta: """Get extra info for this credential.""" - return { - 'name': credentials.data['username'], - 'is_active': True, - } + return UserMeta(name=credentials.data['username'], is_active=True) - async def async_will_remove_credentials(self, credentials): + async def async_will_remove_credentials( + self, credentials: Credentials) -> None: """When credentials get removed, also remove the auth.""" if self.data is None: await self.async_initialize() + assert self.data is not None try: self.data.async_remove_auth(credentials.data['username']) @@ -206,11 +210,12 @@ class HassAuthProvider(AuthProvider): class LoginFlow(data_entry_flow.FlowHandler): """Handler for the login flow.""" - def __init__(self, auth_provider): + def __init__(self, auth_provider: HassAuthProvider) -> None: """Initialize the login flow.""" self._auth_provider = auth_provider - async def async_step_init(self, user_input=None): + async def async_step_init( + self, user_input: Dict[str, str] = None) -> Dict[str, Any]: """Handle the step of the form.""" errors = {} diff --git a/homeassistant/auth/providers/insecure_example.py b/homeassistant/auth/providers/insecure_example.py index 96f824140ed..144ca967302 100644 --- a/homeassistant/auth/providers/insecure_example.py +++ b/homeassistant/auth/providers/insecure_example.py @@ -1,6 +1,7 @@ """Example auth provider.""" from collections import OrderedDict import hmac +from typing import Any, Dict, Optional import voluptuous as vol @@ -9,6 +10,7 @@ from homeassistant import data_entry_flow from homeassistant.core import callback from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS +from ..models import Credentials, UserMeta USER_SCHEMA = vol.Schema({ @@ -31,12 +33,13 @@ class InvalidAuthError(HomeAssistantError): class ExampleAuthProvider(AuthProvider): """Example auth provider based on hardcoded usernames and passwords.""" - async def async_credential_flow(self, context): + async def async_credential_flow( + self, context: Optional[Dict]) -> 'LoginFlow': """Return a flow to login.""" return LoginFlow(self) @callback - def async_validate_login(self, username, password): + def async_validate_login(self, username: str, password: str) -> None: """Helper to validate a username and password.""" user = None @@ -56,7 +59,8 @@ class ExampleAuthProvider(AuthProvider): password.encode('utf-8')): raise InvalidAuthError - async def async_get_or_create_credentials(self, flow_result): + async def async_get_or_create_credentials( + self, flow_result: Dict[str, str]) -> Credentials: """Get credentials based on the flow result.""" username = flow_result['username'] @@ -69,32 +73,32 @@ class ExampleAuthProvider(AuthProvider): 'username': username }) - async def async_user_meta_for_credentials(self, credentials): + async def async_user_meta_for_credentials( + self, credentials: Credentials) -> UserMeta: """Return extra user metadata for credentials. Will be used to populate info when creating a new user. """ username = credentials.data['username'] - info = { - 'is_active': True, - } + name = None for user in self.config['users']: if user['username'] == username: - info['name'] = user.get('name') + name = user.get('name') break - return info + return UserMeta(name=name, is_active=True) class LoginFlow(data_entry_flow.FlowHandler): """Handler for the login flow.""" - def __init__(self, auth_provider): + def __init__(self, auth_provider: ExampleAuthProvider) -> None: """Initialize the login flow.""" self._auth_provider = auth_provider - async def async_step_init(self, user_input=None): + async def async_step_init( + self, user_input: Dict[str, str] = None) -> Dict[str, Any]: """Handle the step of the form.""" errors = {} @@ -111,7 +115,7 @@ class LoginFlow(data_entry_flow.FlowHandler): data=user_input ) - schema = OrderedDict() + schema = OrderedDict() # type: Dict[str, type] schema['username'] = str schema['password'] = str diff --git a/homeassistant/auth/providers/legacy_api_password.py b/homeassistant/auth/providers/legacy_api_password.py index f2f467e07ec..f276997bf06 100644 --- a/homeassistant/auth/providers/legacy_api_password.py +++ b/homeassistant/auth/providers/legacy_api_password.py @@ -5,14 +5,17 @@ It will be removed when auth system production ready """ from collections import OrderedDict import hmac +from typing import Any, Dict, Optional import voluptuous as vol +from homeassistant.components.http import HomeAssistantHTTP # noqa: F401 from homeassistant.exceptions import HomeAssistantError from homeassistant import data_entry_flow from homeassistant.core import callback from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS +from ..models import Credentials, UserMeta USER_SCHEMA = vol.Schema({ @@ -36,25 +39,29 @@ class LegacyApiPasswordAuthProvider(AuthProvider): DEFAULT_TITLE = 'Legacy API Password' - async def async_credential_flow(self, context): + async def async_credential_flow( + self, context: Optional[Dict]) -> 'LoginFlow': """Return a flow to login.""" return LoginFlow(self) @callback - def async_validate_login(self, password): + def async_validate_login(self, password: str) -> None: """Helper to validate a username and password.""" - if not hasattr(self.hass, 'http'): + hass_http = getattr(self.hass, 'http', None) # type: HomeAssistantHTTP + + if not hass_http: raise ValueError('http component is not loaded') - if self.hass.http.api_password is None: + if hass_http.api_password is None: raise ValueError('http component is not configured using' ' api_password') - if not hmac.compare_digest(self.hass.http.api_password.encode('utf-8'), + if not hmac.compare_digest(hass_http.api_password.encode('utf-8'), password.encode('utf-8')): raise InvalidAuthError - async def async_get_or_create_credentials(self, flow_result): + async def async_get_or_create_credentials( + self, flow_result: Dict[str, str]) -> Credentials: """Return LEGACY_USER always.""" for credential in await self.async_credentials(): if credential.data['username'] == LEGACY_USER: @@ -64,26 +71,25 @@ class LegacyApiPasswordAuthProvider(AuthProvider): 'username': LEGACY_USER }) - async def async_user_meta_for_credentials(self, credentials): + async def async_user_meta_for_credentials( + self, credentials: Credentials) -> UserMeta: """ Set name as LEGACY_USER always. Will be used to populate info when creating a new user. """ - return { - 'name': LEGACY_USER, - 'is_active': True, - } + return UserMeta(name=LEGACY_USER, is_active=True) class LoginFlow(data_entry_flow.FlowHandler): """Handler for the login flow.""" - def __init__(self, auth_provider): + def __init__(self, auth_provider: LegacyApiPasswordAuthProvider) -> None: """Initialize the login flow.""" self._auth_provider = auth_provider - async def async_step_init(self, user_input=None): + async def async_step_init( + self, user_input: Dict[str, str] = None) -> Dict[str, Any]: """Handle the step of the form.""" errors = {} @@ -100,7 +106,7 @@ class LoginFlow(data_entry_flow.FlowHandler): data={} ) - schema = OrderedDict() + schema = OrderedDict() # type: Dict[str, type] schema['password'] = str return self.async_show_form( diff --git a/homeassistant/auth/providers/trusted_networks.py b/homeassistant/auth/providers/trusted_networks.py index 7a4b0126505..3233fa5537f 100644 --- a/homeassistant/auth/providers/trusted_networks.py +++ b/homeassistant/auth/providers/trusted_networks.py @@ -3,12 +3,16 @@ It shows list of users if access from trusted network. Abort login flow if not access from trusted network. """ +from typing import Any, Dict, Optional, cast + import voluptuous as vol from homeassistant import data_entry_flow +from homeassistant.components.http import HomeAssistantHTTP # noqa: F401 from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS +from ..models import Credentials, UserMeta CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({ }, extra=vol.PREVENT_EXTRA) @@ -31,16 +35,20 @@ class TrustedNetworksAuthProvider(AuthProvider): DEFAULT_TITLE = 'Trusted Networks' - async def async_credential_flow(self, context): + async def async_credential_flow( + self, context: Optional[Dict]) -> 'LoginFlow': """Return a flow to login.""" + assert context is not None users = await self.store.async_get_users() available_users = {user.id: user.name for user in users if not user.system_generated and user.is_active} - return LoginFlow(self, context.get('ip_address'), available_users) + return LoginFlow(self, cast(str, context.get('ip_address')), + available_users) - async def async_get_or_create_credentials(self, flow_result): + async def async_get_or_create_credentials( + self, flow_result: Dict[str, str]) -> Credentials: """Get credentials based on the flow result.""" user_id = flow_result['user'] @@ -59,7 +67,8 @@ class TrustedNetworksAuthProvider(AuthProvider): # We only allow login as exist user raise InvalidUserError - async def async_user_meta_for_credentials(self, credentials): + async def async_user_meta_for_credentials( + self, credentials: Credentials) -> UserMeta: """Return extra user metadata for credentials. Trusted network auth provider should never create new user. @@ -67,31 +76,36 @@ class TrustedNetworksAuthProvider(AuthProvider): raise NotImplementedError @callback - def async_validate_access(self, ip_address): + def async_validate_access(self, ip_address: str) -> None: """Make sure the access from trusted networks. Raise InvalidAuthError if not. - Raise InvalidAuthError if trusted_networks is not config + Raise InvalidAuthError if trusted_networks is not configured. """ - if (not hasattr(self.hass, 'http') or - not self.hass.http or not self.hass.http.trusted_networks): + hass_http = getattr(self.hass, 'http', None) # type: HomeAssistantHTTP + + if not hass_http or not hass_http.trusted_networks: raise InvalidAuthError('trusted_networks is not configured') if not any(ip_address in trusted_network for trusted_network - in self.hass.http.trusted_networks): + in hass_http.trusted_networks): raise InvalidAuthError('Not in trusted_networks') class LoginFlow(data_entry_flow.FlowHandler): """Handler for the login flow.""" - def __init__(self, auth_provider, ip_address, available_users): + def __init__(self, auth_provider: TrustedNetworksAuthProvider, + ip_address: str, available_users: Dict[str, Optional[str]]) \ + -> None: """Initialize the login flow.""" self._auth_provider = auth_provider self._available_users = available_users self._ip_address = ip_address - async def async_step_init(self, user_input=None): + async def async_step_init( + self, user_input: Optional[Dict[str, str]] = None) \ + -> Dict[str, Any]: """Handle the step of the form.""" errors = {} try: diff --git a/tox.ini b/tox.ini index d6ef1981bef..e1261457c47 100644 --- a/tox.ini +++ b/tox.ini @@ -58,4 +58,4 @@ whitelist_externals=/bin/bash deps = -r{toxinidir}/requirements_test.txt commands = - /bin/bash -c 'mypy homeassistant/*.py homeassistant/util/' + /bin/bash -c 'mypy homeassistant/*.py homeassistant/auth/ homeassistant/util/'