mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 09:47:52 +00:00
Use IndieAuth for client ID (#15369)
* Use IndieAuth for client ID * Lint * Lint & Fix tests * Allow local IP addresses * Update comment
This commit is contained in:
parent
f7d7d825b0
commit
0d4841cbea
@ -186,16 +186,6 @@ class Credentials:
|
||||
is_new = attr.ib(type=bool, default=True)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Client:
|
||||
"""Client that interacts with Home Assistant on behalf of a user."""
|
||||
|
||||
name = attr.ib(type=str)
|
||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||
secret = attr.ib(type=str, default=attr.Factory(generate_secret))
|
||||
redirect_uris = attr.ib(type=list, default=attr.Factory(list))
|
||||
|
||||
|
||||
async def load_auth_provider_module(hass, provider):
|
||||
"""Load an auth provider."""
|
||||
try:
|
||||
@ -356,20 +346,20 @@ class AuthManager:
|
||||
"""Remove a user."""
|
||||
await self._store.async_remove_user(user)
|
||||
|
||||
async def async_create_refresh_token(self, user, client=None):
|
||||
async def async_create_refresh_token(self, user, client_id=None):
|
||||
"""Create a new refresh token for a user."""
|
||||
if not user.is_active:
|
||||
raise ValueError('User is not active')
|
||||
|
||||
if user.system_generated and client is not None:
|
||||
if user.system_generated and client_id is not None:
|
||||
raise ValueError(
|
||||
'System generated users cannot have refresh tokens connected '
|
||||
'to a client.')
|
||||
|
||||
if not user.system_generated and client is None:
|
||||
if not user.system_generated and client_id is None:
|
||||
raise ValueError('Client is required to generate a refresh token.')
|
||||
|
||||
return await self._store.async_create_refresh_token(user, client)
|
||||
return await self._store.async_create_refresh_token(user, client_id)
|
||||
|
||||
async def async_get_refresh_token(self, token):
|
||||
"""Get refresh token by token."""
|
||||
@ -396,26 +386,6 @@ class AuthManager:
|
||||
|
||||
return tkn
|
||||
|
||||
async def async_create_client(self, name, *, redirect_uris=None,
|
||||
no_secret=False):
|
||||
"""Create a new client."""
|
||||
return await self._store.async_create_client(
|
||||
name, redirect_uris, no_secret)
|
||||
|
||||
async def async_get_or_create_client(self, name, *, redirect_uris=None,
|
||||
no_secret=False):
|
||||
"""Find a client, if not exists, create a new one."""
|
||||
for client in await self._store.async_get_clients():
|
||||
if client.name == name:
|
||||
return client
|
||||
|
||||
return await self._store.async_create_client(
|
||||
name, redirect_uris, no_secret)
|
||||
|
||||
async def async_get_client(self, client_id):
|
||||
"""Get a client."""
|
||||
return await self._store.async_get_client(client_id)
|
||||
|
||||
async def _async_create_login_flow(self, handler, *, source, data):
|
||||
"""Create a login flow."""
|
||||
auth_provider = self._providers[handler]
|
||||
@ -456,7 +426,6 @@ class AuthStore:
|
||||
"""Initialize the auth store."""
|
||||
self.hass = hass
|
||||
self._users = None
|
||||
self._clients = None
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
async def async_get_users(self):
|
||||
@ -515,9 +484,8 @@ class AuthStore:
|
||||
self._users.pop(user.id)
|
||||
await self.async_save()
|
||||
|
||||
async def async_create_refresh_token(self, user, client=None):
|
||||
async def async_create_refresh_token(self, user, client_id=None):
|
||||
"""Create a new token for a user."""
|
||||
client_id = client.id if client is not None else None
|
||||
refresh_token = RefreshToken(user=user, client_id=client_id)
|
||||
user.refresh_tokens[refresh_token.token] = refresh_token
|
||||
await self.async_save()
|
||||
@ -535,38 +503,6 @@ class AuthStore:
|
||||
|
||||
return None
|
||||
|
||||
async def async_create_client(self, name, redirect_uris, no_secret):
|
||||
"""Create a new client."""
|
||||
if self._clients is None:
|
||||
await self.async_load()
|
||||
|
||||
kwargs = {
|
||||
'name': name,
|
||||
'redirect_uris': redirect_uris
|
||||
}
|
||||
|
||||
if no_secret:
|
||||
kwargs['secret'] = None
|
||||
|
||||
client = Client(**kwargs)
|
||||
self._clients[client.id] = client
|
||||
await self.async_save()
|
||||
return client
|
||||
|
||||
async def async_get_clients(self):
|
||||
"""Return all clients."""
|
||||
if self._clients is None:
|
||||
await self.async_load()
|
||||
|
||||
return list(self._clients.values())
|
||||
|
||||
async def async_get_client(self, client_id):
|
||||
"""Get a client."""
|
||||
if self._clients is None:
|
||||
await self.async_load()
|
||||
|
||||
return self._clients.get(client_id)
|
||||
|
||||
async def async_load(self):
|
||||
"""Load the users."""
|
||||
data = await self._store.async_load()
|
||||
@ -578,7 +514,6 @@ class AuthStore:
|
||||
|
||||
if data is None:
|
||||
self._users = {}
|
||||
self._clients = {}
|
||||
return
|
||||
|
||||
users = {
|
||||
@ -618,12 +553,7 @@ class AuthStore:
|
||||
)
|
||||
refresh_token.access_tokens.append(token)
|
||||
|
||||
clients = {
|
||||
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
|
||||
}
|
||||
|
||||
self._users = users
|
||||
self._clients = clients
|
||||
|
||||
async def async_save(self):
|
||||
"""Save users."""
|
||||
@ -676,19 +606,8 @@ class AuthStore:
|
||||
for access_token in refresh_token.access_tokens
|
||||
]
|
||||
|
||||
clients = [
|
||||
{
|
||||
'id': client.id,
|
||||
'name': client.name,
|
||||
'secret': client.secret,
|
||||
'redirect_uris': client.redirect_uris,
|
||||
}
|
||||
for client in self._clients.values()
|
||||
]
|
||||
|
||||
data = {
|
||||
'users': users,
|
||||
'clients': clients,
|
||||
'credentials': credentials,
|
||||
'access_tokens': access_tokens,
|
||||
'refresh_tokens': refresh_tokens,
|
||||
|
@ -115,7 +115,8 @@ from homeassistant.helpers.data_entry_flow import (
|
||||
from homeassistant.components.http.view import HomeAssistantView
|
||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
|
||||
from .client import verify_client
|
||||
from . import indieauth
|
||||
|
||||
|
||||
DOMAIN = 'auth'
|
||||
DEPENDENCIES = ['http']
|
||||
@ -143,8 +144,7 @@ class AuthProvidersView(HomeAssistantView):
|
||||
name = 'api:auth:providers'
|
||||
requires_auth = False
|
||||
|
||||
@verify_client
|
||||
async def get(self, request, client):
|
||||
async def get(self, request):
|
||||
"""Get available auth providers."""
|
||||
return self.json([{
|
||||
'name': provider.name,
|
||||
@ -164,16 +164,16 @@ class LoginFlowIndexView(FlowManagerIndexView):
|
||||
"""Do not allow index of flows in progress."""
|
||||
return aiohttp.web.Response(status=405)
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
@verify_client
|
||||
@RequestDataValidator(vol.Schema({
|
||||
vol.Required('client_id'): str,
|
||||
vol.Required('handler'): vol.Any(str, list),
|
||||
vol.Required('redirect_uri'): str,
|
||||
}))
|
||||
async def post(self, request, client, data):
|
||||
async def post(self, request, data):
|
||||
"""Create a new login flow."""
|
||||
if data['redirect_uri'] not in client.redirect_uris:
|
||||
return self.json_message('invalid redirect uri', )
|
||||
if not indieauth.verify_redirect_uri(data['client_id'],
|
||||
data['redirect_uri']):
|
||||
return self.json_message('invalid client id or redirect uri', 400)
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
return await super().post(request)
|
||||
@ -191,16 +191,20 @@ class LoginFlowResourceView(FlowManagerResourceView):
|
||||
super().__init__(flow_mgr)
|
||||
self._store_credentials = store_credentials
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
async def get(self, request):
|
||||
async def get(self, request, flow_id):
|
||||
"""Do not allow getting status of a flow in progress."""
|
||||
return self.json_message('Invalid flow specified', 404)
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
@verify_client
|
||||
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
|
||||
async def post(self, request, client, flow_id, data):
|
||||
@RequestDataValidator(vol.Schema({
|
||||
'client_id': str
|
||||
}, extra=vol.ALLOW_EXTRA))
|
||||
async def post(self, request, flow_id, data):
|
||||
"""Handle progressing a login flow request."""
|
||||
client_id = data.pop('client_id')
|
||||
|
||||
if not indieauth.verify_client_id(client_id):
|
||||
return self.json_message('Invalid client id', 400)
|
||||
|
||||
try:
|
||||
result = await self._flow_mgr.async_configure(flow_id, data)
|
||||
except data_entry_flow.UnknownFlow:
|
||||
@ -212,7 +216,7 @@ class LoginFlowResourceView(FlowManagerResourceView):
|
||||
return self.json(self._prepare_result_json(result))
|
||||
|
||||
result.pop('data')
|
||||
result['result'] = self._store_credentials(client.id, result['result'])
|
||||
result['result'] = self._store_credentials(client_id, result['result'])
|
||||
|
||||
return self.json(result)
|
||||
|
||||
@ -228,24 +232,31 @@ class GrantTokenView(HomeAssistantView):
|
||||
"""Initialize the grant token view."""
|
||||
self._retrieve_credentials = retrieve_credentials
|
||||
|
||||
@verify_client
|
||||
async def post(self, request, client):
|
||||
async def post(self, request):
|
||||
"""Grant a token."""
|
||||
hass = request.app['hass']
|
||||
data = await request.post()
|
||||
|
||||
client_id = data.get('client_id')
|
||||
if client_id is None or not indieauth.verify_client_id(client_id):
|
||||
return self.json({
|
||||
'error': 'invalid_request',
|
||||
}, status_code=400)
|
||||
|
||||
grant_type = data.get('grant_type')
|
||||
|
||||
if grant_type == 'authorization_code':
|
||||
return await self._async_handle_auth_code(hass, client, data)
|
||||
return await self._async_handle_auth_code(hass, client_id, data)
|
||||
|
||||
elif grant_type == 'refresh_token':
|
||||
return await self._async_handle_refresh_token(hass, client, data)
|
||||
return await self._async_handle_refresh_token(
|
||||
hass, client_id, data)
|
||||
|
||||
return self.json({
|
||||
'error': 'unsupported_grant_type',
|
||||
}, status_code=400)
|
||||
|
||||
async def _async_handle_auth_code(self, hass, client, data):
|
||||
async def _async_handle_auth_code(self, hass, client_id, data):
|
||||
"""Handle authorization code request."""
|
||||
code = data.get('code')
|
||||
|
||||
@ -254,7 +265,7 @@ class GrantTokenView(HomeAssistantView):
|
||||
'error': 'invalid_request',
|
||||
}, status_code=400)
|
||||
|
||||
credentials = self._retrieve_credentials(client.id, code)
|
||||
credentials = self._retrieve_credentials(client_id, code)
|
||||
|
||||
if credentials is None:
|
||||
return self.json({
|
||||
@ -263,7 +274,7 @@ class GrantTokenView(HomeAssistantView):
|
||||
|
||||
user = await hass.auth.async_get_or_create_user(credentials)
|
||||
refresh_token = await hass.auth.async_create_refresh_token(user,
|
||||
client)
|
||||
client_id)
|
||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||
|
||||
return self.json({
|
||||
@ -274,7 +285,7 @@ class GrantTokenView(HomeAssistantView):
|
||||
int(refresh_token.access_token_expiration.total_seconds()),
|
||||
})
|
||||
|
||||
async def _async_handle_refresh_token(self, hass, client, data):
|
||||
async def _async_handle_refresh_token(self, hass, client_id, data):
|
||||
"""Handle authorization code request."""
|
||||
token = data.get('refresh_token')
|
||||
|
||||
@ -285,7 +296,7 @@ class GrantTokenView(HomeAssistantView):
|
||||
|
||||
refresh_token = await hass.auth.async_get_refresh_token(token)
|
||||
|
||||
if refresh_token is None or refresh_token.client_id != client.id:
|
||||
if refresh_token is None or refresh_token.client_id != client_id:
|
||||
return self.json({
|
||||
'error': 'invalid_grant',
|
||||
}, status_code=400)
|
||||
|
@ -1,79 +0,0 @@
|
||||
"""Helpers to resolve client ID/secret."""
|
||||
import base64
|
||||
from functools import wraps
|
||||
import hmac
|
||||
|
||||
import aiohttp.hdrs
|
||||
|
||||
|
||||
def verify_client(method):
|
||||
"""Decorator to verify client id/secret on requests."""
|
||||
@wraps(method)
|
||||
async def wrapper(view, request, *args, **kwargs):
|
||||
"""Verify client id/secret before doing request."""
|
||||
client = await _verify_client(request)
|
||||
|
||||
if client is None:
|
||||
return view.json({
|
||||
'error': 'invalid_client',
|
||||
}, status_code=401)
|
||||
|
||||
return await method(
|
||||
view, request, *args, **kwargs, client=client)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
async def _verify_client(request):
|
||||
"""Method to verify the client id/secret in consistent time.
|
||||
|
||||
By using a consistent time for looking up client id and comparing the
|
||||
secret, we prevent attacks by malicious actors trying different client ids
|
||||
and are able to derive from the time it takes to process the request if
|
||||
they guessed the client id correctly.
|
||||
"""
|
||||
if aiohttp.hdrs.AUTHORIZATION not in request.headers:
|
||||
return None
|
||||
|
||||
auth_type, auth_value = \
|
||||
request.headers.get(aiohttp.hdrs.AUTHORIZATION).split(' ', 1)
|
||||
|
||||
if auth_type != 'Basic':
|
||||
return None
|
||||
|
||||
decoded = base64.b64decode(auth_value).decode('utf-8')
|
||||
try:
|
||||
client_id, client_secret = decoded.split(':', 1)
|
||||
except ValueError:
|
||||
# If no ':' in decoded
|
||||
client_id, client_secret = decoded, None
|
||||
|
||||
return await async_secure_get_client(
|
||||
request.app['hass'], client_id, client_secret)
|
||||
|
||||
|
||||
async def async_secure_get_client(hass, client_id, client_secret):
|
||||
"""Get a client id/secret in consistent time."""
|
||||
client = await hass.auth.async_get_client(client_id)
|
||||
|
||||
if client is None:
|
||||
if client_secret is not None:
|
||||
# Still do a compare so we run same time as if a client was found.
|
||||
hmac.compare_digest(client_secret.encode('utf-8'),
|
||||
client_secret.encode('utf-8'))
|
||||
return None
|
||||
|
||||
if client.secret is None:
|
||||
return client
|
||||
|
||||
elif client_secret is None:
|
||||
# Still do a compare so we run same time as if a secret was passed.
|
||||
hmac.compare_digest(client.secret.encode('utf-8'),
|
||||
client.secret.encode('utf-8'))
|
||||
return None
|
||||
|
||||
elif hmac.compare_digest(client_secret.encode('utf-8'),
|
||||
client.secret.encode('utf-8')):
|
||||
return client
|
||||
|
||||
return None
|
130
homeassistant/components/auth/indieauth.py
Normal file
130
homeassistant/components/auth/indieauth.py
Normal file
@ -0,0 +1,130 @@
|
||||
"""Helpers to resolve client ID/secret."""
|
||||
from ipaddress import ip_address, ip_network
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# IP addresses of loopback interfaces
|
||||
ALLOWED_IPS = (
|
||||
ip_address('127.0.0.1'),
|
||||
ip_address('::1'),
|
||||
)
|
||||
|
||||
# RFC1918 - Address allocation for Private Internets
|
||||
ALLOWED_NETWORKS = (
|
||||
ip_network('10.0.0.0/8'),
|
||||
ip_network('172.16.0.0/12'),
|
||||
ip_network('192.168.0.0/16'),
|
||||
)
|
||||
|
||||
|
||||
def verify_redirect_uri(client_id, redirect_uri):
|
||||
"""Verify that the client and redirect uri match."""
|
||||
try:
|
||||
client_id_parts = _parse_client_id(client_id)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
redirect_parts = _parse_url(redirect_uri)
|
||||
|
||||
# IndieAuth 4.2.2 allows for redirect_uri to be on different domain
|
||||
# but needs to be specified in link tag when fetching `client_id`.
|
||||
# This is not implemented.
|
||||
|
||||
# Verify redirect url and client url have same scheme and domain.
|
||||
return (
|
||||
client_id_parts.scheme == redirect_parts.scheme and
|
||||
client_id_parts.netloc == redirect_parts.netloc
|
||||
)
|
||||
|
||||
|
||||
def verify_client_id(client_id):
|
||||
"""Verify that the client id is valid."""
|
||||
try:
|
||||
_parse_client_id(client_id)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _parse_url(url):
|
||||
"""Parse a url in parts and canonicalize according to IndieAuth."""
|
||||
parts = urlparse(url)
|
||||
|
||||
# Canonicalize a url according to IndieAuth 3.2.
|
||||
|
||||
# SHOULD convert the hostname to lowercase
|
||||
parts = parts._replace(netloc=parts.netloc.lower())
|
||||
|
||||
# If a URL with no path component is ever encountered,
|
||||
# it MUST be treated as if it had the path /.
|
||||
if parts.path == '':
|
||||
parts = parts._replace(path='/')
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _parse_client_id(client_id):
|
||||
"""Test if client id is a valid URL according to IndieAuth section 3.2.
|
||||
|
||||
https://indieauth.spec.indieweb.org/#client-identifier
|
||||
"""
|
||||
parts = _parse_url(client_id)
|
||||
|
||||
# Client identifier URLs
|
||||
# MUST have either an https or http scheme
|
||||
if parts.scheme not in ('http', 'https'):
|
||||
raise ValueError()
|
||||
|
||||
# MUST contain a path component
|
||||
# Handled by url canonicalization.
|
||||
|
||||
# MUST NOT contain single-dot or double-dot path segments
|
||||
if any(segment in ('.', '..') for segment in parts.path.split('/')):
|
||||
raise ValueError(
|
||||
'Client ID cannot contain single-dot or double-dot path segments')
|
||||
|
||||
# MUST NOT contain a fragment component
|
||||
if parts.fragment != '':
|
||||
raise ValueError('Client ID cannot contain a fragment')
|
||||
|
||||
# MUST NOT contain a username or password component
|
||||
if parts.username is not None:
|
||||
raise ValueError('Client ID cannot contain username')
|
||||
|
||||
if parts.password is not None:
|
||||
raise ValueError('Client ID cannot contain password')
|
||||
|
||||
# MAY contain a port
|
||||
try:
|
||||
# parts raises ValueError when port cannot be parsed as int
|
||||
parts.port
|
||||
except ValueError:
|
||||
raise ValueError('Client ID contains invalid port')
|
||||
|
||||
# Additionally, hostnames
|
||||
# MUST be domain names or a loopback interface and
|
||||
# MUST NOT be IPv4 or IPv6 addresses except for IPv4 127.0.0.1
|
||||
# or IPv6 [::1]
|
||||
|
||||
# We are not goint to follow the spec here. We are going to allow
|
||||
# any internal network IP to be used inside a client id.
|
||||
|
||||
address = None
|
||||
|
||||
try:
|
||||
netloc = parts.netloc
|
||||
|
||||
# Strip the [, ] from ipv6 addresses before parsing
|
||||
if netloc[0] == '[' and netloc[-1] == ']':
|
||||
netloc = netloc[1:-1]
|
||||
|
||||
address = ip_address(netloc)
|
||||
except ValueError:
|
||||
# Not an ip address
|
||||
pass
|
||||
|
||||
if (address is None or
|
||||
address in ALLOWED_IPS or
|
||||
any(address in network for network in ALLOWED_NETWORKS)):
|
||||
return parts
|
||||
|
||||
raise ValueError('Hostname should be a domain name or local IP address')
|
@ -200,15 +200,6 @@ def add_manifest_json_key(key, val):
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Set up the serving of the frontend."""
|
||||
if hass.auth.active:
|
||||
client = await hass.auth.async_get_or_create_client(
|
||||
'Home Assistant Frontend',
|
||||
redirect_uris=['/'],
|
||||
no_secret=True,
|
||||
)
|
||||
else:
|
||||
client = None
|
||||
|
||||
hass.components.websocket_api.async_register_command(
|
||||
WS_TYPE_GET_PANELS, websocket_get_panels, SCHEMA_GET_PANELS)
|
||||
hass.components.websocket_api.async_register_command(
|
||||
@ -255,7 +246,7 @@ async def async_setup(hass, config):
|
||||
if os.path.isdir(local):
|
||||
hass.http.register_static_path("/local", local, not is_dev)
|
||||
|
||||
index_view = IndexView(repo_path, js_version, client)
|
||||
index_view = IndexView(repo_path, js_version, hass.auth.active)
|
||||
hass.http.register_view(index_view)
|
||||
|
||||
@callback
|
||||
@ -350,11 +341,11 @@ class IndexView(HomeAssistantView):
|
||||
requires_auth = False
|
||||
extra_urls = ['/states', '/states/{extra}']
|
||||
|
||||
def __init__(self, repo_path, js_option, client):
|
||||
def __init__(self, repo_path, js_option, auth_active):
|
||||
"""Initialize the frontend view."""
|
||||
self.repo_path = repo_path
|
||||
self.js_option = js_option
|
||||
self.client = client
|
||||
self.auth_active = auth_active
|
||||
self._template_cache = {}
|
||||
|
||||
def get_template(self, latest):
|
||||
@ -399,11 +390,9 @@ class IndexView(HomeAssistantView):
|
||||
no_auth=no_auth,
|
||||
theme_color=MANIFEST_JSON['theme_color'],
|
||||
extra_urls=hass.data[extra_key],
|
||||
client_id=self.auth_active
|
||||
)
|
||||
|
||||
if self.client is not None:
|
||||
template_params['client_id'] = self.client.id
|
||||
|
||||
return web.Response(text=template.render(**template_params),
|
||||
content_type='text/html')
|
||||
|
||||
|
@ -31,6 +31,8 @@ from homeassistant.util.async_ import (
|
||||
_TEST_INSTANCE_PORT = SERVER_PORT
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
INSTANCES = []
|
||||
CLIENT_ID = 'https://example.com/app'
|
||||
CLIENT_REDIRECT_URI = 'https://example.com/app/callback'
|
||||
|
||||
|
||||
def threadsafe_callback_factory(func):
|
||||
@ -330,8 +332,6 @@ class MockUser(auth.User):
|
||||
def ensure_auth_manager_loaded(auth_mgr):
|
||||
"""Ensure an auth manager is considered loaded."""
|
||||
store = auth_mgr._store
|
||||
if store._clients is None:
|
||||
store._clients = {}
|
||||
if store._users is None:
|
||||
store._users = {}
|
||||
|
||||
|
@ -1,6 +1,4 @@
|
||||
"""Tests for the auth component."""
|
||||
from aiohttp.helpers import BasicAuth
|
||||
|
||||
from homeassistant import auth
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
@ -16,10 +14,6 @@ BASE_CONFIG = [{
|
||||
'name': 'Test Name'
|
||||
}]
|
||||
}]
|
||||
CLIENT_ID = 'test-id'
|
||||
CLIENT_SECRET = 'test-secret'
|
||||
CLIENT_AUTH = BasicAuth(CLIENT_ID, CLIENT_SECRET)
|
||||
CLIENT_REDIRECT_URI = 'http://example.com/callback'
|
||||
|
||||
|
||||
async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
|
||||
@ -32,9 +26,6 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
|
||||
'api_password': 'bla'
|
||||
}
|
||||
})
|
||||
client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET,
|
||||
redirect_uris=[CLIENT_REDIRECT_URI])
|
||||
hass.auth._store._clients[client.id] = client
|
||||
if setup_api:
|
||||
await async_setup_component(hass, 'api', {})
|
||||
return await aiohttp_client(hass.http.app)
|
||||
|
@ -1,70 +0,0 @@
|
||||
"""Tests for the client validator."""
|
||||
from aiohttp.helpers import BasicAuth
|
||||
import pytest
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components.auth.client import verify_client
|
||||
from homeassistant.components.http.view import HomeAssistantView
|
||||
|
||||
from . import async_setup_auth
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_view(hass):
|
||||
"""Register a view that verifies client id/secret."""
|
||||
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
|
||||
|
||||
clients = []
|
||||
|
||||
class ClientView(HomeAssistantView):
|
||||
url = '/'
|
||||
name = 'bla'
|
||||
|
||||
@verify_client
|
||||
async def get(self, request, client):
|
||||
"""Handle GET request."""
|
||||
clients.append(client)
|
||||
|
||||
hass.http.register_view(ClientView)
|
||||
return clients
|
||||
|
||||
|
||||
async def test_verify_client(hass, aiohttp_client, mock_view):
|
||||
"""Test that verify client can extract client auth from a request."""
|
||||
http_client = await async_setup_auth(hass, aiohttp_client)
|
||||
client = await hass.auth.async_create_client('Hello')
|
||||
|
||||
resp = await http_client.get('/', auth=BasicAuth(client.id, client.secret))
|
||||
assert resp.status == 200
|
||||
assert mock_view[0] is client
|
||||
|
||||
|
||||
async def test_verify_client_no_auth_header(hass, aiohttp_client, mock_view):
|
||||
"""Test that verify client will decline unknown client id."""
|
||||
http_client = await async_setup_auth(hass, aiohttp_client)
|
||||
|
||||
resp = await http_client.get('/')
|
||||
assert resp.status == 401
|
||||
assert mock_view == []
|
||||
|
||||
|
||||
async def test_verify_client_invalid_client_id(hass, aiohttp_client,
|
||||
mock_view):
|
||||
"""Test that verify client will decline unknown client id."""
|
||||
http_client = await async_setup_auth(hass, aiohttp_client)
|
||||
client = await hass.auth.async_create_client('Hello')
|
||||
|
||||
resp = await http_client.get('/', auth=BasicAuth('invalid', client.secret))
|
||||
assert resp.status == 401
|
||||
assert mock_view == []
|
||||
|
||||
|
||||
async def test_verify_client_invalid_client_secret(hass, aiohttp_client,
|
||||
mock_view):
|
||||
"""Test that verify client will decline incorrect client secret."""
|
||||
http_client = await async_setup_auth(hass, aiohttp_client)
|
||||
client = await hass.auth.async_create_client('Hello')
|
||||
|
||||
resp = await http_client.get('/', auth=BasicAuth(client.id, 'invalid'))
|
||||
assert resp.status == 401
|
||||
assert mock_view == []
|
110
tests/components/auth/test_indieauth.py
Normal file
110
tests/components/auth/test_indieauth.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Tests for the client validator."""
|
||||
from homeassistant.components.auth import indieauth
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_client_id_scheme():
|
||||
"""Test we enforce valid scheme."""
|
||||
assert indieauth._parse_client_id('http://ex.com/')
|
||||
assert indieauth._parse_client_id('https://ex.com/')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
indieauth._parse_client_id('ftp://ex.com')
|
||||
|
||||
|
||||
def test_client_id_path():
|
||||
"""Test we enforce valid path."""
|
||||
assert indieauth._parse_client_id('http://ex.com').path == '/'
|
||||
assert indieauth._parse_client_id('http://ex.com/hello').path == '/hello'
|
||||
assert indieauth._parse_client_id(
|
||||
'http://ex.com/hello/.world').path == '/hello/.world'
|
||||
assert indieauth._parse_client_id(
|
||||
'http://ex.com/hello./.world').path == '/hello./.world'
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
indieauth._parse_client_id('http://ex.com/.')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
indieauth._parse_client_id('http://ex.com/hello/./yo')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
indieauth._parse_client_id('http://ex.com/hello/../yo')
|
||||
|
||||
|
||||
def test_client_id_fragment():
|
||||
"""Test we enforce valid fragment."""
|
||||
with pytest.raises(ValueError):
|
||||
indieauth._parse_client_id('http://ex.com/#yoo')
|
||||
|
||||
|
||||
def test_client_id_user_pass():
|
||||
"""Test we enforce valid username/password."""
|
||||
with pytest.raises(ValueError):
|
||||
indieauth._parse_client_id('http://user@ex.com/')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
indieauth._parse_client_id('http://user:pass@ex.com/')
|
||||
|
||||
|
||||
def test_client_id_hostname():
|
||||
"""Test we enforce valid hostname."""
|
||||
assert indieauth._parse_client_id('http://www.home-assistant.io/')
|
||||
assert indieauth._parse_client_id('http://[::1]')
|
||||
assert indieauth._parse_client_id('http://127.0.0.1')
|
||||
assert indieauth._parse_client_id('http://10.0.0.0')
|
||||
assert indieauth._parse_client_id('http://10.255.255.255')
|
||||
assert indieauth._parse_client_id('http://172.16.0.0')
|
||||
assert indieauth._parse_client_id('http://172.31.255.255')
|
||||
assert indieauth._parse_client_id('http://192.168.0.0')
|
||||
assert indieauth._parse_client_id('http://192.168.255.255')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert indieauth._parse_client_id('http://255.255.255.255/')
|
||||
with pytest.raises(ValueError):
|
||||
assert indieauth._parse_client_id('http://11.0.0.0/')
|
||||
with pytest.raises(ValueError):
|
||||
assert indieauth._parse_client_id('http://172.32.0.0/')
|
||||
with pytest.raises(ValueError):
|
||||
assert indieauth._parse_client_id('http://192.167.0.0/')
|
||||
|
||||
|
||||
def test_parse_url_lowercase_host():
|
||||
"""Test we update empty paths."""
|
||||
assert indieauth._parse_url('http://ex.com/hello').path == '/hello'
|
||||
assert indieauth._parse_url('http://EX.COM/hello').hostname == 'ex.com'
|
||||
|
||||
parts = indieauth._parse_url('http://EX.COM:123/HELLO')
|
||||
assert parts.netloc == 'ex.com:123'
|
||||
assert parts.path == '/HELLO'
|
||||
|
||||
|
||||
def test_parse_url_path():
|
||||
"""Test we update empty paths."""
|
||||
assert indieauth._parse_url('http://ex.com').path == '/'
|
||||
|
||||
|
||||
def test_verify_redirect_uri():
|
||||
"""Test that we verify redirect uri correctly."""
|
||||
assert indieauth.verify_redirect_uri(
|
||||
'http://ex.com',
|
||||
'http://ex.com/callback'
|
||||
)
|
||||
|
||||
# Different domain
|
||||
assert not indieauth.verify_redirect_uri(
|
||||
'http://ex.com',
|
||||
'http://different.com/callback'
|
||||
)
|
||||
|
||||
# Different scheme
|
||||
assert not indieauth.verify_redirect_uri(
|
||||
'http://ex.com',
|
||||
'https://ex.com/callback'
|
||||
)
|
||||
|
||||
# Different subdomain
|
||||
assert not indieauth.verify_redirect_uri(
|
||||
'https://sub1.ex.com',
|
||||
'https://sub2.ex.com/callback'
|
||||
)
|
@ -1,22 +1,26 @@
|
||||
"""Integration tests for the auth component."""
|
||||
from . import async_setup_auth, CLIENT_AUTH, CLIENT_REDIRECT_URI
|
||||
from . import async_setup_auth
|
||||
|
||||
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
|
||||
|
||||
|
||||
async def test_login_new_user_and_refresh_token(hass, aiohttp_client):
|
||||
"""Test logging in with new user and refreshing tokens."""
|
||||
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
|
||||
resp = await client.post('/auth/login_flow', json={
|
||||
'client_id': CLIENT_ID,
|
||||
'handler': ['insecure_example', None],
|
||||
'redirect_uri': CLIENT_REDIRECT_URI,
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
|
||||
resp = await client.post(
|
||||
'/auth/login_flow/{}'.format(step['flow_id']), json={
|
||||
'client_id': CLIENT_ID,
|
||||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
@ -24,9 +28,10 @@ async def test_login_new_user_and_refresh_token(hass, aiohttp_client):
|
||||
|
||||
# Exchange code for tokens
|
||||
resp = await client.post('/auth/token', data={
|
||||
'client_id': CLIENT_ID,
|
||||
'grant_type': 'authorization_code',
|
||||
'code': code
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
tokens = await resp.json()
|
||||
@ -35,9 +40,10 @@ async def test_login_new_user_and_refresh_token(hass, aiohttp_client):
|
||||
|
||||
# Use refresh token to get more tokens.
|
||||
resp = await client.post('/auth/token', data={
|
||||
'client_id': CLIENT_ID,
|
||||
'grant_type': 'refresh_token',
|
||||
'refresh_token': tokens['refresh_token']
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
tokens = await resp.json()
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Tests for the link user flow."""
|
||||
from . import async_setup_auth, CLIENT_AUTH, CLIENT_ID, CLIENT_REDIRECT_URI
|
||||
from . import async_setup_auth
|
||||
|
||||
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
|
||||
|
||||
|
||||
async def async_get_code(hass, aiohttp_client):
|
||||
@ -25,17 +27,19 @@ async def async_get_code(hass, aiohttp_client):
|
||||
client = await async_setup_auth(hass, aiohttp_client, config)
|
||||
|
||||
resp = await client.post('/auth/login_flow', json={
|
||||
'client_id': CLIENT_ID,
|
||||
'handler': ['insecure_example', None],
|
||||
'redirect_uri': CLIENT_REDIRECT_URI,
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
|
||||
resp = await client.post(
|
||||
'/auth/login_flow/{}'.format(step['flow_id']), json={
|
||||
'client_id': CLIENT_ID,
|
||||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
@ -43,9 +47,10 @@ async def async_get_code(hass, aiohttp_client):
|
||||
|
||||
# Exchange code for tokens
|
||||
resp = await client.post('/auth/token', data={
|
||||
'client_id': CLIENT_ID,
|
||||
'grant_type': 'authorization_code',
|
||||
'code': code
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
tokens = await resp.json()
|
||||
@ -57,17 +62,19 @@ async def async_get_code(hass, aiohttp_client):
|
||||
|
||||
# Now authenticate with the 2nd flow
|
||||
resp = await client.post('/auth/login_flow', json={
|
||||
'client_id': CLIENT_ID,
|
||||
'handler': ['insecure_example', '2nd auth'],
|
||||
'redirect_uri': CLIENT_REDIRECT_URI,
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
|
||||
resp = await client.post(
|
||||
'/auth/login_flow/{}'.format(step['flow_id']), json={
|
||||
'client_id': CLIENT_ID,
|
||||
'username': '2nd-user',
|
||||
'password': '2nd-pass',
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
|
@ -1,13 +1,13 @@
|
||||
"""Tests for the login flow."""
|
||||
from aiohttp.helpers import BasicAuth
|
||||
from . import async_setup_auth
|
||||
|
||||
from . import async_setup_auth, CLIENT_AUTH, CLIENT_REDIRECT_URI
|
||||
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
|
||||
|
||||
|
||||
async def test_fetch_auth_providers(hass, aiohttp_client):
|
||||
"""Test fetching auth providers."""
|
||||
client = await async_setup_auth(hass, aiohttp_client)
|
||||
resp = await client.get('/auth/providers', auth=CLIENT_AUTH)
|
||||
resp = await client.get('/auth/providers')
|
||||
assert await resp.json() == [{
|
||||
'name': 'Example',
|
||||
'type': 'insecure_example',
|
||||
@ -15,14 +15,6 @@ async def test_fetch_auth_providers(hass, aiohttp_client):
|
||||
}]
|
||||
|
||||
|
||||
async def test_fetch_auth_providers_require_valid_client(hass, aiohttp_client):
|
||||
"""Test fetching auth providers."""
|
||||
client = await async_setup_auth(hass, aiohttp_client)
|
||||
resp = await client.get('/auth/providers',
|
||||
auth=BasicAuth('invalid', 'bla'))
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
async def test_cannot_get_flows_in_progress(hass, aiohttp_client):
|
||||
"""Test we cannot get flows in progress."""
|
||||
client = await async_setup_auth(hass, aiohttp_client, [])
|
||||
@ -34,18 +26,20 @@ async def test_invalid_username_password(hass, aiohttp_client):
|
||||
"""Test we cannot get flows in progress."""
|
||||
client = await async_setup_auth(hass, aiohttp_client)
|
||||
resp = await client.post('/auth/login_flow', json={
|
||||
'client_id': CLIENT_ID,
|
||||
'handler': ['insecure_example', None],
|
||||
'redirect_uri': CLIENT_REDIRECT_URI
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
|
||||
# Incorrect username
|
||||
resp = await client.post(
|
||||
'/auth/login_flow/{}'.format(step['flow_id']), json={
|
||||
'client_id': CLIENT_ID,
|
||||
'username': 'wrong-user',
|
||||
'password': 'test-pass',
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
@ -56,9 +50,10 @@ async def test_invalid_username_password(hass, aiohttp_client):
|
||||
# Incorrect password
|
||||
resp = await client.post(
|
||||
'/auth/login_flow/{}'.format(step['flow_id']), json={
|
||||
'client_id': CLIENT_ID,
|
||||
'username': 'test-user',
|
||||
'password': 'wrong-pass',
|
||||
}, auth=CLIENT_AUTH)
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
|
@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockUser
|
||||
from tests.common import MockUser, CLIENT_ID
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -28,11 +28,6 @@ def hass_ws_client(aiohttp_client):
|
||||
def hass_access_token(hass):
|
||||
"""Return an access token to access Home Assistant."""
|
||||
user = MockUser().add_to_hass(hass)
|
||||
client = hass.loop.run_until_complete(hass.auth.async_create_client(
|
||||
'Access Token Fixture',
|
||||
redirect_uris=['/'],
|
||||
no_secret=True,
|
||||
))
|
||||
refresh_token = hass.loop.run_until_complete(
|
||||
hass.auth.async_create_refresh_token(user, client))
|
||||
hass.auth.async_create_refresh_token(user, CLIENT_ID))
|
||||
yield hass.auth.async_create_access_token(refresh_token)
|
||||
|
@ -6,7 +6,8 @@ import pytest
|
||||
|
||||
from homeassistant import auth, data_entry_flow
|
||||
from homeassistant.util import dt as dt_util
|
||||
from tests.common import MockUser, ensure_auth_manager_loaded, flush_store
|
||||
from tests.common import (
|
||||
MockUser, ensure_auth_manager_loaded, flush_store, CLIENT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -181,10 +182,7 @@ async def test_saving_loading(hass, hass_storage):
|
||||
})
|
||||
user = await manager.async_get_or_create_user(step['result'])
|
||||
|
||||
client = await manager.async_create_client(
|
||||
'test', redirect_uris=['https://example.com'])
|
||||
|
||||
refresh_token = await manager.async_create_refresh_token(user, client)
|
||||
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
|
||||
manager.async_create_access_token(refresh_token)
|
||||
|
||||
@ -195,10 +193,6 @@ async def test_saving_loading(hass, hass_storage):
|
||||
assert len(users) == 1
|
||||
assert users[0] == user
|
||||
|
||||
clients = await store2.async_get_clients()
|
||||
assert len(clients) == 1
|
||||
assert clients[0] == client
|
||||
|
||||
|
||||
def test_access_token_expired():
|
||||
"""Test that the expired property on access tokens work."""
|
||||
@ -225,11 +219,10 @@ def test_access_token_expired():
|
||||
async def test_cannot_retrieve_expired_access_token(hass):
|
||||
"""Test that we cannot retrieve expired access tokens."""
|
||||
manager = await auth.auth_manager_from_config(hass, [])
|
||||
client = await manager.async_create_client('test')
|
||||
user = MockUser().add_to_auth_manager(manager)
|
||||
refresh_token = await manager.async_create_refresh_token(user, client)
|
||||
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
assert refresh_token.user.id is user.id
|
||||
assert refresh_token.client_id is client.id
|
||||
assert refresh_token.client_id == CLIENT_ID
|
||||
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
assert manager.async_get_access_token(access_token.token) is access_token
|
||||
@ -242,19 +235,6 @@ async def test_cannot_retrieve_expired_access_token(hass):
|
||||
assert manager.async_get_access_token(access_token.token) is None
|
||||
|
||||
|
||||
async def test_get_or_create_client(hass):
|
||||
"""Test that get_or_create_client works."""
|
||||
manager = await auth.auth_manager_from_config(hass, [])
|
||||
|
||||
client1 = await manager.async_get_or_create_client(
|
||||
'Test Client', redirect_uris=['https://test.com/1'])
|
||||
assert client1.name is 'Test Client'
|
||||
|
||||
client2 = await manager.async_get_or_create_client(
|
||||
'Test Client', redirect_uris=['https://test.com/1'])
|
||||
assert client2.id is client1.id
|
||||
|
||||
|
||||
async def test_generating_system_user(hass):
|
||||
"""Test that we can add a system user."""
|
||||
manager = await auth.auth_manager_from_config(hass, [])
|
||||
@ -274,10 +254,9 @@ async def test_refresh_token_requires_client_for_user(hass):
|
||||
with pytest.raises(ValueError):
|
||||
await manager.async_create_refresh_token(user)
|
||||
|
||||
client = await manager.async_get_or_create_client('Test client')
|
||||
token = await manager.async_create_refresh_token(user, client)
|
||||
token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
assert token is not None
|
||||
assert token.client_id == client.id
|
||||
assert token.client_id == CLIENT_ID
|
||||
|
||||
|
||||
async def test_refresh_token_not_requires_client_for_system_user(hass):
|
||||
@ -285,10 +264,9 @@ async def test_refresh_token_not_requires_client_for_system_user(hass):
|
||||
manager = await auth.auth_manager_from_config(hass, [])
|
||||
user = await manager.async_create_system_user('Hass.io')
|
||||
assert user.system_generated is True
|
||||
client = await manager.async_get_or_create_client('Test client')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await manager.async_create_refresh_token(user, client)
|
||||
await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
|
||||
token = await manager.async_create_refresh_token(user)
|
||||
assert token is not None
|
||||
|
Loading…
x
Reference in New Issue
Block a user