Allows the supervisor to send a session's user to addon with header X-Remote-User (#4152)

* Working draft for x-remote-user

* Renames prop to remote_user

* Allows to set in addon description whether it requests the username

* Fixes addon-options schema

* Sends user ID instead of username to addons

* Adds tests

* Removes configurability of remote-user forwarding

* Update const.py

* Also adds username header

* Fetches full user info object from homeassistant

* Cleaner validation and dataclasses

* Fixes linting

* Fixes linting

* Tries to fix test

* Updates tests

* Updates tests

* Updates tests

* Updates tests

* Updates tests

* Updates tests

* Updates tests

* Updates tests

* Resolves PR comments

* Linting

* Fixes tests

* Update const.py

* Removes header keys if not required

* Moves ignoring user ID headers if no session_data is given

* simplify

* fix lint with new job

---------

Co-authored-by: Pascal Vizeli <pvizeli@syshack.ch>
Co-authored-by: Pascal Vizeli <pascal.vizeli@syshack.ch>
This commit is contained in:
Florian Bachmann 2023-08-22 10:11:13 +02:00 committed by GitHub
parent 204fcdf479
commit acc0e5c989
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 224 additions and 12 deletions

View File

@ -21,11 +21,18 @@ from ..const import (
ATTR_ICON, ATTR_ICON,
ATTR_PANELS, ATTR_PANELS,
ATTR_SESSION, ATTR_SESSION,
ATTR_SESSION_DATA_USER_ID,
ATTR_TITLE, ATTR_TITLE,
HEADER_REMOTE_USER_DISPLAY_NAME,
HEADER_REMOTE_USER_ID,
HEADER_REMOTE_USER_NAME,
HEADER_TOKEN, HEADER_TOKEN,
HEADER_TOKEN_OLD, HEADER_TOKEN_OLD,
IngressSessionData,
IngressSessionDataUser,
) )
from ..coresys import CoreSysAttributes from ..coresys import CoreSysAttributes
from ..exceptions import HomeAssistantAPIError
from .const import COOKIE_INGRESS from .const import COOKIE_INGRESS
from .utils import api_process, api_validate, require_home_assistant from .utils import api_process, api_validate, require_home_assistant
@ -33,10 +40,23 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
VALIDATE_SESSION_DATA = vol.Schema({ATTR_SESSION: str}) VALIDATE_SESSION_DATA = vol.Schema({ATTR_SESSION: str})
"""Expected optional payload of create session request"""
SCHEMA_INGRESS_CREATE_SESSION_DATA = vol.Schema(
{
vol.Optional(ATTR_SESSION_DATA_USER_ID): str,
}
)
class APIIngress(CoreSysAttributes): class APIIngress(CoreSysAttributes):
"""Ingress view to handle add-on webui routing.""" """Ingress view to handle add-on webui routing."""
_list_of_users: list[IngressSessionDataUser]
def __init__(self) -> None:
"""Initialize APIIngress."""
self._list_of_users = []
def _extract_addon(self, request: web.Request) -> Addon: def _extract_addon(self, request: web.Request) -> Addon:
"""Return addon, throw an exception it it doesn't exist.""" """Return addon, throw an exception it it doesn't exist."""
token = request.match_info.get("token") token = request.match_info.get("token")
@ -71,7 +91,19 @@ class APIIngress(CoreSysAttributes):
@require_home_assistant @require_home_assistant
async def create_session(self, request: web.Request) -> dict[str, Any]: async def create_session(self, request: web.Request) -> dict[str, Any]:
"""Create a new session.""" """Create a new session."""
session = self.sys_ingress.create_session() schema_ingress_config_session_data = await api_validate(
SCHEMA_INGRESS_CREATE_SESSION_DATA, request
)
data: IngressSessionData | None = None
if ATTR_SESSION_DATA_USER_ID in schema_ingress_config_session_data:
user = await self._find_user_by_id(
schema_ingress_config_session_data[ATTR_SESSION_DATA_USER_ID]
)
if user:
data = IngressSessionData(user)
session = self.sys_ingress.create_session(data)
return {ATTR_SESSION: session} return {ATTR_SESSION: session}
@api_process @api_process
@ -99,13 +131,14 @@ class APIIngress(CoreSysAttributes):
# Process requests # Process requests
addon = self._extract_addon(request) addon = self._extract_addon(request)
path = request.match_info.get("path") path = request.match_info.get("path")
session_data = self.sys_ingress.get_session_data(session)
try: try:
# Websocket # Websocket
if _is_websocket(request): if _is_websocket(request):
return await self._handle_websocket(request, addon, path) return await self._handle_websocket(request, addon, path, session_data)
# Request # Request
return await self._handle_request(request, addon, path) return await self._handle_request(request, addon, path, session_data)
except aiohttp.ClientError as err: except aiohttp.ClientError as err:
_LOGGER.error("Ingress error: %s", err) _LOGGER.error("Ingress error: %s", err)
@ -113,7 +146,11 @@ class APIIngress(CoreSysAttributes):
raise HTTPBadGateway() raise HTTPBadGateway()
async def _handle_websocket( async def _handle_websocket(
self, request: web.Request, addon: Addon, path: str self,
request: web.Request,
addon: Addon,
path: str,
session_data: IngressSessionData | None,
) -> web.WebSocketResponse: ) -> web.WebSocketResponse:
"""Ingress route for websocket.""" """Ingress route for websocket."""
if hdrs.SEC_WEBSOCKET_PROTOCOL in request.headers: if hdrs.SEC_WEBSOCKET_PROTOCOL in request.headers:
@ -131,7 +168,7 @@ class APIIngress(CoreSysAttributes):
# Preparing # Preparing
url = self._create_url(addon, path) url = self._create_url(addon, path)
source_header = _init_header(request, addon) source_header = _init_header(request, addon, session_data)
# Support GET query # Support GET query
if request.query_string: if request.query_string:
@ -157,11 +194,15 @@ class APIIngress(CoreSysAttributes):
return ws_server return ws_server
async def _handle_request( async def _handle_request(
self, request: web.Request, addon: Addon, path: str self,
request: web.Request,
addon: Addon,
path: str,
session_data: IngressSessionData | None,
) -> web.Response | web.StreamResponse: ) -> web.Response | web.StreamResponse:
"""Ingress route for request.""" """Ingress route for request."""
url = self._create_url(addon, path) url = self._create_url(addon, path)
source_header = _init_header(request, addon) source_header = _init_header(request, addon, session_data)
# Passing the raw stream breaks requests for some webservers # Passing the raw stream breaks requests for some webservers
# since we just need it for POST requests really, for all other methods # since we just need it for POST requests really, for all other methods
@ -217,11 +258,33 @@ class APIIngress(CoreSysAttributes):
return response return response
async def _find_user_by_id(self, user_id: str) -> IngressSessionDataUser | None:
"""Find user object by the user's ID."""
try:
list_of_users = await self.sys_homeassistant.get_users()
except (HomeAssistantAPIError, TypeError) as err:
_LOGGER.error(
"%s error occurred while requesting list of users: %s", type(err), err
)
return None
def _init_header(request: web.Request, addon: str) -> CIMultiDict | dict[str, str]: if list_of_users is not None:
self._list_of_users = list_of_users
return next((user for user in self._list_of_users if user.id == user_id), None)
def _init_header(
request: web.Request, addon: Addon, session_data: IngressSessionData | None
) -> CIMultiDict | dict[str, str]:
"""Create initial header.""" """Create initial header."""
headers = {} headers = {}
if session_data is not None:
headers[HEADER_REMOTE_USER_ID] = session_data.user.id
headers[HEADER_REMOTE_USER_NAME] = session_data.user.username
headers[HEADER_REMOTE_USER_DISPLAY_NAME] = session_data.user.display_name
# filter flags # filter flags
for name, value in request.headers.items(): for name, value in request.headers.items():
if name in ( if name in (
@ -234,6 +297,9 @@ def _init_header(request: web.Request, addon: str) -> CIMultiDict | dict[str, st
hdrs.SEC_WEBSOCKET_KEY, hdrs.SEC_WEBSOCKET_KEY,
istr(HEADER_TOKEN), istr(HEADER_TOKEN),
istr(HEADER_TOKEN_OLD), istr(HEADER_TOKEN_OLD),
istr(HEADER_REMOTE_USER_ID),
istr(HEADER_REMOTE_USER_NAME),
istr(HEADER_REMOTE_USER_DISPLAY_NAME),
): ):
continue continue
headers[name] = value headers[name] = value

View File

@ -1,4 +1,5 @@
"""Constants file for Supervisor.""" """Constants file for Supervisor."""
from dataclasses import dataclass
from enum import Enum from enum import Enum
from ipaddress import ip_network from ipaddress import ip_network
from pathlib import Path from pathlib import Path
@ -69,6 +70,9 @@ JSON_RESULT = "result"
RESULT_ERROR = "error" RESULT_ERROR = "error"
RESULT_OK = "ok" RESULT_OK = "ok"
HEADER_REMOTE_USER_ID = "X-Remote-User-Id"
HEADER_REMOTE_USER_NAME = "X-Remote-User-Name"
HEADER_REMOTE_USER_DISPLAY_NAME = "X-Remote-User-Display-Name"
HEADER_TOKEN_OLD = "X-Hassio-Key" HEADER_TOKEN_OLD = "X-Hassio-Key"
HEADER_TOKEN = "X-Supervisor-Token" HEADER_TOKEN = "X-Supervisor-Token"
@ -271,6 +275,9 @@ ATTR_SERVERS = "servers"
ATTR_SERVICE = "service" ATTR_SERVICE = "service"
ATTR_SERVICES = "services" ATTR_SERVICES = "services"
ATTR_SESSION = "session" ATTR_SESSION = "session"
ATTR_SESSION_DATA = "session_data"
ATTR_SESSION_DATA_USER = "user"
ATTR_SESSION_DATA_USER_ID = "user_id"
ATTR_SIGNAL = "signal" ATTR_SIGNAL = "signal"
ATTR_SIZE = "size" ATTR_SIZE = "size"
ATTR_SLUG = "slug" ATTR_SLUG = "slug"
@ -464,6 +471,22 @@ class CpuArch(str, Enum):
AMD64 = "amd64" AMD64 = "amd64"
@dataclass
class IngressSessionDataUser:
"""Format of an IngressSessionDataUser object."""
id: str
display_name: str
username: str
@dataclass
class IngressSessionData:
"""Format of an IngressSessionData object."""
user: IngressSessionDataUser
STARTING_STATES = [ STARTING_STATES = [
CoreState.INITIALIZE, CoreState.INITIALIZE,
CoreState.STARTUP, CoreState.STARTUP,

View File

@ -1,5 +1,6 @@
"""Home Assistant control object.""" """Home Assistant control object."""
import asyncio import asyncio
from datetime import timedelta
from ipaddress import IPv4Address from ipaddress import IPv4Address
import logging import logging
from pathlib import Path, PurePath from pathlib import Path, PurePath
@ -28,6 +29,7 @@ from ..const import (
ATTR_WATCHDOG, ATTR_WATCHDOG,
FILE_HASSIO_HOMEASSISTANT, FILE_HASSIO_HOMEASSISTANT,
BusEvent, BusEvent,
IngressSessionDataUser,
) )
from ..coresys import CoreSys, CoreSysAttributes from ..coresys import CoreSys, CoreSysAttributes
from ..exceptions import ( from ..exceptions import (
@ -38,7 +40,7 @@ from ..exceptions import (
) )
from ..hardware.const import PolicyGroup from ..hardware.const import PolicyGroup
from ..hardware.data import Device from ..hardware.data import Device
from ..jobs.decorator import Job from ..jobs.decorator import Job, JobExecutionLimit
from ..utils import remove_folder from ..utils import remove_folder
from ..utils.common import FileConfiguration from ..utils.common import FileConfiguration
from ..utils.json import read_json_file, write_json_file from ..utils.json import read_json_file, write_json_file
@ -432,3 +434,21 @@ class HomeAssistant(FileConfiguration, CoreSysAttributes):
ATTR_WATCHDOG, ATTR_WATCHDOG,
): ):
self._data[attr] = data[attr] self._data[attr] = data[attr]
@Job(
name="home_assistant_get_users",
limit=JobExecutionLimit.THROTTLE_WAIT,
throttle_period=timedelta(minutes=5),
)
async def get_users(self) -> list[IngressSessionDataUser]:
"""Get list of all configured users."""
list_of_users = await self.sys_homeassistant.websocket.async_send_command(
{ATTR_TYPE: "config/auth/list"}
)
return [
IngressSessionDataUser(
id=data["id"], username=data["username"], display_name=data["name"]
)
for data in list_of_users
]

View File

@ -5,7 +5,13 @@ import random
import secrets import secrets
from .addons.addon import Addon from .addons.addon import Addon
from .const import ATTR_PORTS, ATTR_SESSION, FILE_HASSIO_INGRESS from .const import (
ATTR_PORTS,
ATTR_SESSION,
ATTR_SESSION_DATA,
FILE_HASSIO_INGRESS,
IngressSessionData,
)
from .coresys import CoreSys, CoreSysAttributes from .coresys import CoreSys, CoreSysAttributes
from .utils import check_port from .utils import check_port
from .utils.common import FileConfiguration from .utils.common import FileConfiguration
@ -30,11 +36,20 @@ class Ingress(FileConfiguration, CoreSysAttributes):
return None return None
return self.sys_addons.get(self.tokens[token], local_only=True) return self.sys_addons.get(self.tokens[token], local_only=True)
def get_session_data(self, session_id: str) -> IngressSessionData | None:
"""Return complementary data of current session or None."""
return self.sessions_data.get(session_id)
@property @property
def sessions(self) -> dict[str, float]: def sessions(self) -> dict[str, float]:
"""Return sessions.""" """Return sessions."""
return self._data[ATTR_SESSION] return self._data[ATTR_SESSION]
@property
def sessions_data(self) -> dict[str, IngressSessionData]:
"""Return sessions_data."""
return self._data[ATTR_SESSION_DATA]
@property @property
def ports(self) -> dict[str, int]: def ports(self) -> dict[str, int]:
"""Return list of dynamic ports.""" """Return list of dynamic ports."""
@ -71,6 +86,7 @@ class Ingress(FileConfiguration, CoreSysAttributes):
now = utcnow() now = utcnow()
sessions = {} sessions = {}
sessions_data: dict[str, IngressSessionData] = {}
for session, valid in self.sessions.items(): for session, valid in self.sessions.items():
# check if timestamp valid, to avoid crash on malformed timestamp # check if timestamp valid, to avoid crash on malformed timestamp
try: try:
@ -84,10 +100,13 @@ class Ingress(FileConfiguration, CoreSysAttributes):
# Is valid # Is valid
sessions[session] = valid sessions[session] = valid
sessions_data[session] = self.get_session_data(session)
# Write back # Write back
self.sessions.clear() self.sessions.clear()
self.sessions.update(sessions) self.sessions.update(sessions)
self.sessions_data.clear()
self.sessions_data.update(sessions_data)
def _update_token_list(self) -> None: def _update_token_list(self) -> None:
"""Regenerate token <-> Add-on map.""" """Regenerate token <-> Add-on map."""
@ -97,12 +116,15 @@ class Ingress(FileConfiguration, CoreSysAttributes):
for addon in self.addons: for addon in self.addons:
self.tokens[addon.ingress_token] = addon.slug self.tokens[addon.ingress_token] = addon.slug
def create_session(self) -> str: def create_session(self, data: IngressSessionData | None = None) -> str:
"""Create new session.""" """Create new session."""
session = secrets.token_hex(64) session = secrets.token_hex(64)
valid = utcnow() + timedelta(minutes=15) valid = utcnow() + timedelta(minutes=15)
self.sessions[session] = valid.timestamp() self.sessions[session] = valid.timestamp()
if data is not None:
self.sessions_data[session] = data
return session return session
def validate_session(self, session: str) -> bool: def validate_session(self, session: str) -> bool:

View File

@ -30,6 +30,8 @@ from .const import (
ATTR_PWNED, ATTR_PWNED,
ATTR_REGISTRIES, ATTR_REGISTRIES,
ATTR_SESSION, ATTR_SESSION,
ATTR_SESSION_DATA,
ATTR_SESSION_DATA_USER,
ATTR_SUPERVISOR, ATTR_SUPERVISOR,
ATTR_TIMEZONE, ATTR_TIMEZONE,
ATTR_USERNAME, ATTR_USERNAME,
@ -178,18 +180,33 @@ SCHEMA_DOCKER_CONFIG = vol.Schema(
SCHEMA_AUTH_CONFIG = vol.Schema({sha256: sha256}) SCHEMA_AUTH_CONFIG = vol.Schema({sha256: sha256})
SCHEMA_SESSION_DATA = vol.Schema(
{
token: vol.Schema(
{
vol.Required(ATTR_SESSION_DATA_USER): vol.Schema(
{
vol.Required("id"): str,
vol.Required("username"): str,
vol.Required("displayname"): str,
}
)
}
)
}
)
SCHEMA_INGRESS_CONFIG = vol.Schema( SCHEMA_INGRESS_CONFIG = vol.Schema(
{ {
vol.Required(ATTR_SESSION, default=dict): vol.Schema( vol.Required(ATTR_SESSION, default=dict): vol.Schema(
{token: vol.Coerce(float)} {token: vol.Coerce(float)}
), ),
vol.Required(ATTR_SESSION_DATA, default=dict): SCHEMA_SESSION_DATA,
vol.Required(ATTR_PORTS, default=dict): vol.Schema({str: network_port}), vol.Required(ATTR_PORTS, default=dict): vol.Schema({str: network_port}),
}, },
extra=vol.REMOVE_EXTRA, extra=vol.REMOVE_EXTRA,
) )
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter
SCHEMA_SECURITY_CONFIG = vol.Schema( SCHEMA_SECURITY_CONFIG = vol.Schema(
{ {

View File

@ -1,4 +1,5 @@
"""Test ingress API.""" """Test ingress API."""
# pylint: disable=protected-access
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -37,3 +38,50 @@ async def test_validate_session(api_client, coresys):
assert await resp.json() == {"result": "ok", "data": {}} assert await resp.json() == {"result": "ok", "data": {}}
assert coresys.ingress.sessions[session] > valid_time assert coresys.ingress.sessions[session] > valid_time
@pytest.mark.asyncio
async def test_validate_session_with_user_id(api_client, coresys):
"""Test validating ingress session with user ID passed."""
with patch("aiohttp.web_request.BaseRequest.__getitem__", return_value=None):
resp = await api_client.post(
"/ingress/validate_session",
json={"session": "non-existing"},
)
assert resp.status == 401
with patch(
"aiohttp.web_request.BaseRequest.__getitem__",
return_value=coresys.homeassistant,
):
client = coresys.homeassistant.websocket._client
client.async_send_command.return_value = [
{"id": "some-id", "name": "Some Name", "username": "sn"}
]
resp = await api_client.post("/ingress/session", json={"user_id": "some-id"})
result = await resp.json()
client.async_send_command.assert_called_with({"type": "config/auth/list"})
assert "session" in result["data"]
session = result["data"]["session"]
assert session in coresys.ingress.sessions
valid_time = coresys.ingress.sessions[session]
resp = await api_client.post(
"/ingress/validate_session",
json={"session": session},
)
assert resp.status == 200
assert await resp.json() == {"result": "ok", "data": {}}
assert coresys.ingress.sessions[session] > valid_time
assert session in coresys.ingress.sessions_data
assert coresys.ingress.get_session_data(session).user.id == "some-id"
assert coresys.ingress.get_session_data(session).user.username == "sn"
assert (
coresys.ingress.get_session_data(session).user.display_name == "Some Name"
)

View File

@ -1,6 +1,7 @@
"""Test ingress.""" """Test ingress."""
from datetime import timedelta from datetime import timedelta
from supervisor.const import ATTR_SESSION_DATA_USER_ID
from supervisor.utils.dt import utc_from_timestamp from supervisor.utils.dt import utc_from_timestamp
@ -20,6 +21,21 @@ def test_session_handling(coresys):
assert not coresys.ingress.validate_session(session) assert not coresys.ingress.validate_session(session)
assert not coresys.ingress.validate_session("invalid session") assert not coresys.ingress.validate_session("invalid session")
session_data = coresys.ingress.get_session_data(session)
assert session_data is None
def test_session_handling_with_session_data(coresys):
"""Create and test session."""
session = coresys.ingress.create_session(
dict([(ATTR_SESSION_DATA_USER_ID, "some-id")])
)
assert session
session_data = coresys.ingress.get_session_data(session)
assert session_data[ATTR_SESSION_DATA_USER_ID] == "some-id"
async def test_save_on_unload(coresys): async def test_save_on_unload(coresys):
"""Test called save on unload.""" """Test called save on unload."""