mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Add trusted_users in trusted networks auth provider (#22478)
This commit is contained in:
parent
26726af689
commit
6ba2891604
@ -18,8 +18,26 @@ from ..models import Credentials, UserMeta
|
||||
IPAddress = Union[IPv4Address, IPv6Address]
|
||||
IPNetwork = Union[IPv4Network, IPv6Network]
|
||||
|
||||
CONF_TRUSTED_NETWORKS = 'trusted_networks'
|
||||
CONF_TRUSTED_USERS = 'trusted_users'
|
||||
CONF_GROUP = 'group'
|
||||
CONF_ALLOW_BYPASS_LOGIN = 'allow_bypass_login'
|
||||
|
||||
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||
vol.Required('trusted_networks'): vol.All(cv.ensure_list, [ip_network])
|
||||
vol.Required(CONF_TRUSTED_NETWORKS): vol.All(
|
||||
cv.ensure_list, [ip_network]
|
||||
),
|
||||
vol.Optional(CONF_TRUSTED_USERS, default={}): vol.Schema(
|
||||
# we only validate the format of user_id or group_id
|
||||
{ip_network: vol.All(
|
||||
cv.ensure_list,
|
||||
[vol.Or(
|
||||
cv.uuid4_hex,
|
||||
vol.Schema({vol.Required(CONF_GROUP): cv.uuid4_hex}),
|
||||
)],
|
||||
)}
|
||||
),
|
||||
vol.Optional(CONF_ALLOW_BYPASS_LOGIN, default=False): cv.boolean,
|
||||
}, extra=vol.PREVENT_EXTRA)
|
||||
|
||||
|
||||
@ -43,7 +61,12 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||
@property
|
||||
def trusted_networks(self) -> List[IPNetwork]:
|
||||
"""Return trusted networks."""
|
||||
return cast(List[IPNetwork], self.config['trusted_networks'])
|
||||
return cast(List[IPNetwork], self.config[CONF_TRUSTED_NETWORKS])
|
||||
|
||||
@property
|
||||
def trusted_users(self) -> Dict[IPNetwork, Any]:
|
||||
"""Return trusted users per network."""
|
||||
return cast(Dict[IPNetwork, Any], self.config[CONF_TRUSTED_USERS])
|
||||
|
||||
@property
|
||||
def support_mfa(self) -> bool:
|
||||
@ -53,13 +76,34 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
assert context is not None
|
||||
ip_addr = cast(IPAddress, context.get('ip_address'))
|
||||
users = await self.store.async_get_users()
|
||||
available_users = {user.id: user.name
|
||||
for user in users
|
||||
if not user.system_generated and user.is_active}
|
||||
available_users = [user for user in users
|
||||
if not user.system_generated and user.is_active]
|
||||
for ip_net, user_or_group_list in self.trusted_users.items():
|
||||
if ip_addr in ip_net:
|
||||
user_list = [user_id for user_id in user_or_group_list
|
||||
if isinstance(user_id, str)]
|
||||
group_list = [group[CONF_GROUP] for group in user_or_group_list
|
||||
if isinstance(group, dict)]
|
||||
flattened_group_list = [group for sublist in group_list
|
||||
for group in sublist]
|
||||
available_users = [
|
||||
user for user in available_users
|
||||
if (user.id in user_list or
|
||||
any([group.id in flattened_group_list
|
||||
for group in user.groups]))
|
||||
]
|
||||
break
|
||||
|
||||
return TrustedNetworksLoginFlow(
|
||||
self, cast(IPAddress, context.get('ip_address')), available_users)
|
||||
self,
|
||||
ip_addr,
|
||||
{
|
||||
user.id: user.name for user in available_users
|
||||
},
|
||||
self.config[CONF_ALLOW_BYPASS_LOGIN],
|
||||
)
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]) -> Credentials:
|
||||
@ -109,11 +153,13 @@ class TrustedNetworksLoginFlow(LoginFlow):
|
||||
|
||||
def __init__(self, auth_provider: TrustedNetworksAuthProvider,
|
||||
ip_addr: IPAddress,
|
||||
available_users: Dict[str, Optional[str]]) -> None:
|
||||
available_users: Dict[str, Optional[str]],
|
||||
allow_bypass_login: bool) -> None:
|
||||
"""Initialize the login flow."""
|
||||
super().__init__(auth_provider)
|
||||
self._available_users = available_users
|
||||
self._ip_address = ip_addr
|
||||
self._allow_bypass_login = allow_bypass_login
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None) \
|
||||
@ -131,6 +177,11 @@ class TrustedNetworksLoginFlow(LoginFlow):
|
||||
if user_input is not None:
|
||||
return await self.async_finish(user_input)
|
||||
|
||||
if self._allow_bypass_login and len(self._available_users) == 1:
|
||||
return await self.async_finish({
|
||||
'user': next(iter(self._available_users.keys()))
|
||||
})
|
||||
|
||||
return self.async_show_form(
|
||||
step_id='init',
|
||||
data_schema=vol.Schema({'user': vol.In(self._available_users)}),
|
||||
|
@ -81,7 +81,8 @@ from . import indieauth
|
||||
async def async_setup(hass, store_result):
|
||||
"""Component to allow users to login."""
|
||||
hass.http.register_view(AuthProvidersView)
|
||||
hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow))
|
||||
hass.http.register_view(
|
||||
LoginFlowIndexView(hass.auth.login_flow, store_result))
|
||||
hass.http.register_view(
|
||||
LoginFlowResourceView(hass.auth.login_flow, store_result))
|
||||
|
||||
@ -142,9 +143,10 @@ class LoginFlowIndexView(HomeAssistantView):
|
||||
name = 'api:auth:login_flow'
|
||||
requires_auth = False
|
||||
|
||||
def __init__(self, flow_mgr):
|
||||
def __init__(self, flow_mgr, store_result):
|
||||
"""Initialize the flow manager index view."""
|
||||
self._flow_mgr = flow_mgr
|
||||
self._store_result = store_result
|
||||
|
||||
async def get(self, request):
|
||||
"""Do not allow index of flows in progress."""
|
||||
@ -179,6 +181,12 @@ class LoginFlowIndexView(HomeAssistantView):
|
||||
except data_entry_flow.UnknownStep:
|
||||
return self.json_message('Handler does not support init', 400)
|
||||
|
||||
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
result.pop('data')
|
||||
result['result'] = self._store_result(
|
||||
data['client_id'], result['result'])
|
||||
return self.json(result)
|
||||
|
||||
return self.json(_prepare_result_json(result))
|
||||
|
||||
|
||||
|
@ -8,6 +8,7 @@ from datetime import (timedelta, datetime as datetime_sys,
|
||||
from socket import _GLOBAL_DEFAULT_TIMEOUT
|
||||
from typing import Any, Union, TypeVar, Callable, Sequence, Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
from uuid import UUID
|
||||
|
||||
import voluptuous as vol
|
||||
from pkg_resources import parse_version
|
||||
@ -532,6 +533,20 @@ def x10_address(value):
|
||||
return str(value).lower()
|
||||
|
||||
|
||||
def uuid4_hex(value):
|
||||
"""Validate a v4 UUID in hex format."""
|
||||
try:
|
||||
result = UUID(value, version=4)
|
||||
except (ValueError, AttributeError, TypeError) as error:
|
||||
raise vol.Invalid('Invalid Version4 UUID', error_message=str(error))
|
||||
|
||||
if result.hex != value.lower():
|
||||
# UUID() will create a uuid4 if input is invalid
|
||||
raise vol.Invalid('Invalid Version4 UUID')
|
||||
|
||||
return result.hex
|
||||
|
||||
|
||||
def ensure_list_csv(value: Any) -> Sequence:
|
||||
"""Ensure that input is a list or make one from comma-separated string."""
|
||||
if isinstance(value, str):
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Test the Trusted Networks auth provider."""
|
||||
from ipaddress import ip_address
|
||||
from ipaddress import ip_address, ip_network
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
@ -25,8 +25,47 @@ def provider(hass, store):
|
||||
'192.168.0.1',
|
||||
'192.168.128.0/24',
|
||||
'::1',
|
||||
'fd00::/8'
|
||||
]
|
||||
'fd00::/8',
|
||||
],
|
||||
})
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_with_user(hass, store):
|
||||
"""Mock provider with trusted users config."""
|
||||
return tn_auth.TrustedNetworksAuthProvider(
|
||||
hass, store, tn_auth.CONFIG_SCHEMA({
|
||||
'type': 'trusted_networks',
|
||||
'trusted_networks': [
|
||||
'192.168.0.1',
|
||||
'192.168.128.0/24',
|
||||
'::1',
|
||||
'fd00::/8',
|
||||
],
|
||||
# user_id will be injected in test
|
||||
'trusted_users': {
|
||||
'192.168.0.1': [],
|
||||
'192.168.128.0/24': [],
|
||||
'fd00::/8': [],
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_bypass_login(hass, store):
|
||||
"""Mock provider with allow_bypass_login config."""
|
||||
return tn_auth.TrustedNetworksAuthProvider(
|
||||
hass, store, tn_auth.CONFIG_SCHEMA({
|
||||
'type': 'trusted_networks',
|
||||
'trusted_networks': [
|
||||
'192.168.0.1',
|
||||
'192.168.128.0/24',
|
||||
'::1',
|
||||
'fd00::/8',
|
||||
],
|
||||
'allow_bypass_login': True,
|
||||
})
|
||||
)
|
||||
|
||||
@ -39,6 +78,23 @@ def manager(hass, store, provider):
|
||||
}, {})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager_with_user(hass, store, provider_with_user):
|
||||
"""Mock manager with trusted user."""
|
||||
return auth.AuthManager(hass, store, {
|
||||
(provider_with_user.type, provider_with_user.id): provider_with_user
|
||||
}, {})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager_bypass_login(hass, store, provider_bypass_login):
|
||||
"""Mock manager with allow bypass login."""
|
||||
return auth.AuthManager(hass, store, {
|
||||
(provider_bypass_login.type, provider_bypass_login.id):
|
||||
provider_bypass_login
|
||||
}, {})
|
||||
|
||||
|
||||
async def test_trusted_networks_credentials(manager, provider):
|
||||
"""Test trusted_networks credentials related functions."""
|
||||
owner = await manager.async_create_user("test-owner")
|
||||
@ -104,3 +160,157 @@ async def test_login_flow(manager, provider):
|
||||
step = await flow.async_step_init({'user': user.id})
|
||||
assert step['type'] == 'create_entry'
|
||||
assert step['data']['user'] == user.id
|
||||
|
||||
|
||||
async def test_trusted_users_login(manager_with_user, provider_with_user):
|
||||
"""Test available user list changed per different IP."""
|
||||
owner = await manager_with_user.async_create_user("test-owner")
|
||||
sys_user = await manager_with_user.async_create_system_user(
|
||||
"test-sys-user") # system user will not be available to select
|
||||
user = await manager_with_user.async_create_user("test-user")
|
||||
|
||||
# change the trusted users config
|
||||
config = provider_with_user.config['trusted_users']
|
||||
assert ip_network('192.168.0.1') in config
|
||||
config[ip_network('192.168.0.1')] = [owner.id]
|
||||
assert ip_network('192.168.128.0/24') in config
|
||||
config[ip_network('192.168.128.0/24')] = [sys_user.id, user.id]
|
||||
|
||||
# not from trusted network
|
||||
flow = await provider_with_user.async_login_flow(
|
||||
{'ip_address': ip_address('127.0.0.1')})
|
||||
step = await flow.async_step_init()
|
||||
assert step['type'] == 'abort'
|
||||
assert step['reason'] == 'not_whitelisted'
|
||||
|
||||
# from trusted network, list users intersect trusted_users
|
||||
flow = await provider_with_user.async_login_flow(
|
||||
{'ip_address': ip_address('192.168.0.1')})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
|
||||
schema = step['data_schema']
|
||||
# only owner listed
|
||||
assert schema({'user': owner.id})
|
||||
with pytest.raises(vol.Invalid):
|
||||
assert schema({'user': user.id})
|
||||
|
||||
# from trusted network, list users intersect trusted_users
|
||||
flow = await provider_with_user.async_login_flow(
|
||||
{'ip_address': ip_address('192.168.128.1')})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
|
||||
schema = step['data_schema']
|
||||
# only user listed
|
||||
assert schema({'user': user.id})
|
||||
with pytest.raises(vol.Invalid):
|
||||
assert schema({'user': owner.id})
|
||||
with pytest.raises(vol.Invalid):
|
||||
assert schema({'user': sys_user.id})
|
||||
|
||||
# from trusted network, list users intersect trusted_users
|
||||
flow = await provider_with_user.async_login_flow(
|
||||
{'ip_address': ip_address('::1')})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
|
||||
schema = step['data_schema']
|
||||
# both owner and user listed
|
||||
assert schema({'user': owner.id})
|
||||
assert schema({'user': user.id})
|
||||
with pytest.raises(vol.Invalid):
|
||||
assert schema({'user': sys_user.id})
|
||||
|
||||
# from trusted network, list users intersect trusted_users
|
||||
flow = await provider_with_user.async_login_flow(
|
||||
{'ip_address': ip_address('fd00::1')})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
|
||||
schema = step['data_schema']
|
||||
# no user listed
|
||||
with pytest.raises(vol.Invalid):
|
||||
assert schema({'user': owner.id})
|
||||
with pytest.raises(vol.Invalid):
|
||||
assert schema({'user': user.id})
|
||||
with pytest.raises(vol.Invalid):
|
||||
assert schema({'user': sys_user.id})
|
||||
|
||||
|
||||
async def test_trusted_group_login(manager_with_user, provider_with_user):
|
||||
"""Test config trusted_user with group_id."""
|
||||
owner = await manager_with_user.async_create_user("test-owner")
|
||||
# create a user in user group
|
||||
user = await manager_with_user.async_create_user("test-user")
|
||||
await manager_with_user.async_update_user(
|
||||
user, group_ids=[auth.const.GROUP_ID_USER])
|
||||
|
||||
# change the trusted users config
|
||||
config = provider_with_user.config['trusted_users']
|
||||
assert ip_network('192.168.0.1') in config
|
||||
config[ip_network('192.168.0.1')] = [{'group': [auth.const.GROUP_ID_USER]}]
|
||||
assert ip_network('192.168.128.0/24') in config
|
||||
config[ip_network('192.168.128.0/24')] = [
|
||||
owner.id, {'group': [auth.const.GROUP_ID_USER]}]
|
||||
|
||||
# not from trusted network
|
||||
flow = await provider_with_user.async_login_flow(
|
||||
{'ip_address': ip_address('127.0.0.1')})
|
||||
step = await flow.async_step_init()
|
||||
assert step['type'] == 'abort'
|
||||
assert step['reason'] == 'not_whitelisted'
|
||||
|
||||
# from trusted network, list users intersect trusted_users
|
||||
flow = await provider_with_user.async_login_flow(
|
||||
{'ip_address': ip_address('192.168.0.1')})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
|
||||
schema = step['data_schema']
|
||||
# only user listed
|
||||
print(user.id)
|
||||
assert schema({'user': user.id})
|
||||
with pytest.raises(vol.Invalid):
|
||||
assert schema({'user': owner.id})
|
||||
|
||||
# from trusted network, list users intersect trusted_users
|
||||
flow = await provider_with_user.async_login_flow(
|
||||
{'ip_address': ip_address('192.168.128.1')})
|
||||
step = await flow.async_step_init()
|
||||
assert step['step_id'] == 'init'
|
||||
|
||||
schema = step['data_schema']
|
||||
# both owner and user listed
|
||||
assert schema({'user': owner.id})
|
||||
assert schema({'user': user.id})
|
||||
|
||||
|
||||
async def test_bypass_login_flow(manager_bypass_login, provider_bypass_login):
|
||||
"""Test login flow can be bypass if only one user available."""
|
||||
owner = await manager_bypass_login.async_create_user("test-owner")
|
||||
|
||||
# not from trusted network
|
||||
flow = await provider_bypass_login.async_login_flow(
|
||||
{'ip_address': ip_address('127.0.0.1')})
|
||||
step = await flow.async_step_init()
|
||||
assert step['type'] == 'abort'
|
||||
assert step['reason'] == 'not_whitelisted'
|
||||
|
||||
# from trusted network, only one available user, bypass the login flow
|
||||
flow = await provider_bypass_login.async_login_flow(
|
||||
{'ip_address': ip_address('192.168.0.1')})
|
||||
step = await flow.async_step_init()
|
||||
assert step['type'] == 'create_entry'
|
||||
assert step['data']['user'] == owner.id
|
||||
|
||||
user = await manager_bypass_login.async_create_user("test-user")
|
||||
|
||||
# from trusted network, two available user, show up login form
|
||||
flow = await provider_bypass_login.async_login_flow(
|
||||
{'ip_address': ip_address('192.168.0.1')})
|
||||
step = await flow.async_step_init()
|
||||
schema = step['data_schema']
|
||||
# both owner and user listed
|
||||
assert schema({'user': owner.id})
|
||||
assert schema({'user': user.id})
|
||||
|
@ -4,6 +4,7 @@ import enum
|
||||
import os
|
||||
from socket import _GLOBAL_DEFAULT_TIMEOUT
|
||||
from unittest.mock import Mock, patch
|
||||
import uuid
|
||||
|
||||
import homeassistant
|
||||
import pytest
|
||||
@ -963,3 +964,24 @@ def test_entity_id_allow_old_validation(caplog):
|
||||
assert "Found invalid entity_id {}".format(value) in caplog.text
|
||||
|
||||
assert len(cv.INVALID_ENTITY_IDS_FOUND) == 2
|
||||
|
||||
|
||||
def test_uuid4_hex(caplog):
|
||||
"""Test uuid validation."""
|
||||
schema = vol.Schema(cv.uuid4_hex)
|
||||
|
||||
for value in ['Not a hex string', '0', 0]:
|
||||
with pytest.raises(vol.Invalid):
|
||||
schema(value)
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
# the 13th char should be 4
|
||||
schema('a03d31b22eee1acc9b90eec40be6ed23')
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
# the 17th char should be 8-a
|
||||
schema('a03d31b22eee4acc7b90eec40be6ed23')
|
||||
|
||||
hex = uuid.uuid4().hex
|
||||
assert schema(hex) == hex
|
||||
assert schema(hex.upper()) == hex
|
||||
|
Loading…
x
Reference in New Issue
Block a user