mirror of
https://github.com/home-assistant/core.git
synced 2025-07-26 22:57:17 +00:00
Add trusted networks auth provider (#15812)
* Add context to login flow * Add trusted networks auth provider * source -> context
This commit is contained in:
parent
50daef9a52
commit
da8f93dca2
125
homeassistant/auth/providers/trusted_networks.py
Normal file
125
homeassistant/auth/providers/trusted_networks.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
"""Trusted Networks auth provider.
|
||||||
|
|
||||||
|
It shows list of users if access from trusted network.
|
||||||
|
Abort login flow if not access from trusted network.
|
||||||
|
"""
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant import data_entry_flow
|
||||||
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||||
|
|
||||||
|
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||||
|
}, extra=vol.PREVENT_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidAuthError(HomeAssistantError):
|
||||||
|
"""Raised when try to access from untrusted networks."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidUserError(HomeAssistantError):
|
||||||
|
"""Raised when try to login as invalid user."""
|
||||||
|
|
||||||
|
|
||||||
|
@AUTH_PROVIDERS.register('trusted_networks')
|
||||||
|
class TrustedNetworksAuthProvider(AuthProvider):
|
||||||
|
"""Trusted Networks auth provider.
|
||||||
|
|
||||||
|
Allow passwordless access from trusted network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_TITLE = 'Trusted Networks'
|
||||||
|
|
||||||
|
async def async_credential_flow(self, context):
|
||||||
|
"""Return a flow to login."""
|
||||||
|
users = await self.store.async_get_users()
|
||||||
|
available_users = {user.id: user.name
|
||||||
|
for user in users
|
||||||
|
if not user.system_generated and user.is_active}
|
||||||
|
|
||||||
|
return LoginFlow(self, context.get('ip_address'), available_users)
|
||||||
|
|
||||||
|
async def async_get_or_create_credentials(self, flow_result):
|
||||||
|
"""Get credentials based on the flow result."""
|
||||||
|
user_id = flow_result['user']
|
||||||
|
|
||||||
|
users = await self.store.async_get_users()
|
||||||
|
for user in users:
|
||||||
|
if (not user.system_generated and
|
||||||
|
user.is_active and
|
||||||
|
user.id == user_id):
|
||||||
|
for credential in await self.async_credentials():
|
||||||
|
if credential.data['user_id'] == user_id:
|
||||||
|
return credential
|
||||||
|
cred = self.async_create_credentials({'user_id': user_id})
|
||||||
|
await self.store.async_link_user(user, cred)
|
||||||
|
return cred
|
||||||
|
|
||||||
|
# We only allow login as exist user
|
||||||
|
raise InvalidUserError
|
||||||
|
|
||||||
|
async def async_user_meta_for_credentials(self, credentials):
|
||||||
|
"""Return extra user metadata for credentials.
|
||||||
|
|
||||||
|
Trusted network auth provider should never create new user.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_validate_access(self, ip_address):
|
||||||
|
"""Make sure the access from trusted networks.
|
||||||
|
|
||||||
|
Raise InvalidAuthError if not.
|
||||||
|
Raise InvalidAuthError if trusted_networks is not config
|
||||||
|
"""
|
||||||
|
if (not hasattr(self.hass, 'http') or
|
||||||
|
not self.hass.http or not self.hass.http.trusted_networks):
|
||||||
|
raise InvalidAuthError('trusted_networks is not configured')
|
||||||
|
|
||||||
|
if not any(ip_address in trusted_network for trusted_network
|
||||||
|
in self.hass.http.trusted_networks):
|
||||||
|
raise InvalidAuthError('Not in trusted_networks')
|
||||||
|
|
||||||
|
|
||||||
|
class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
|
def __init__(self, auth_provider, ip_address, available_users):
|
||||||
|
"""Initialize the login flow."""
|
||||||
|
self._auth_provider = auth_provider
|
||||||
|
self._available_users = available_users
|
||||||
|
self._ip_address = ip_address
|
||||||
|
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
"""Handle the step of the form."""
|
||||||
|
errors = {}
|
||||||
|
try:
|
||||||
|
self._auth_provider.async_validate_access(self._ip_address)
|
||||||
|
|
||||||
|
except InvalidAuthError:
|
||||||
|
errors['base'] = 'invalid_auth'
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id='init',
|
||||||
|
data_schema=None,
|
||||||
|
errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_input is not None:
|
||||||
|
user_id = user_input['user']
|
||||||
|
if user_id not in self._available_users:
|
||||||
|
errors['base'] = 'invalid_auth'
|
||||||
|
|
||||||
|
if not errors:
|
||||||
|
return self.async_create_entry(
|
||||||
|
title=self._auth_provider.name,
|
||||||
|
data=user_input
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = {'user': vol.In(self._available_users)}
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id='init',
|
||||||
|
data_schema=vol.Schema(schema),
|
||||||
|
errors=errors,
|
||||||
|
)
|
@ -63,6 +63,7 @@ import aiohttp.web
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
|
from homeassistant.components.http import KEY_REAL_IP
|
||||||
from homeassistant.components.http.ban import process_wrong_login, \
|
from homeassistant.components.http.ban import process_wrong_login, \
|
||||||
log_invalid_auth
|
log_invalid_auth
|
||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
@ -151,7 +152,8 @@ class LoginFlowIndexView(HomeAssistantView):
|
|||||||
handler = data['handler']
|
handler = data['handler']
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await self._flow_mgr.async_init(handler, context={})
|
result = await self._flow_mgr.async_init(
|
||||||
|
handler, context={'ip_address': request[KEY_REAL_IP]})
|
||||||
except data_entry_flow.UnknownHandler:
|
except data_entry_flow.UnknownHandler:
|
||||||
return self.json_message('Invalid handler specified', 404)
|
return self.json_message('Invalid handler specified', 404)
|
||||||
except data_entry_flow.UnknownStep:
|
except data_entry_flow.UnknownStep:
|
||||||
@ -188,6 +190,13 @@ class LoginFlowResourceView(HomeAssistantView):
|
|||||||
return self.json_message('Invalid client id', 400)
|
return self.json_message('Invalid client id', 400)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# do not allow change ip during login flow
|
||||||
|
for flow in self._flow_mgr.async_progress():
|
||||||
|
if (flow['flow_id'] == flow_id and
|
||||||
|
flow['context']['ip_address'] !=
|
||||||
|
request.get(KEY_REAL_IP)):
|
||||||
|
return self.json_message('IP address changed', 400)
|
||||||
|
|
||||||
result = await self._flow_mgr.async_configure(flow_id, data)
|
result = await self._flow_mgr.async_configure(flow_id, data)
|
||||||
except data_entry_flow.UnknownFlow:
|
except data_entry_flow.UnknownFlow:
|
||||||
return self.json_message('Invalid flow specified', 404)
|
return self.json_message('Invalid flow specified', 404)
|
||||||
|
@ -220,6 +220,7 @@ class HomeAssistantHTTP:
|
|||||||
self.ssl_key = ssl_key
|
self.ssl_key = ssl_key
|
||||||
self.server_host = server_host
|
self.server_host = server_host
|
||||||
self.server_port = server_port
|
self.server_port = server_port
|
||||||
|
self.trusted_networks = trusted_networks
|
||||||
self.is_ban_enabled = is_ban_enabled
|
self.is_ban_enabled = is_ban_enabled
|
||||||
self._handler = None
|
self._handler = None
|
||||||
self.server = None
|
self.server = None
|
||||||
|
@ -344,7 +344,10 @@ class ActiveConnection:
|
|||||||
if request[KEY_AUTHENTICATED]:
|
if request[KEY_AUTHENTICATED]:
|
||||||
authenticated = True
|
authenticated = True
|
||||||
|
|
||||||
else:
|
# always request auth when auth is active
|
||||||
|
# even request passed pre-authentication (trusted networks)
|
||||||
|
# or when using legacy api_password
|
||||||
|
if self.hass.auth.active or not authenticated:
|
||||||
self.debug("Request auth")
|
self.debug("Request auth")
|
||||||
await self.wsock.send_json(auth_required_message())
|
await self.wsock.send_json(auth_required_message())
|
||||||
msg = await wsock.receive_json()
|
msg = await wsock.receive_json()
|
||||||
|
@ -63,18 +63,9 @@ async def test_verify_not_load(hass, provider):
|
|||||||
|
|
||||||
|
|
||||||
async def test_verify_login(hass, provider):
|
async def test_verify_login(hass, provider):
|
||||||
"""Test we raise if http module not load."""
|
"""Test login using legacy api password auth provider."""
|
||||||
hass.http = Mock(api_password='test-password')
|
hass.http = Mock(api_password='test-password')
|
||||||
provider.async_validate_login('test-password')
|
provider.async_validate_login('test-password')
|
||||||
hass.http = Mock(api_password='test-password')
|
hass.http = Mock(api_password='test-password')
|
||||||
with pytest.raises(legacy_api_password.InvalidAuthError):
|
with pytest.raises(legacy_api_password.InvalidAuthError):
|
||||||
provider.async_validate_login('invalid-password')
|
provider.async_validate_login('invalid-password')
|
||||||
|
|
||||||
|
|
||||||
async def test_utf_8_username_password(provider):
|
|
||||||
"""Test that we create a new credential."""
|
|
||||||
credentials = await provider.async_get_or_create_credentials({
|
|
||||||
'username': '🎉',
|
|
||||||
'password': '😎',
|
|
||||||
})
|
|
||||||
assert credentials.is_new is True
|
|
||||||
|
106
tests/auth/providers/test_trusted_networks.py
Normal file
106
tests/auth/providers/test_trusted_networks.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
"""Test the Trusted Networks auth provider."""
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant import auth
|
||||||
|
from homeassistant.auth import auth_store
|
||||||
|
from homeassistant.auth.providers import trusted_networks as tn_auth
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(hass):
|
||||||
|
"""Mock store."""
|
||||||
|
return auth_store.AuthStore(hass)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def provider(hass, store):
|
||||||
|
"""Mock provider."""
|
||||||
|
return tn_auth.TrustedNetworksAuthProvider(hass, store, {
|
||||||
|
'type': 'trusted_networks'
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(hass, store, provider):
|
||||||
|
"""Mock manager."""
|
||||||
|
return auth.AuthManager(hass, store, {
|
||||||
|
(provider.type, provider.id): provider
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_trusted_networks_credentials(manager, provider):
|
||||||
|
"""Test trusted_networks credentials related functions."""
|
||||||
|
owner = await manager.async_create_user("test-owner")
|
||||||
|
tn_owner_cred = await provider.async_get_or_create_credentials({
|
||||||
|
'user': owner.id
|
||||||
|
})
|
||||||
|
assert tn_owner_cred.is_new is False
|
||||||
|
assert any(cred.id == tn_owner_cred.id for cred in owner.credentials)
|
||||||
|
|
||||||
|
user = await manager.async_create_user("test-user")
|
||||||
|
tn_user_cred = await provider.async_get_or_create_credentials({
|
||||||
|
'user': user.id
|
||||||
|
})
|
||||||
|
assert tn_user_cred.id != tn_owner_cred.id
|
||||||
|
assert tn_user_cred.is_new is False
|
||||||
|
assert any(cred.id == tn_user_cred.id for cred in user.credentials)
|
||||||
|
|
||||||
|
with pytest.raises(tn_auth.InvalidUserError):
|
||||||
|
await provider.async_get_or_create_credentials({
|
||||||
|
'user': 'invalid-user'
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_validate_access(provider):
|
||||||
|
"""Test validate access from trusted networks."""
|
||||||
|
with pytest.raises(tn_auth.InvalidAuthError):
|
||||||
|
provider.async_validate_access('192.168.0.1')
|
||||||
|
|
||||||
|
provider.hass.http = Mock(trusted_networks=['192.168.0.1'])
|
||||||
|
provider.async_validate_access('192.168.0.1')
|
||||||
|
|
||||||
|
with pytest.raises(tn_auth.InvalidAuthError):
|
||||||
|
provider.async_validate_access('127.0.0.1')
|
||||||
|
|
||||||
|
|
||||||
|
async def test_login_flow(manager, provider):
|
||||||
|
"""Test login flow."""
|
||||||
|
owner = await manager.async_create_user("test-owner")
|
||||||
|
user = await manager.async_create_user("test-user")
|
||||||
|
|
||||||
|
# trusted network didn't loaded
|
||||||
|
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'})
|
||||||
|
step = await flow.async_step_init()
|
||||||
|
assert step['step_id'] == 'init'
|
||||||
|
assert step['errors']['base'] == 'invalid_auth'
|
||||||
|
|
||||||
|
provider.hass.http = Mock(trusted_networks=['192.168.0.1'])
|
||||||
|
|
||||||
|
# not from trusted network
|
||||||
|
flow = await provider.async_credential_flow({'ip_address': '127.0.0.1'})
|
||||||
|
step = await flow.async_step_init()
|
||||||
|
assert step['step_id'] == 'init'
|
||||||
|
assert step['errors']['base'] == 'invalid_auth'
|
||||||
|
|
||||||
|
# from trusted network, list users
|
||||||
|
flow = await provider.async_credential_flow({'ip_address': '192.168.0.1'})
|
||||||
|
step = await flow.async_step_init()
|
||||||
|
assert step['step_id'] == 'init'
|
||||||
|
|
||||||
|
schema = step['data_schema']
|
||||||
|
assert schema({'user': owner.id})
|
||||||
|
with pytest.raises(vol.Invalid):
|
||||||
|
assert schema({'user': 'invalid-user'})
|
||||||
|
|
||||||
|
# login with invalid user
|
||||||
|
step = await flow.async_step_init({'user': 'invalid-user'})
|
||||||
|
assert step['step_id'] == 'init'
|
||||||
|
assert step['errors']['base'] == 'invalid_auth'
|
||||||
|
|
||||||
|
# login with valid user
|
||||||
|
step = await flow.async_step_init({'user': user.id})
|
||||||
|
assert step['type'] == 'create_entry'
|
||||||
|
assert step['data']['user'] == user.id
|
Loading…
x
Reference in New Issue
Block a user