mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 09:17:53 +00:00
Add trusted networks auth provider (#15812)
* Add context to login flow * Add trusted networks auth provider * source -> context
This commit is contained in:
parent
50daef9a52
commit
da8f93dca2
125
homeassistant/auth/providers/trusted_networks.py
Normal file
125
homeassistant/auth/providers/trusted_networks.py
Normal file
@ -0,0 +1,125 @@
|
||||
"""Trusted Networks auth provider.
|
||||
|
||||
It shows list of users if access from trusted network.
|
||||
Abort login flow if not access from trusted network.
|
||||
"""
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
|
||||
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||
}, extra=vol.PREVENT_EXTRA)
|
||||
|
||||
|
||||
class InvalidAuthError(HomeAssistantError):
|
||||
"""Raised when try to access from untrusted networks."""
|
||||
|
||||
|
||||
class InvalidUserError(HomeAssistantError):
|
||||
"""Raised when try to login as invalid user."""
|
||||
|
||||
|
||||
@AUTH_PROVIDERS.register('trusted_networks')
|
||||
class TrustedNetworksAuthProvider(AuthProvider):
|
||||
"""Trusted Networks auth provider.
|
||||
|
||||
Allow passwordless access from trusted network.
|
||||
"""
|
||||
|
||||
DEFAULT_TITLE = 'Trusted Networks'
|
||||
|
||||
async def async_credential_flow(self, context):
|
||||
"""Return a flow to login."""
|
||||
users = await self.store.async_get_users()
|
||||
available_users = {user.id: user.name
|
||||
for user in users
|
||||
if not user.system_generated and user.is_active}
|
||||
|
||||
return LoginFlow(self, context.get('ip_address'), available_users)
|
||||
|
||||
async def async_get_or_create_credentials(self, flow_result):
|
||||
"""Get credentials based on the flow result."""
|
||||
user_id = flow_result['user']
|
||||
|
||||
users = await self.store.async_get_users()
|
||||
for user in users:
|
||||
if (not user.system_generated and
|
||||
user.is_active and
|
||||
user.id == user_id):
|
||||
for credential in await self.async_credentials():
|
||||
if credential.data['user_id'] == user_id:
|
||||
return credential
|
||||
cred = self.async_create_credentials({'user_id': user_id})
|
||||
await self.store.async_link_user(user, cred)
|
||||
return cred
|
||||
|
||||
# We only allow login as exist user
|
||||
raise InvalidUserError
|
||||
|
||||
async def async_user_meta_for_credentials(self, credentials):
|
||||
"""Return extra user metadata for credentials.
|
||||
|
||||
Trusted network auth provider should never create new user.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@callback
|
||||
def async_validate_access(self, ip_address):
|
||||
"""Make sure the access from trusted networks.
|
||||
|
||||
Raise InvalidAuthError if not.
|
||||
Raise InvalidAuthError if trusted_networks is not config
|
||||
"""
|
||||
if (not hasattr(self.hass, 'http') or
|
||||
not self.hass.http or not self.hass.http.trusted_networks):
|
||||
raise InvalidAuthError('trusted_networks is not configured')
|
||||
|
||||
if not any(ip_address in trusted_network for trusted_network
|
||||
in self.hass.http.trusted_networks):
|
||||
raise InvalidAuthError('Not in trusted_networks')
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
def __init__(self, auth_provider, ip_address, available_users):
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
self._available_users = available_users
|
||||
self._ip_address = ip_address
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
try:
|
||||
self._auth_provider.async_validate_access(self._ip_address)
|
||||
|
||||
except InvalidAuthError:
|
||||
errors['base'] = 'invalid_auth'
|
||||
return self.async_show_form(
|
||||
step_id='init',
|
||||
data_schema=None,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
if user_input is not None:
|
||||
user_id = user_input['user']
|
||||
if user_id not in self._available_users:
|
||||
errors['base'] = 'invalid_auth'
|
||||
|
||||
if not errors:
|
||||
return self.async_create_entry(
|
||||
title=self._auth_provider.name,
|
||||
data=user_input
|
||||
)
|
||||
|
||||
schema = {'user': vol.In(self._available_users)}
|
||||
|
||||
return self.async_show_form(
|
||||
step_id='init',
|
||||
data_schema=vol.Schema(schema),
|
||||
errors=errors,
|
||||
)
|
@ -63,6 +63,7 @@ import aiohttp.web
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.components.http import KEY_REAL_IP
|
||||
from homeassistant.components.http.ban import process_wrong_login, \
|
||||
log_invalid_auth
|
||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
@ -151,7 +152,8 @@ class LoginFlowIndexView(HomeAssistantView):
|
||||
handler = data['handler']
|
||||
|
||||
try:
|
||||
result = await self._flow_mgr.async_init(handler, context={})
|
||||
result = await self._flow_mgr.async_init(
|
||||
handler, context={'ip_address': request[KEY_REAL_IP]})
|
||||
except data_entry_flow.UnknownHandler:
|
||||
return self.json_message('Invalid handler specified', 404)
|
||||
except data_entry_flow.UnknownStep:
|
||||
@ -188,6 +190,13 @@ class LoginFlowResourceView(HomeAssistantView):
|
||||
return self.json_message('Invalid client id', 400)
|
||||
|
||||
try:
|
||||
# do not allow change ip during login flow
|
||||
for flow in self._flow_mgr.async_progress():
|
||||
if (flow['flow_id'] == flow_id and
|
||||
flow['context']['ip_address'] !=
|
||||
request.get(KEY_REAL_IP)):
|
||||
return self.json_message('IP address changed', 400)
|
||||
|
||||
result = await self._flow_mgr.async_configure(flow_id, data)
|
||||
except data_entry_flow.UnknownFlow:
|
||||
return self.json_message('Invalid flow specified', 404)
|
||||
|
@ -220,6 +220,7 @@ class HomeAssistantHTTP:
|
||||
self.ssl_key = ssl_key
|
||||
self.server_host = server_host
|
||||
self.server_port = server_port
|
||||
self.trusted_networks = trusted_networks
|
||||
self.is_ban_enabled = is_ban_enabled
|
||||
self._handler = None
|
||||
self.server = None
|
||||
|
@ -344,7 +344,10 @@ class ActiveConnection:
|
||||
if request[KEY_AUTHENTICATED]:
|
||||
authenticated = True
|
||||
|
||||
else:
|
||||
# always request auth when auth is active
|
||||
# even request passed pre-authentication (trusted networks)
|
||||
# or when using legacy api_password
|
||||
if self.hass.auth.active or not authenticated:
|
||||
self.debug("Request auth")
|
||||
await self.wsock.send_json(auth_required_message())
|
||||
msg = await wsock.receive_json()
|
||||
|
@ -63,18 +63,9 @@ async def test_verify_not_load(hass, provider):
|
||||
|
||||
|
||||
async def test_verify_login(hass, provider):
|
||||
"""Test we raise if http module not load."""
|
||||
"""Test login using legacy api password auth provider."""
|
||||
hass.http = Mock(api_password='test-password')
|
||||
provider.async_validate_login('test-password')
|
||||
hass.http = Mock(api_password='test-password')
|
||||
with pytest.raises(legacy_api_password.InvalidAuthError):
|
||||
provider.async_validate_login('invalid-password')
|
||||
|
||||
|
||||
async def test_utf_8_username_password(provider):
|
||||
"""Test that we create a new credential."""
|
||||
credentials = await provider.async_get_or_create_credentials({
|
||||
'username': '🎉',
|
||||
'password': '😎',
|
||||
})
|
||||
assert credentials.is_new is True
|
||||
|
106
tests/auth/providers/test_trusted_networks.py
Normal file
106
tests/auth/providers/test_trusted_networks.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Test the Trusted Networks auth provider."""
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import auth
|
||||
from homeassistant.auth import auth_store
|
||||
from homeassistant.auth.providers import trusted_networks as tn_auth
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(hass):
|
||||
"""Mock store."""
|
||||
return auth_store.AuthStore(hass)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(hass, store):
|
||||
"""Mock provider."""
|
||||
return tn_auth.TrustedNetworksAuthProvider(hass, store, {
|
||||
'type': 'trusted_networks'
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(hass, store, provider):
|
||||
"""Mock manager."""
|
||||
return auth.AuthManager(hass, store, {
|
||||
(provider.type, provider.id): provider
|
||||
})
|
||||
|
||||
|
||||
async def test_trusted_networks_credentials(manager, provider):
|
||||
"""Test trusted_networks credentials related functions."""
|
||||
owner = await manager.async_create_user("test-owner")
|
||||
tn_owner_cred = await provider.async_get_or_create_credentials({
|
||||
'user': owner.id
|
||||
})
|
||||
assert tn_owner_cred.is_new is False
|
||||
assert any(cred.id == tn_owner_cred.id for cred in owner.credentials)
|
||||
|
||||
user = await manager.async_create_user("test-user")
|
||||
tn_user_cred = await provider.async_get_or_create_credentials({
|
||||
'user': user.id
|
||||
})
|
||||
assert tn_user_cred.id != tn_owner_cred.id
|
||||
assert tn_user_cred.is_new is False
|
||||
assert any(cred.id == tn_user_cred.id for cred in user.credentials)
|
||||
|
||||
with pytest.raises(tn_auth.InvalidUserError):
|
||||
await provider.async_get_or_create_credentials({
|
||||
'user': 'invalid-user'
|
||||
})
|
||||
|
||||
|
||||
async def test_validate_access(provider):
|
||||
"""Test validate access from trusted networks."""
|
||||
with pytest.raises(tn_auth.InvalidAuthError):
|
||||
provider.async_validate_access('192.168.0.1')
|
||||
|
||||
provider.hass.http = Mock(trusted_networks=['192.168.0.1'])
|
||||
provider.async_validate_access('192.168.0.1')
|
||||
|
||||
with pytest.raises(tn_auth.InvalidAuthError):
|
||||
provider.async_validate_access('127.0.0.1')
|
||||
|
||||
|
||||
async def test_login_flow(manager, provider):
|
||||
"""Test login flow."""
|
||||
owner = await manager.async_create_user("test-owner")
|
||||
user = await manager.async_create_user("test-user")
|
||||
|
||||
# trusted network didn't loaded
|
||||
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
assert step['errors']['base'] == 'invalid_auth'
|
||||
|
||||
provider.hass.http = Mock(trusted_networks=['192.168.0.1'])
|
||||
|
||||
# not from trusted network
|
||||
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
assert step['errors']['base'] == 'invalid_auth'
|
||||
|
||||
# from trusted network, list users
|
||||
flow = await provider.async_credential_flow({'ip_address': '192.168.0.1'})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
|
||||
schema = step['data_schema']
|
||||
assert schema({'user': owner.id})
|
||||
with pytest.raises(vol.Invalid):
|
||||
assert schema({'user': 'invalid-user'})
|
||||
|
||||
# login with invalid user
|
||||
step = await flow.async_step_init({'user': 'invalid-user'})
|
||||
assert step['step_id'] == 'init'
|
||||
assert step['errors']['base'] == 'invalid_auth'
|
||||
|
||||
# login with valid user
|
||||
step = await flow.async_step_init({'user': user.id})
|
||||
assert step['type'] == 'create_entry'
|
||||
assert step['data']['user'] == user.id
|
Loading…
x
Reference in New Issue
Block a user