Use IndieAuth for client ID (#15369)

* Use IndieAuth for client ID

* Lint

* Lint & Fix tests

* Allow local IP addresses

* Update comment
This commit is contained in:
Paulus Schoutsen 2018-07-09 18:24:46 +02:00 committed by GitHub
parent f7d7d825b0
commit 0d4841cbea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 329 additions and 347 deletions

View File

@ -186,16 +186,6 @@ class Credentials:
is_new = attr.ib(type=bool, default=True)
@attr.s(slots=True)
class Client:
"""Client that interacts with Home Assistant on behalf of a user."""
name = attr.ib(type=str)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
secret = attr.ib(type=str, default=attr.Factory(generate_secret))
redirect_uris = attr.ib(type=list, default=attr.Factory(list))
async def load_auth_provider_module(hass, provider):
"""Load an auth provider."""
try:
@ -356,20 +346,20 @@ class AuthManager:
"""Remove a user."""
await self._store.async_remove_user(user)
async def async_create_refresh_token(self, user, client=None):
async def async_create_refresh_token(self, user, client_id=None):
"""Create a new refresh token for a user."""
if not user.is_active:
raise ValueError('User is not active')
if user.system_generated and client is not None:
if user.system_generated and client_id is not None:
raise ValueError(
'System generated users cannot have refresh tokens connected '
'to a client.')
if not user.system_generated and client is None:
if not user.system_generated and client_id is None:
raise ValueError('Client is required to generate a refresh token.')
return await self._store.async_create_refresh_token(user, client)
return await self._store.async_create_refresh_token(user, client_id)
async def async_get_refresh_token(self, token):
"""Get refresh token by token."""
@ -396,26 +386,6 @@ class AuthManager:
return tkn
async def async_create_client(self, name, *, redirect_uris=None,
no_secret=False):
"""Create a new client."""
return await self._store.async_create_client(
name, redirect_uris, no_secret)
async def async_get_or_create_client(self, name, *, redirect_uris=None,
no_secret=False):
"""Find a client, if not exists, create a new one."""
for client in await self._store.async_get_clients():
if client.name == name:
return client
return await self._store.async_create_client(
name, redirect_uris, no_secret)
async def async_get_client(self, client_id):
"""Get a client."""
return await self._store.async_get_client(client_id)
async def _async_create_login_flow(self, handler, *, source, data):
"""Create a login flow."""
auth_provider = self._providers[handler]
@ -456,7 +426,6 @@ class AuthStore:
"""Initialize the auth store."""
self.hass = hass
self._users = None
self._clients = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
async def async_get_users(self):
@ -515,9 +484,8 @@ class AuthStore:
self._users.pop(user.id)
await self.async_save()
async def async_create_refresh_token(self, user, client=None):
async def async_create_refresh_token(self, user, client_id=None):
"""Create a new token for a user."""
client_id = client.id if client is not None else None
refresh_token = RefreshToken(user=user, client_id=client_id)
user.refresh_tokens[refresh_token.token] = refresh_token
await self.async_save()
@ -535,38 +503,6 @@ class AuthStore:
return None
async def async_create_client(self, name, redirect_uris, no_secret):
"""Create a new client."""
if self._clients is None:
await self.async_load()
kwargs = {
'name': name,
'redirect_uris': redirect_uris
}
if no_secret:
kwargs['secret'] = None
client = Client(**kwargs)
self._clients[client.id] = client
await self.async_save()
return client
async def async_get_clients(self):
"""Return all clients."""
if self._clients is None:
await self.async_load()
return list(self._clients.values())
async def async_get_client(self, client_id):
"""Get a client."""
if self._clients is None:
await self.async_load()
return self._clients.get(client_id)
async def async_load(self):
"""Load the users."""
data = await self._store.async_load()
@ -578,7 +514,6 @@ class AuthStore:
if data is None:
self._users = {}
self._clients = {}
return
users = {
@ -618,12 +553,7 @@ class AuthStore:
)
refresh_token.access_tokens.append(token)
clients = {
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
}
self._users = users
self._clients = clients
async def async_save(self):
"""Save users."""
@ -676,19 +606,8 @@ class AuthStore:
for access_token in refresh_token.access_tokens
]
clients = [
{
'id': client.id,
'name': client.name,
'secret': client.secret,
'redirect_uris': client.redirect_uris,
}
for client in self._clients.values()
]
data = {
'users': users,
'clients': clients,
'credentials': credentials,
'access_tokens': access_tokens,
'refresh_tokens': refresh_tokens,

View File

@ -115,7 +115,8 @@ from homeassistant.helpers.data_entry_flow import (
from homeassistant.components.http.view import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator
from .client import verify_client
from . import indieauth
DOMAIN = 'auth'
DEPENDENCIES = ['http']
@ -143,8 +144,7 @@ class AuthProvidersView(HomeAssistantView):
name = 'api:auth:providers'
requires_auth = False
@verify_client
async def get(self, request, client):
async def get(self, request):
"""Get available auth providers."""
return self.json([{
'name': provider.name,
@ -164,16 +164,16 @@ class LoginFlowIndexView(FlowManagerIndexView):
"""Do not allow index of flows in progress."""
return aiohttp.web.Response(status=405)
# pylint: disable=arguments-differ
@verify_client
@RequestDataValidator(vol.Schema({
vol.Required('client_id'): str,
vol.Required('handler'): vol.Any(str, list),
vol.Required('redirect_uri'): str,
}))
async def post(self, request, client, data):
async def post(self, request, data):
"""Create a new login flow."""
if data['redirect_uri'] not in client.redirect_uris:
return self.json_message('invalid redirect uri', )
if not indieauth.verify_redirect_uri(data['client_id'],
data['redirect_uri']):
return self.json_message('invalid client id or redirect uri', 400)
# pylint: disable=no-value-for-parameter
return await super().post(request)
@ -191,16 +191,20 @@ class LoginFlowResourceView(FlowManagerResourceView):
super().__init__(flow_mgr)
self._store_credentials = store_credentials
# pylint: disable=arguments-differ
async def get(self, request):
async def get(self, request, flow_id):
"""Do not allow getting status of a flow in progress."""
return self.json_message('Invalid flow specified', 404)
# pylint: disable=arguments-differ
@verify_client
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
async def post(self, request, client, flow_id, data):
@RequestDataValidator(vol.Schema({
'client_id': str
}, extra=vol.ALLOW_EXTRA))
async def post(self, request, flow_id, data):
"""Handle progressing a login flow request."""
client_id = data.pop('client_id')
if not indieauth.verify_client_id(client_id):
return self.json_message('Invalid client id', 400)
try:
result = await self._flow_mgr.async_configure(flow_id, data)
except data_entry_flow.UnknownFlow:
@ -212,7 +216,7 @@ class LoginFlowResourceView(FlowManagerResourceView):
return self.json(self._prepare_result_json(result))
result.pop('data')
result['result'] = self._store_credentials(client.id, result['result'])
result['result'] = self._store_credentials(client_id, result['result'])
return self.json(result)
@ -228,24 +232,31 @@ class GrantTokenView(HomeAssistantView):
"""Initialize the grant token view."""
self._retrieve_credentials = retrieve_credentials
@verify_client
async def post(self, request, client):
async def post(self, request):
"""Grant a token."""
hass = request.app['hass']
data = await request.post()
client_id = data.get('client_id')
if client_id is None or not indieauth.verify_client_id(client_id):
return self.json({
'error': 'invalid_request',
}, status_code=400)
grant_type = data.get('grant_type')
if grant_type == 'authorization_code':
return await self._async_handle_auth_code(hass, client, data)
return await self._async_handle_auth_code(hass, client_id, data)
elif grant_type == 'refresh_token':
return await self._async_handle_refresh_token(hass, client, data)
return await self._async_handle_refresh_token(
hass, client_id, data)
return self.json({
'error': 'unsupported_grant_type',
}, status_code=400)
async def _async_handle_auth_code(self, hass, client, data):
async def _async_handle_auth_code(self, hass, client_id, data):
"""Handle authorization code request."""
code = data.get('code')
@ -254,7 +265,7 @@ class GrantTokenView(HomeAssistantView):
'error': 'invalid_request',
}, status_code=400)
credentials = self._retrieve_credentials(client.id, code)
credentials = self._retrieve_credentials(client_id, code)
if credentials is None:
return self.json({
@ -263,7 +274,7 @@ class GrantTokenView(HomeAssistantView):
user = await hass.auth.async_get_or_create_user(credentials)
refresh_token = await hass.auth.async_create_refresh_token(user,
client)
client_id)
access_token = hass.auth.async_create_access_token(refresh_token)
return self.json({
@ -274,7 +285,7 @@ class GrantTokenView(HomeAssistantView):
int(refresh_token.access_token_expiration.total_seconds()),
})
async def _async_handle_refresh_token(self, hass, client, data):
async def _async_handle_refresh_token(self, hass, client_id, data):
"""Handle authorization code request."""
token = data.get('refresh_token')
@ -285,7 +296,7 @@ class GrantTokenView(HomeAssistantView):
refresh_token = await hass.auth.async_get_refresh_token(token)
if refresh_token is None or refresh_token.client_id != client.id:
if refresh_token is None or refresh_token.client_id != client_id:
return self.json({
'error': 'invalid_grant',
}, status_code=400)

View File

@ -1,79 +0,0 @@
"""Helpers to resolve client ID/secret."""
import base64
from functools import wraps
import hmac
import aiohttp.hdrs
def verify_client(method):
"""Decorator to verify client id/secret on requests."""
@wraps(method)
async def wrapper(view, request, *args, **kwargs):
"""Verify client id/secret before doing request."""
client = await _verify_client(request)
if client is None:
return view.json({
'error': 'invalid_client',
}, status_code=401)
return await method(
view, request, *args, **kwargs, client=client)
return wrapper
async def _verify_client(request):
"""Method to verify the client id/secret in consistent time.
By using a consistent time for looking up client id and comparing the
secret, we prevent attacks by malicious actors trying different client ids
and are able to derive from the time it takes to process the request if
they guessed the client id correctly.
"""
if aiohttp.hdrs.AUTHORIZATION not in request.headers:
return None
auth_type, auth_value = \
request.headers.get(aiohttp.hdrs.AUTHORIZATION).split(' ', 1)
if auth_type != 'Basic':
return None
decoded = base64.b64decode(auth_value).decode('utf-8')
try:
client_id, client_secret = decoded.split(':', 1)
except ValueError:
# If no ':' in decoded
client_id, client_secret = decoded, None
return await async_secure_get_client(
request.app['hass'], client_id, client_secret)
async def async_secure_get_client(hass, client_id, client_secret):
"""Get a client id/secret in consistent time."""
client = await hass.auth.async_get_client(client_id)
if client is None:
if client_secret is not None:
# Still do a compare so we run same time as if a client was found.
hmac.compare_digest(client_secret.encode('utf-8'),
client_secret.encode('utf-8'))
return None
if client.secret is None:
return client
elif client_secret is None:
# Still do a compare so we run same time as if a secret was passed.
hmac.compare_digest(client.secret.encode('utf-8'),
client.secret.encode('utf-8'))
return None
elif hmac.compare_digest(client_secret.encode('utf-8'),
client.secret.encode('utf-8')):
return client
return None

View File

@ -0,0 +1,130 @@
"""Helpers to resolve client ID/secret."""
from ipaddress import ip_address, ip_network
from urllib.parse import urlparse
# IP addresses of loopback interfaces
ALLOWED_IPS = (
ip_address('127.0.0.1'),
ip_address('::1'),
)
# RFC1918 - Address allocation for Private Internets
ALLOWED_NETWORKS = (
ip_network('10.0.0.0/8'),
ip_network('172.16.0.0/12'),
ip_network('192.168.0.0/16'),
)
def verify_redirect_uri(client_id, redirect_uri):
"""Verify that the client and redirect uri match."""
try:
client_id_parts = _parse_client_id(client_id)
except ValueError:
return False
redirect_parts = _parse_url(redirect_uri)
# IndieAuth 4.2.2 allows for redirect_uri to be on different domain
# but needs to be specified in link tag when fetching `client_id`.
# This is not implemented.
# Verify redirect url and client url have same scheme and domain.
return (
client_id_parts.scheme == redirect_parts.scheme and
client_id_parts.netloc == redirect_parts.netloc
)
def verify_client_id(client_id):
"""Verify that the client id is valid."""
try:
_parse_client_id(client_id)
return True
except ValueError:
return False
def _parse_url(url):
"""Parse a url in parts and canonicalize according to IndieAuth."""
parts = urlparse(url)
# Canonicalize a url according to IndieAuth 3.2.
# SHOULD convert the hostname to lowercase
parts = parts._replace(netloc=parts.netloc.lower())
# If a URL with no path component is ever encountered,
# it MUST be treated as if it had the path /.
if parts.path == '':
parts = parts._replace(path='/')
return parts
def _parse_client_id(client_id):
"""Test if client id is a valid URL according to IndieAuth section 3.2.
https://indieauth.spec.indieweb.org/#client-identifier
"""
parts = _parse_url(client_id)
# Client identifier URLs
# MUST have either an https or http scheme
if parts.scheme not in ('http', 'https'):
raise ValueError()
# MUST contain a path component
# Handled by url canonicalization.
# MUST NOT contain single-dot or double-dot path segments
if any(segment in ('.', '..') for segment in parts.path.split('/')):
raise ValueError(
'Client ID cannot contain single-dot or double-dot path segments')
# MUST NOT contain a fragment component
if parts.fragment != '':
raise ValueError('Client ID cannot contain a fragment')
# MUST NOT contain a username or password component
if parts.username is not None:
raise ValueError('Client ID cannot contain username')
if parts.password is not None:
raise ValueError('Client ID cannot contain password')
# MAY contain a port
try:
# parts raises ValueError when port cannot be parsed as int
parts.port
except ValueError:
raise ValueError('Client ID contains invalid port')
# Additionally, hostnames
# MUST be domain names or a loopback interface and
# MUST NOT be IPv4 or IPv6 addresses except for IPv4 127.0.0.1
# or IPv6 [::1]
# We are not goint to follow the spec here. We are going to allow
# any internal network IP to be used inside a client id.
address = None
try:
netloc = parts.netloc
# Strip the [, ] from ipv6 addresses before parsing
if netloc[0] == '[' and netloc[-1] == ']':
netloc = netloc[1:-1]
address = ip_address(netloc)
except ValueError:
# Not an ip address
pass
if (address is None or
address in ALLOWED_IPS or
any(address in network for network in ALLOWED_NETWORKS)):
return parts
raise ValueError('Hostname should be a domain name or local IP address')

View File

@ -200,15 +200,6 @@ def add_manifest_json_key(key, val):
async def async_setup(hass, config):
"""Set up the serving of the frontend."""
if hass.auth.active:
client = await hass.auth.async_get_or_create_client(
'Home Assistant Frontend',
redirect_uris=['/'],
no_secret=True,
)
else:
client = None
hass.components.websocket_api.async_register_command(
WS_TYPE_GET_PANELS, websocket_get_panels, SCHEMA_GET_PANELS)
hass.components.websocket_api.async_register_command(
@ -255,7 +246,7 @@ async def async_setup(hass, config):
if os.path.isdir(local):
hass.http.register_static_path("/local", local, not is_dev)
index_view = IndexView(repo_path, js_version, client)
index_view = IndexView(repo_path, js_version, hass.auth.active)
hass.http.register_view(index_view)
@callback
@ -350,11 +341,11 @@ class IndexView(HomeAssistantView):
requires_auth = False
extra_urls = ['/states', '/states/{extra}']
def __init__(self, repo_path, js_option, client):
def __init__(self, repo_path, js_option, auth_active):
"""Initialize the frontend view."""
self.repo_path = repo_path
self.js_option = js_option
self.client = client
self.auth_active = auth_active
self._template_cache = {}
def get_template(self, latest):
@ -399,11 +390,9 @@ class IndexView(HomeAssistantView):
no_auth=no_auth,
theme_color=MANIFEST_JSON['theme_color'],
extra_urls=hass.data[extra_key],
client_id=self.auth_active
)
if self.client is not None:
template_params['client_id'] = self.client.id
return web.Response(text=template.render(**template_params),
content_type='text/html')

View File

@ -31,6 +31,8 @@ from homeassistant.util.async_ import (
_TEST_INSTANCE_PORT = SERVER_PORT
_LOGGER = logging.getLogger(__name__)
INSTANCES = []
CLIENT_ID = 'https://example.com/app'
CLIENT_REDIRECT_URI = 'https://example.com/app/callback'
def threadsafe_callback_factory(func):
@ -330,8 +332,6 @@ class MockUser(auth.User):
def ensure_auth_manager_loaded(auth_mgr):
"""Ensure an auth manager is considered loaded."""
store = auth_mgr._store
if store._clients is None:
store._clients = {}
if store._users is None:
store._users = {}

View File

@ -1,6 +1,4 @@
"""Tests for the auth component."""
from aiohttp.helpers import BasicAuth
from homeassistant import auth
from homeassistant.setup import async_setup_component
@ -16,10 +14,6 @@ BASE_CONFIG = [{
'name': 'Test Name'
}]
}]
CLIENT_ID = 'test-id'
CLIENT_SECRET = 'test-secret'
CLIENT_AUTH = BasicAuth(CLIENT_ID, CLIENT_SECRET)
CLIENT_REDIRECT_URI = 'http://example.com/callback'
async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
@ -32,9 +26,6 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
'api_password': 'bla'
}
})
client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET,
redirect_uris=[CLIENT_REDIRECT_URI])
hass.auth._store._clients[client.id] = client
if setup_api:
await async_setup_component(hass, 'api', {})
return await aiohttp_client(hass.http.app)

View File

@ -1,70 +0,0 @@
"""Tests for the client validator."""
from aiohttp.helpers import BasicAuth
import pytest
from homeassistant.setup import async_setup_component
from homeassistant.components.auth.client import verify_client
from homeassistant.components.http.view import HomeAssistantView
from . import async_setup_auth
@pytest.fixture
def mock_view(hass):
"""Register a view that verifies client id/secret."""
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
clients = []
class ClientView(HomeAssistantView):
url = '/'
name = 'bla'
@verify_client
async def get(self, request, client):
"""Handle GET request."""
clients.append(client)
hass.http.register_view(ClientView)
return clients
async def test_verify_client(hass, aiohttp_client, mock_view):
"""Test that verify client can extract client auth from a request."""
http_client = await async_setup_auth(hass, aiohttp_client)
client = await hass.auth.async_create_client('Hello')
resp = await http_client.get('/', auth=BasicAuth(client.id, client.secret))
assert resp.status == 200
assert mock_view[0] is client
async def test_verify_client_no_auth_header(hass, aiohttp_client, mock_view):
"""Test that verify client will decline unknown client id."""
http_client = await async_setup_auth(hass, aiohttp_client)
resp = await http_client.get('/')
assert resp.status == 401
assert mock_view == []
async def test_verify_client_invalid_client_id(hass, aiohttp_client,
mock_view):
"""Test that verify client will decline unknown client id."""
http_client = await async_setup_auth(hass, aiohttp_client)
client = await hass.auth.async_create_client('Hello')
resp = await http_client.get('/', auth=BasicAuth('invalid', client.secret))
assert resp.status == 401
assert mock_view == []
async def test_verify_client_invalid_client_secret(hass, aiohttp_client,
mock_view):
"""Test that verify client will decline incorrect client secret."""
http_client = await async_setup_auth(hass, aiohttp_client)
client = await hass.auth.async_create_client('Hello')
resp = await http_client.get('/', auth=BasicAuth(client.id, 'invalid'))
assert resp.status == 401
assert mock_view == []

View File

@ -0,0 +1,110 @@
"""Tests for the client validator."""
from homeassistant.components.auth import indieauth
import pytest
def test_client_id_scheme():
"""Test we enforce valid scheme."""
assert indieauth._parse_client_id('http://ex.com/')
assert indieauth._parse_client_id('https://ex.com/')
with pytest.raises(ValueError):
indieauth._parse_client_id('ftp://ex.com')
def test_client_id_path():
"""Test we enforce valid path."""
assert indieauth._parse_client_id('http://ex.com').path == '/'
assert indieauth._parse_client_id('http://ex.com/hello').path == '/hello'
assert indieauth._parse_client_id(
'http://ex.com/hello/.world').path == '/hello/.world'
assert indieauth._parse_client_id(
'http://ex.com/hello./.world').path == '/hello./.world'
with pytest.raises(ValueError):
indieauth._parse_client_id('http://ex.com/.')
with pytest.raises(ValueError):
indieauth._parse_client_id('http://ex.com/hello/./yo')
with pytest.raises(ValueError):
indieauth._parse_client_id('http://ex.com/hello/../yo')
def test_client_id_fragment():
"""Test we enforce valid fragment."""
with pytest.raises(ValueError):
indieauth._parse_client_id('http://ex.com/#yoo')
def test_client_id_user_pass():
"""Test we enforce valid username/password."""
with pytest.raises(ValueError):
indieauth._parse_client_id('http://user@ex.com/')
with pytest.raises(ValueError):
indieauth._parse_client_id('http://user:pass@ex.com/')
def test_client_id_hostname():
"""Test we enforce valid hostname."""
assert indieauth._parse_client_id('http://www.home-assistant.io/')
assert indieauth._parse_client_id('http://[::1]')
assert indieauth._parse_client_id('http://127.0.0.1')
assert indieauth._parse_client_id('http://10.0.0.0')
assert indieauth._parse_client_id('http://10.255.255.255')
assert indieauth._parse_client_id('http://172.16.0.0')
assert indieauth._parse_client_id('http://172.31.255.255')
assert indieauth._parse_client_id('http://192.168.0.0')
assert indieauth._parse_client_id('http://192.168.255.255')
with pytest.raises(ValueError):
assert indieauth._parse_client_id('http://255.255.255.255/')
with pytest.raises(ValueError):
assert indieauth._parse_client_id('http://11.0.0.0/')
with pytest.raises(ValueError):
assert indieauth._parse_client_id('http://172.32.0.0/')
with pytest.raises(ValueError):
assert indieauth._parse_client_id('http://192.167.0.0/')
def test_parse_url_lowercase_host():
"""Test we update empty paths."""
assert indieauth._parse_url('http://ex.com/hello').path == '/hello'
assert indieauth._parse_url('http://EX.COM/hello').hostname == 'ex.com'
parts = indieauth._parse_url('http://EX.COM:123/HELLO')
assert parts.netloc == 'ex.com:123'
assert parts.path == '/HELLO'
def test_parse_url_path():
"""Test we update empty paths."""
assert indieauth._parse_url('http://ex.com').path == '/'
def test_verify_redirect_uri():
"""Test that we verify redirect uri correctly."""
assert indieauth.verify_redirect_uri(
'http://ex.com',
'http://ex.com/callback'
)
# Different domain
assert not indieauth.verify_redirect_uri(
'http://ex.com',
'http://different.com/callback'
)
# Different scheme
assert not indieauth.verify_redirect_uri(
'http://ex.com',
'https://ex.com/callback'
)
# Different subdomain
assert not indieauth.verify_redirect_uri(
'https://sub1.ex.com',
'https://sub2.ex.com/callback'
)

View File

@ -1,22 +1,26 @@
"""Integration tests for the auth component."""
from . import async_setup_auth, CLIENT_AUTH, CLIENT_REDIRECT_URI
from . import async_setup_auth
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
async def test_login_new_user_and_refresh_token(hass, aiohttp_client):
"""Test logging in with new user and refreshing tokens."""
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
resp = await client.post('/auth/login_flow', json={
'client_id': CLIENT_ID,
'handler': ['insecure_example', None],
'redirect_uri': CLIENT_REDIRECT_URI,
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
step = await resp.json()
resp = await client.post(
'/auth/login_flow/{}'.format(step['flow_id']), json={
'client_id': CLIENT_ID,
'username': 'test-user',
'password': 'test-pass',
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
step = await resp.json()
@ -24,9 +28,10 @@ async def test_login_new_user_and_refresh_token(hass, aiohttp_client):
# Exchange code for tokens
resp = await client.post('/auth/token', data={
'client_id': CLIENT_ID,
'grant_type': 'authorization_code',
'code': code
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
tokens = await resp.json()
@ -35,9 +40,10 @@ async def test_login_new_user_and_refresh_token(hass, aiohttp_client):
# Use refresh token to get more tokens.
resp = await client.post('/auth/token', data={
'client_id': CLIENT_ID,
'grant_type': 'refresh_token',
'refresh_token': tokens['refresh_token']
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
tokens = await resp.json()

View File

@ -1,5 +1,7 @@
"""Tests for the link user flow."""
from . import async_setup_auth, CLIENT_AUTH, CLIENT_ID, CLIENT_REDIRECT_URI
from . import async_setup_auth
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
async def async_get_code(hass, aiohttp_client):
@ -25,17 +27,19 @@ async def async_get_code(hass, aiohttp_client):
client = await async_setup_auth(hass, aiohttp_client, config)
resp = await client.post('/auth/login_flow', json={
'client_id': CLIENT_ID,
'handler': ['insecure_example', None],
'redirect_uri': CLIENT_REDIRECT_URI,
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
step = await resp.json()
resp = await client.post(
'/auth/login_flow/{}'.format(step['flow_id']), json={
'client_id': CLIENT_ID,
'username': 'test-user',
'password': 'test-pass',
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
step = await resp.json()
@ -43,9 +47,10 @@ async def async_get_code(hass, aiohttp_client):
# Exchange code for tokens
resp = await client.post('/auth/token', data={
'client_id': CLIENT_ID,
'grant_type': 'authorization_code',
'code': code
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
tokens = await resp.json()
@ -57,17 +62,19 @@ async def async_get_code(hass, aiohttp_client):
# Now authenticate with the 2nd flow
resp = await client.post('/auth/login_flow', json={
'client_id': CLIENT_ID,
'handler': ['insecure_example', '2nd auth'],
'redirect_uri': CLIENT_REDIRECT_URI,
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
step = await resp.json()
resp = await client.post(
'/auth/login_flow/{}'.format(step['flow_id']), json={
'client_id': CLIENT_ID,
'username': '2nd-user',
'password': '2nd-pass',
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
step = await resp.json()

View File

@ -1,13 +1,13 @@
"""Tests for the login flow."""
from aiohttp.helpers import BasicAuth
from . import async_setup_auth
from . import async_setup_auth, CLIENT_AUTH, CLIENT_REDIRECT_URI
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
async def test_fetch_auth_providers(hass, aiohttp_client):
"""Test fetching auth providers."""
client = await async_setup_auth(hass, aiohttp_client)
resp = await client.get('/auth/providers', auth=CLIENT_AUTH)
resp = await client.get('/auth/providers')
assert await resp.json() == [{
'name': 'Example',
'type': 'insecure_example',
@ -15,14 +15,6 @@ async def test_fetch_auth_providers(hass, aiohttp_client):
}]
async def test_fetch_auth_providers_require_valid_client(hass, aiohttp_client):
"""Test fetching auth providers."""
client = await async_setup_auth(hass, aiohttp_client)
resp = await client.get('/auth/providers',
auth=BasicAuth('invalid', 'bla'))
assert resp.status == 401
async def test_cannot_get_flows_in_progress(hass, aiohttp_client):
"""Test we cannot get flows in progress."""
client = await async_setup_auth(hass, aiohttp_client, [])
@ -34,18 +26,20 @@ async def test_invalid_username_password(hass, aiohttp_client):
"""Test we cannot get flows in progress."""
client = await async_setup_auth(hass, aiohttp_client)
resp = await client.post('/auth/login_flow', json={
'client_id': CLIENT_ID,
'handler': ['insecure_example', None],
'redirect_uri': CLIENT_REDIRECT_URI
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
step = await resp.json()
# Incorrect username
resp = await client.post(
'/auth/login_flow/{}'.format(step['flow_id']), json={
'client_id': CLIENT_ID,
'username': 'wrong-user',
'password': 'test-pass',
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
step = await resp.json()
@ -56,9 +50,10 @@ async def test_invalid_username_password(hass, aiohttp_client):
# Incorrect password
resp = await client.post(
'/auth/login_flow/{}'.format(step['flow_id']), json={
'client_id': CLIENT_ID,
'username': 'test-user',
'password': 'wrong-pass',
}, auth=CLIENT_AUTH)
})
assert resp.status == 200
step = await resp.json()

View File

@ -3,7 +3,7 @@ import pytest
from homeassistant.setup import async_setup_component
from tests.common import MockUser
from tests.common import MockUser, CLIENT_ID
@pytest.fixture
@ -28,11 +28,6 @@ def hass_ws_client(aiohttp_client):
def hass_access_token(hass):
"""Return an access token to access Home Assistant."""
user = MockUser().add_to_hass(hass)
client = hass.loop.run_until_complete(hass.auth.async_create_client(
'Access Token Fixture',
redirect_uris=['/'],
no_secret=True,
))
refresh_token = hass.loop.run_until_complete(
hass.auth.async_create_refresh_token(user, client))
hass.auth.async_create_refresh_token(user, CLIENT_ID))
yield hass.auth.async_create_access_token(refresh_token)

View File

@ -6,7 +6,8 @@ import pytest
from homeassistant import auth, data_entry_flow
from homeassistant.util import dt as dt_util
from tests.common import MockUser, ensure_auth_manager_loaded, flush_store
from tests.common import (
MockUser, ensure_auth_manager_loaded, flush_store, CLIENT_ID)
@pytest.fixture
@ -181,10 +182,7 @@ async def test_saving_loading(hass, hass_storage):
})
user = await manager.async_get_or_create_user(step['result'])
client = await manager.async_create_client(
'test', redirect_uris=['https://example.com'])
refresh_token = await manager.async_create_refresh_token(user, client)
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
manager.async_create_access_token(refresh_token)
@ -195,10 +193,6 @@ async def test_saving_loading(hass, hass_storage):
assert len(users) == 1
assert users[0] == user
clients = await store2.async_get_clients()
assert len(clients) == 1
assert clients[0] == client
def test_access_token_expired():
"""Test that the expired property on access tokens work."""
@ -225,11 +219,10 @@ def test_access_token_expired():
async def test_cannot_retrieve_expired_access_token(hass):
"""Test that we cannot retrieve expired access tokens."""
manager = await auth.auth_manager_from_config(hass, [])
client = await manager.async_create_client('test')
user = MockUser().add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(user, client)
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
assert refresh_token.user.id is user.id
assert refresh_token.client_id is client.id
assert refresh_token.client_id == CLIENT_ID
access_token = manager.async_create_access_token(refresh_token)
assert manager.async_get_access_token(access_token.token) is access_token
@ -242,19 +235,6 @@ async def test_cannot_retrieve_expired_access_token(hass):
assert manager.async_get_access_token(access_token.token) is None
async def test_get_or_create_client(hass):
"""Test that get_or_create_client works."""
manager = await auth.auth_manager_from_config(hass, [])
client1 = await manager.async_get_or_create_client(
'Test Client', redirect_uris=['https://test.com/1'])
assert client1.name is 'Test Client'
client2 = await manager.async_get_or_create_client(
'Test Client', redirect_uris=['https://test.com/1'])
assert client2.id is client1.id
async def test_generating_system_user(hass):
"""Test that we can add a system user."""
manager = await auth.auth_manager_from_config(hass, [])
@ -274,10 +254,9 @@ async def test_refresh_token_requires_client_for_user(hass):
with pytest.raises(ValueError):
await manager.async_create_refresh_token(user)
client = await manager.async_get_or_create_client('Test client')
token = await manager.async_create_refresh_token(user, client)
token = await manager.async_create_refresh_token(user, CLIENT_ID)
assert token is not None
assert token.client_id == client.id
assert token.client_id == CLIENT_ID
async def test_refresh_token_not_requires_client_for_system_user(hass):
@ -285,10 +264,9 @@ async def test_refresh_token_not_requires_client_for_system_user(hass):
manager = await auth.auth_manager_from_config(hass, [])
user = await manager.async_create_system_user('Hass.io')
assert user.system_generated is True
client = await manager.async_get_or_create_client('Test client')
with pytest.raises(ValueError):
await manager.async_create_refresh_token(user, client)
await manager.async_create_refresh_token(user, CLIENT_ID)
token = await manager.async_create_refresh_token(user)
assert token is not None