Add trusted_users in trusted networks auth provider (#22478)

This commit is contained in:
Jason Hu 2019-03-27 21:53:11 -07:00 committed by Paulus Schoutsen
parent 26726af689
commit 6ba2891604
5 changed files with 318 additions and 12 deletions

View File

@ -18,8 +18,26 @@ from ..models import Credentials, UserMeta
IPAddress = Union[IPv4Address, IPv6Address] IPAddress = Union[IPv4Address, IPv6Address]
IPNetwork = Union[IPv4Network, IPv6Network] IPNetwork = Union[IPv4Network, IPv6Network]
CONF_TRUSTED_NETWORKS = 'trusted_networks'
CONF_TRUSTED_USERS = 'trusted_users'
CONF_GROUP = 'group'
CONF_ALLOW_BYPASS_LOGIN = 'allow_bypass_login'
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({ CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
vol.Required('trusted_networks'): vol.All(cv.ensure_list, [ip_network]) vol.Required(CONF_TRUSTED_NETWORKS): vol.All(
cv.ensure_list, [ip_network]
),
vol.Optional(CONF_TRUSTED_USERS, default={}): vol.Schema(
# we only validate the format of user_id or group_id
{ip_network: vol.All(
cv.ensure_list,
[vol.Or(
cv.uuid4_hex,
vol.Schema({vol.Required(CONF_GROUP): cv.uuid4_hex}),
)],
)}
),
vol.Optional(CONF_ALLOW_BYPASS_LOGIN, default=False): cv.boolean,
}, extra=vol.PREVENT_EXTRA) }, extra=vol.PREVENT_EXTRA)
@ -43,7 +61,12 @@ class TrustedNetworksAuthProvider(AuthProvider):
@property @property
def trusted_networks(self) -> List[IPNetwork]: def trusted_networks(self) -> List[IPNetwork]:
"""Return trusted networks.""" """Return trusted networks."""
return cast(List[IPNetwork], self.config['trusted_networks']) return cast(List[IPNetwork], self.config[CONF_TRUSTED_NETWORKS])
@property
def trusted_users(self) -> Dict[IPNetwork, Any]:
"""Return trusted users per network."""
return cast(Dict[IPNetwork, Any], self.config[CONF_TRUSTED_USERS])
@property @property
def support_mfa(self) -> bool: def support_mfa(self) -> bool:
@ -53,13 +76,34 @@ class TrustedNetworksAuthProvider(AuthProvider):
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
assert context is not None assert context is not None
ip_addr = cast(IPAddress, context.get('ip_address'))
users = await self.store.async_get_users() users = await self.store.async_get_users()
available_users = {user.id: user.name available_users = [user for user in users
for user in users if not user.system_generated and user.is_active]
if not user.system_generated and user.is_active} for ip_net, user_or_group_list in self.trusted_users.items():
if ip_addr in ip_net:
user_list = [user_id for user_id in user_or_group_list
if isinstance(user_id, str)]
group_list = [group[CONF_GROUP] for group in user_or_group_list
if isinstance(group, dict)]
flattened_group_list = [group for sublist in group_list
for group in sublist]
available_users = [
user for user in available_users
if (user.id in user_list or
any([group.id in flattened_group_list
for group in user.groups]))
]
break
return TrustedNetworksLoginFlow( return TrustedNetworksLoginFlow(
self, cast(IPAddress, context.get('ip_address')), available_users) self,
ip_addr,
{
user.id: user.name for user in available_users
},
self.config[CONF_ALLOW_BYPASS_LOGIN],
)
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials: self, flow_result: Dict[str, str]) -> Credentials:
@ -109,11 +153,13 @@ class TrustedNetworksLoginFlow(LoginFlow):
def __init__(self, auth_provider: TrustedNetworksAuthProvider, def __init__(self, auth_provider: TrustedNetworksAuthProvider,
ip_addr: IPAddress, ip_addr: IPAddress,
available_users: Dict[str, Optional[str]]) -> None: available_users: Dict[str, Optional[str]],
allow_bypass_login: bool) -> None:
"""Initialize the login flow.""" """Initialize the login flow."""
super().__init__(auth_provider) super().__init__(auth_provider)
self._available_users = available_users self._available_users = available_users
self._ip_address = ip_addr self._ip_address = ip_addr
self._allow_bypass_login = allow_bypass_login
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None) \
@ -131,6 +177,11 @@ class TrustedNetworksLoginFlow(LoginFlow):
if user_input is not None: if user_input is not None:
return await self.async_finish(user_input) return await self.async_finish(user_input)
if self._allow_bypass_login and len(self._available_users) == 1:
return await self.async_finish({
'user': next(iter(self._available_users.keys()))
})
return self.async_show_form( return self.async_show_form(
step_id='init', step_id='init',
data_schema=vol.Schema({'user': vol.In(self._available_users)}), data_schema=vol.Schema({'user': vol.In(self._available_users)}),

View File

@ -81,7 +81,8 @@ from . import indieauth
async def async_setup(hass, store_result): async def async_setup(hass, store_result):
"""Component to allow users to login.""" """Component to allow users to login."""
hass.http.register_view(AuthProvidersView) hass.http.register_view(AuthProvidersView)
hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow)) hass.http.register_view(
LoginFlowIndexView(hass.auth.login_flow, store_result))
hass.http.register_view( hass.http.register_view(
LoginFlowResourceView(hass.auth.login_flow, store_result)) LoginFlowResourceView(hass.auth.login_flow, store_result))
@ -142,9 +143,10 @@ class LoginFlowIndexView(HomeAssistantView):
name = 'api:auth:login_flow' name = 'api:auth:login_flow'
requires_auth = False requires_auth = False
def __init__(self, flow_mgr): def __init__(self, flow_mgr, store_result):
"""Initialize the flow manager index view.""" """Initialize the flow manager index view."""
self._flow_mgr = flow_mgr self._flow_mgr = flow_mgr
self._store_result = store_result
async def get(self, request): async def get(self, request):
"""Do not allow index of flows in progress.""" """Do not allow index of flows in progress."""
@ -179,6 +181,12 @@ class LoginFlowIndexView(HomeAssistantView):
except data_entry_flow.UnknownStep: except data_entry_flow.UnknownStep:
return self.json_message('Handler does not support init', 400) return self.json_message('Handler does not support init', 400)
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
result.pop('data')
result['result'] = self._store_result(
data['client_id'], result['result'])
return self.json(result)
return self.json(_prepare_result_json(result)) return self.json(_prepare_result_json(result))

View File

@ -8,6 +8,7 @@ from datetime import (timedelta, datetime as datetime_sys,
from socket import _GLOBAL_DEFAULT_TIMEOUT from socket import _GLOBAL_DEFAULT_TIMEOUT
from typing import Any, Union, TypeVar, Callable, Sequence, Dict, Optional from typing import Any, Union, TypeVar, Callable, Sequence, Dict, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import UUID
import voluptuous as vol import voluptuous as vol
from pkg_resources import parse_version from pkg_resources import parse_version
@ -532,6 +533,20 @@ def x10_address(value):
return str(value).lower() return str(value).lower()
def uuid4_hex(value):
"""Validate a v4 UUID in hex format."""
try:
result = UUID(value, version=4)
except (ValueError, AttributeError, TypeError) as error:
raise vol.Invalid('Invalid Version4 UUID', error_message=str(error))
if result.hex != value.lower():
# UUID() will create a uuid4 if input is invalid
raise vol.Invalid('Invalid Version4 UUID')
return result.hex
def ensure_list_csv(value: Any) -> Sequence: def ensure_list_csv(value: Any) -> Sequence:
"""Ensure that input is a list or make one from comma-separated string.""" """Ensure that input is a list or make one from comma-separated string."""
if isinstance(value, str): if isinstance(value, str):

View File

@ -1,5 +1,5 @@
"""Test the Trusted Networks auth provider.""" """Test the Trusted Networks auth provider."""
from ipaddress import ip_address from ipaddress import ip_address, ip_network
import pytest import pytest
import voluptuous as vol import voluptuous as vol
@ -25,8 +25,47 @@ def provider(hass, store):
'192.168.0.1', '192.168.0.1',
'192.168.128.0/24', '192.168.128.0/24',
'::1', '::1',
'fd00::/8' 'fd00::/8',
] ],
})
)
@pytest.fixture
def provider_with_user(hass, store):
"""Mock provider with trusted users config."""
return tn_auth.TrustedNetworksAuthProvider(
hass, store, tn_auth.CONFIG_SCHEMA({
'type': 'trusted_networks',
'trusted_networks': [
'192.168.0.1',
'192.168.128.0/24',
'::1',
'fd00::/8',
],
# user_id will be injected in test
'trusted_users': {
'192.168.0.1': [],
'192.168.128.0/24': [],
'fd00::/8': [],
},
})
)
@pytest.fixture
def provider_bypass_login(hass, store):
"""Mock provider with allow_bypass_login config."""
return tn_auth.TrustedNetworksAuthProvider(
hass, store, tn_auth.CONFIG_SCHEMA({
'type': 'trusted_networks',
'trusted_networks': [
'192.168.0.1',
'192.168.128.0/24',
'::1',
'fd00::/8',
],
'allow_bypass_login': True,
}) })
) )
@ -39,6 +78,23 @@ def manager(hass, store, provider):
}, {}) }, {})
@pytest.fixture
def manager_with_user(hass, store, provider_with_user):
"""Mock manager with trusted user."""
return auth.AuthManager(hass, store, {
(provider_with_user.type, provider_with_user.id): provider_with_user
}, {})
@pytest.fixture
def manager_bypass_login(hass, store, provider_bypass_login):
"""Mock manager with allow bypass login."""
return auth.AuthManager(hass, store, {
(provider_bypass_login.type, provider_bypass_login.id):
provider_bypass_login
}, {})
async def test_trusted_networks_credentials(manager, provider): async def test_trusted_networks_credentials(manager, provider):
"""Test trusted_networks credentials related functions.""" """Test trusted_networks credentials related functions."""
owner = await manager.async_create_user("test-owner") owner = await manager.async_create_user("test-owner")
@ -104,3 +160,157 @@ async def test_login_flow(manager, provider):
step = await flow.async_step_init({'user': user.id}) step = await flow.async_step_init({'user': user.id})
assert step['type'] == 'create_entry' assert step['type'] == 'create_entry'
assert step['data']['user'] == user.id assert step['data']['user'] == user.id
async def test_trusted_users_login(manager_with_user, provider_with_user):
"""Test available user list changed per different IP."""
owner = await manager_with_user.async_create_user("test-owner")
sys_user = await manager_with_user.async_create_system_user(
"test-sys-user") # system user will not be available to select
user = await manager_with_user.async_create_user("test-user")
# change the trusted users config
config = provider_with_user.config['trusted_users']
assert ip_network('192.168.0.1') in config
config[ip_network('192.168.0.1')] = [owner.id]
assert ip_network('192.168.128.0/24') in config
config[ip_network('192.168.128.0/24')] = [sys_user.id, user.id]
# not from trusted network
flow = await provider_with_user.async_login_flow(
{'ip_address': ip_address('127.0.0.1')})
step = await flow.async_step_init()
assert step['type'] == 'abort'
assert step['reason'] == 'not_whitelisted'
# from trusted network, list users intersect trusted_users
flow = await provider_with_user.async_login_flow(
{'ip_address': ip_address('192.168.0.1')})
step = await flow.async_step_init()
assert step['step_id'] == 'init'
schema = step['data_schema']
# only owner listed
assert schema({'user': owner.id})
with pytest.raises(vol.Invalid):
assert schema({'user': user.id})
# from trusted network, list users intersect trusted_users
flow = await provider_with_user.async_login_flow(
{'ip_address': ip_address('192.168.128.1')})
step = await flow.async_step_init()
assert step['step_id'] == 'init'
schema = step['data_schema']
# only user listed
assert schema({'user': user.id})
with pytest.raises(vol.Invalid):
assert schema({'user': owner.id})
with pytest.raises(vol.Invalid):
assert schema({'user': sys_user.id})
# from trusted network, list users intersect trusted_users
flow = await provider_with_user.async_login_flow(
{'ip_address': ip_address('::1')})
step = await flow.async_step_init()
assert step['step_id'] == 'init'
schema = step['data_schema']
# both owner and user listed
assert schema({'user': owner.id})
assert schema({'user': user.id})
with pytest.raises(vol.Invalid):
assert schema({'user': sys_user.id})
# from trusted network, list users intersect trusted_users
flow = await provider_with_user.async_login_flow(
{'ip_address': ip_address('fd00::1')})
step = await flow.async_step_init()
assert step['step_id'] == 'init'
schema = step['data_schema']
# no user listed
with pytest.raises(vol.Invalid):
assert schema({'user': owner.id})
with pytest.raises(vol.Invalid):
assert schema({'user': user.id})
with pytest.raises(vol.Invalid):
assert schema({'user': sys_user.id})
async def test_trusted_group_login(manager_with_user, provider_with_user):
"""Test config trusted_user with group_id."""
owner = await manager_with_user.async_create_user("test-owner")
# create a user in user group
user = await manager_with_user.async_create_user("test-user")
await manager_with_user.async_update_user(
user, group_ids=[auth.const.GROUP_ID_USER])
# change the trusted users config
config = provider_with_user.config['trusted_users']
assert ip_network('192.168.0.1') in config
config[ip_network('192.168.0.1')] = [{'group': [auth.const.GROUP_ID_USER]}]
assert ip_network('192.168.128.0/24') in config
config[ip_network('192.168.128.0/24')] = [
owner.id, {'group': [auth.const.GROUP_ID_USER]}]
# not from trusted network
flow = await provider_with_user.async_login_flow(
{'ip_address': ip_address('127.0.0.1')})
step = await flow.async_step_init()
assert step['type'] == 'abort'
assert step['reason'] == 'not_whitelisted'
# from trusted network, list users intersect trusted_users
flow = await provider_with_user.async_login_flow(
{'ip_address': ip_address('192.168.0.1')})
step = await flow.async_step_init()
assert step['step_id'] == 'init'
schema = step['data_schema']
# only user listed
print(user.id)
assert schema({'user': user.id})
with pytest.raises(vol.Invalid):
assert schema({'user': owner.id})
# from trusted network, list users intersect trusted_users
flow = await provider_with_user.async_login_flow(
{'ip_address': ip_address('192.168.128.1')})
step = await flow.async_step_init()
assert step['step_id'] == 'init'
schema = step['data_schema']
# both owner and user listed
assert schema({'user': owner.id})
assert schema({'user': user.id})
async def test_bypass_login_flow(manager_bypass_login, provider_bypass_login):
"""Test login flow can be bypass if only one user available."""
owner = await manager_bypass_login.async_create_user("test-owner")
# not from trusted network
flow = await provider_bypass_login.async_login_flow(
{'ip_address': ip_address('127.0.0.1')})
step = await flow.async_step_init()
assert step['type'] == 'abort'
assert step['reason'] == 'not_whitelisted'
# from trusted network, only one available user, bypass the login flow
flow = await provider_bypass_login.async_login_flow(
{'ip_address': ip_address('192.168.0.1')})
step = await flow.async_step_init()
assert step['type'] == 'create_entry'
assert step['data']['user'] == owner.id
user = await manager_bypass_login.async_create_user("test-user")
# from trusted network, two available user, show up login form
flow = await provider_bypass_login.async_login_flow(
{'ip_address': ip_address('192.168.0.1')})
step = await flow.async_step_init()
schema = step['data_schema']
# both owner and user listed
assert schema({'user': owner.id})
assert schema({'user': user.id})

View File

@ -4,6 +4,7 @@ import enum
import os import os
from socket import _GLOBAL_DEFAULT_TIMEOUT from socket import _GLOBAL_DEFAULT_TIMEOUT
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import uuid
import homeassistant import homeassistant
import pytest import pytest
@ -963,3 +964,24 @@ def test_entity_id_allow_old_validation(caplog):
assert "Found invalid entity_id {}".format(value) in caplog.text assert "Found invalid entity_id {}".format(value) in caplog.text
assert len(cv.INVALID_ENTITY_IDS_FOUND) == 2 assert len(cv.INVALID_ENTITY_IDS_FOUND) == 2
def test_uuid4_hex(caplog):
"""Test uuid validation."""
schema = vol.Schema(cv.uuid4_hex)
for value in ['Not a hex string', '0', 0]:
with pytest.raises(vol.Invalid):
schema(value)
with pytest.raises(vol.Invalid):
# the 13th char should be 4
schema('a03d31b22eee1acc9b90eec40be6ed23')
with pytest.raises(vol.Invalid):
# the 17th char should be 8-a
schema('a03d31b22eee4acc7b90eec40be6ed23')
hex = uuid.uuid4().hex
assert schema(hex) == hex
assert schema(hex.upper()) == hex