Allows the supervisor to send a session's user to addon with header X-Remote-User (#4152)

* Working draft for x-remote-user

* Renames prop to remote_user

* Allows to set in addon description whether it requests the username

* Fixes addon-options schema

* Sends user ID instead of username to addons

* Adds tests

* Removes configurability of remote-user forwarding

* Update const.py

* Also adds username header

* Fetches full user info object from homeassistant

* Cleaner validation and dataclasses

* Fixes linting

* Fixes linting

* Tries to fix test

* Updates tests

* Updates tests

* Updates tests

* Updates tests

* Updates tests

* Updates tests

* Updates tests

* Updates tests

* Resolves PR comments

* Linting

* Fixes tests

* Update const.py

* Removes header keys if not required

* Moves ignoring user ID headers if no session_data is given

* simplify

* fix lint with new job

---------

Co-authored-by: Pascal Vizeli <pvizeli@syshack.ch>
Co-authored-by: Pascal Vizeli <pascal.vizeli@syshack.ch>
This commit is contained in:
Florian Bachmann 2023-08-22 10:11:13 +02:00 committed by GitHub
parent 204fcdf479
commit acc0e5c989
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 224 additions and 12 deletions

View File

@ -21,11 +21,18 @@ from ..const import (
ATTR_ICON,
ATTR_PANELS,
ATTR_SESSION,
ATTR_SESSION_DATA_USER_ID,
ATTR_TITLE,
HEADER_REMOTE_USER_DISPLAY_NAME,
HEADER_REMOTE_USER_ID,
HEADER_REMOTE_USER_NAME,
HEADER_TOKEN,
HEADER_TOKEN_OLD,
IngressSessionData,
IngressSessionDataUser,
)
from ..coresys import CoreSysAttributes
from ..exceptions import HomeAssistantAPIError
from .const import COOKIE_INGRESS
from .utils import api_process, api_validate, require_home_assistant
@ -33,10 +40,23 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
VALIDATE_SESSION_DATA = vol.Schema({ATTR_SESSION: str})
"""Expected optional payload of create session request"""
SCHEMA_INGRESS_CREATE_SESSION_DATA = vol.Schema(
{
vol.Optional(ATTR_SESSION_DATA_USER_ID): str,
}
)
class APIIngress(CoreSysAttributes):
"""Ingress view to handle add-on webui routing."""
_list_of_users: list[IngressSessionDataUser]
def __init__(self) -> None:
"""Initialize APIIngress."""
self._list_of_users = []
def _extract_addon(self, request: web.Request) -> Addon:
"""Return addon, throw an exception it it doesn't exist."""
token = request.match_info.get("token")
@ -71,7 +91,19 @@ class APIIngress(CoreSysAttributes):
@require_home_assistant
async def create_session(self, request: web.Request) -> dict[str, Any]:
"""Create a new session."""
session = self.sys_ingress.create_session()
schema_ingress_config_session_data = await api_validate(
SCHEMA_INGRESS_CREATE_SESSION_DATA, request
)
data: IngressSessionData | None = None
if ATTR_SESSION_DATA_USER_ID in schema_ingress_config_session_data:
user = await self._find_user_by_id(
schema_ingress_config_session_data[ATTR_SESSION_DATA_USER_ID]
)
if user:
data = IngressSessionData(user)
session = self.sys_ingress.create_session(data)
return {ATTR_SESSION: session}
@api_process
@ -99,13 +131,14 @@ class APIIngress(CoreSysAttributes):
# Process requests
addon = self._extract_addon(request)
path = request.match_info.get("path")
session_data = self.sys_ingress.get_session_data(session)
try:
# Websocket
if _is_websocket(request):
return await self._handle_websocket(request, addon, path)
return await self._handle_websocket(request, addon, path, session_data)
# Request
return await self._handle_request(request, addon, path)
return await self._handle_request(request, addon, path, session_data)
except aiohttp.ClientError as err:
_LOGGER.error("Ingress error: %s", err)
@ -113,7 +146,11 @@ class APIIngress(CoreSysAttributes):
raise HTTPBadGateway()
async def _handle_websocket(
self, request: web.Request, addon: Addon, path: str
self,
request: web.Request,
addon: Addon,
path: str,
session_data: IngressSessionData | None,
) -> web.WebSocketResponse:
"""Ingress route for websocket."""
if hdrs.SEC_WEBSOCKET_PROTOCOL in request.headers:
@ -131,7 +168,7 @@ class APIIngress(CoreSysAttributes):
# Preparing
url = self._create_url(addon, path)
source_header = _init_header(request, addon)
source_header = _init_header(request, addon, session_data)
# Support GET query
if request.query_string:
@ -157,11 +194,15 @@ class APIIngress(CoreSysAttributes):
return ws_server
async def _handle_request(
self, request: web.Request, addon: Addon, path: str
self,
request: web.Request,
addon: Addon,
path: str,
session_data: IngressSessionData | None,
) -> web.Response | web.StreamResponse:
"""Ingress route for request."""
url = self._create_url(addon, path)
source_header = _init_header(request, addon)
source_header = _init_header(request, addon, session_data)
# Passing the raw stream breaks requests for some webservers
# since we just need it for POST requests really, for all other methods
@ -217,11 +258,33 @@ class APIIngress(CoreSysAttributes):
return response
async def _find_user_by_id(self, user_id: str) -> IngressSessionDataUser | None:
"""Find user object by the user's ID."""
try:
list_of_users = await self.sys_homeassistant.get_users()
except (HomeAssistantAPIError, TypeError) as err:
_LOGGER.error(
"%s error occurred while requesting list of users: %s", type(err), err
)
return None
def _init_header(request: web.Request, addon: str) -> CIMultiDict | dict[str, str]:
if list_of_users is not None:
self._list_of_users = list_of_users
return next((user for user in self._list_of_users if user.id == user_id), None)
def _init_header(
request: web.Request, addon: Addon, session_data: IngressSessionData | None
) -> CIMultiDict | dict[str, str]:
"""Create initial header."""
headers = {}
if session_data is not None:
headers[HEADER_REMOTE_USER_ID] = session_data.user.id
headers[HEADER_REMOTE_USER_NAME] = session_data.user.username
headers[HEADER_REMOTE_USER_DISPLAY_NAME] = session_data.user.display_name
# filter flags
for name, value in request.headers.items():
if name in (
@ -234,6 +297,9 @@ def _init_header(request: web.Request, addon: str) -> CIMultiDict | dict[str, st
hdrs.SEC_WEBSOCKET_KEY,
istr(HEADER_TOKEN),
istr(HEADER_TOKEN_OLD),
istr(HEADER_REMOTE_USER_ID),
istr(HEADER_REMOTE_USER_NAME),
istr(HEADER_REMOTE_USER_DISPLAY_NAME),
):
continue
headers[name] = value

View File

@ -1,4 +1,5 @@
"""Constants file for Supervisor."""
from dataclasses import dataclass
from enum import Enum
from ipaddress import ip_network
from pathlib import Path
@ -69,6 +70,9 @@ JSON_RESULT = "result"
RESULT_ERROR = "error"
RESULT_OK = "ok"
HEADER_REMOTE_USER_ID = "X-Remote-User-Id"
HEADER_REMOTE_USER_NAME = "X-Remote-User-Name"
HEADER_REMOTE_USER_DISPLAY_NAME = "X-Remote-User-Display-Name"
HEADER_TOKEN_OLD = "X-Hassio-Key"
HEADER_TOKEN = "X-Supervisor-Token"
@ -271,6 +275,9 @@ ATTR_SERVERS = "servers"
ATTR_SERVICE = "service"
ATTR_SERVICES = "services"
ATTR_SESSION = "session"
ATTR_SESSION_DATA = "session_data"
ATTR_SESSION_DATA_USER = "user"
ATTR_SESSION_DATA_USER_ID = "user_id"
ATTR_SIGNAL = "signal"
ATTR_SIZE = "size"
ATTR_SLUG = "slug"
@ -464,6 +471,22 @@ class CpuArch(str, Enum):
AMD64 = "amd64"
@dataclass
class IngressSessionDataUser:
"""Format of an IngressSessionDataUser object."""
id: str
display_name: str
username: str
@dataclass
class IngressSessionData:
"""Format of an IngressSessionData object."""
user: IngressSessionDataUser
STARTING_STATES = [
CoreState.INITIALIZE,
CoreState.STARTUP,

View File

@ -1,5 +1,6 @@
"""Home Assistant control object."""
import asyncio
from datetime import timedelta
from ipaddress import IPv4Address
import logging
from pathlib import Path, PurePath
@ -28,6 +29,7 @@ from ..const import (
ATTR_WATCHDOG,
FILE_HASSIO_HOMEASSISTANT,
BusEvent,
IngressSessionDataUser,
)
from ..coresys import CoreSys, CoreSysAttributes
from ..exceptions import (
@ -38,7 +40,7 @@ from ..exceptions import (
)
from ..hardware.const import PolicyGroup
from ..hardware.data import Device
from ..jobs.decorator import Job
from ..jobs.decorator import Job, JobExecutionLimit
from ..utils import remove_folder
from ..utils.common import FileConfiguration
from ..utils.json import read_json_file, write_json_file
@ -432,3 +434,21 @@ class HomeAssistant(FileConfiguration, CoreSysAttributes):
ATTR_WATCHDOG,
):
self._data[attr] = data[attr]
@Job(
name="home_assistant_get_users",
limit=JobExecutionLimit.THROTTLE_WAIT,
throttle_period=timedelta(minutes=5),
)
async def get_users(self) -> list[IngressSessionDataUser]:
"""Get list of all configured users."""
list_of_users = await self.sys_homeassistant.websocket.async_send_command(
{ATTR_TYPE: "config/auth/list"}
)
return [
IngressSessionDataUser(
id=data["id"], username=data["username"], display_name=data["name"]
)
for data in list_of_users
]

View File

@ -5,7 +5,13 @@ import random
import secrets
from .addons.addon import Addon
from .const import ATTR_PORTS, ATTR_SESSION, FILE_HASSIO_INGRESS
from .const import (
ATTR_PORTS,
ATTR_SESSION,
ATTR_SESSION_DATA,
FILE_HASSIO_INGRESS,
IngressSessionData,
)
from .coresys import CoreSys, CoreSysAttributes
from .utils import check_port
from .utils.common import FileConfiguration
@ -30,11 +36,20 @@ class Ingress(FileConfiguration, CoreSysAttributes):
return None
return self.sys_addons.get(self.tokens[token], local_only=True)
def get_session_data(self, session_id: str) -> IngressSessionData | None:
"""Return complementary data of current session or None."""
return self.sessions_data.get(session_id)
@property
def sessions(self) -> dict[str, float]:
"""Return sessions."""
return self._data[ATTR_SESSION]
@property
def sessions_data(self) -> dict[str, IngressSessionData]:
"""Return sessions_data."""
return self._data[ATTR_SESSION_DATA]
@property
def ports(self) -> dict[str, int]:
"""Return list of dynamic ports."""
@ -71,6 +86,7 @@ class Ingress(FileConfiguration, CoreSysAttributes):
now = utcnow()
sessions = {}
sessions_data: dict[str, IngressSessionData] = {}
for session, valid in self.sessions.items():
# check if timestamp valid, to avoid crash on malformed timestamp
try:
@ -84,10 +100,13 @@ class Ingress(FileConfiguration, CoreSysAttributes):
# Is valid
sessions[session] = valid
sessions_data[session] = self.get_session_data(session)
# Write back
self.sessions.clear()
self.sessions.update(sessions)
self.sessions_data.clear()
self.sessions_data.update(sessions_data)
def _update_token_list(self) -> None:
"""Regenerate token <-> Add-on map."""
@ -97,12 +116,15 @@ class Ingress(FileConfiguration, CoreSysAttributes):
for addon in self.addons:
self.tokens[addon.ingress_token] = addon.slug
def create_session(self) -> str:
def create_session(self, data: IngressSessionData | None = None) -> str:
"""Create new session."""
session = secrets.token_hex(64)
valid = utcnow() + timedelta(minutes=15)
self.sessions[session] = valid.timestamp()
if data is not None:
self.sessions_data[session] = data
return session
def validate_session(self, session: str) -> bool:

View File

@ -30,6 +30,8 @@ from .const import (
ATTR_PWNED,
ATTR_REGISTRIES,
ATTR_SESSION,
ATTR_SESSION_DATA,
ATTR_SESSION_DATA_USER,
ATTR_SUPERVISOR,
ATTR_TIMEZONE,
ATTR_USERNAME,
@ -178,18 +180,33 @@ SCHEMA_DOCKER_CONFIG = vol.Schema(
SCHEMA_AUTH_CONFIG = vol.Schema({sha256: sha256})
SCHEMA_SESSION_DATA = vol.Schema(
{
token: vol.Schema(
{
vol.Required(ATTR_SESSION_DATA_USER): vol.Schema(
{
vol.Required("id"): str,
vol.Required("username"): str,
vol.Required("displayname"): str,
}
)
}
)
}
)
SCHEMA_INGRESS_CONFIG = vol.Schema(
{
vol.Required(ATTR_SESSION, default=dict): vol.Schema(
{token: vol.Coerce(float)}
),
vol.Required(ATTR_SESSION_DATA, default=dict): SCHEMA_SESSION_DATA,
vol.Required(ATTR_PORTS, default=dict): vol.Schema({str: network_port}),
},
extra=vol.REMOVE_EXTRA,
)
# pylint: disable=no-value-for-parameter
SCHEMA_SECURITY_CONFIG = vol.Schema(
{

View File

@ -1,4 +1,5 @@
"""Test ingress API."""
# pylint: disable=protected-access
from unittest.mock import patch
import pytest
@ -37,3 +38,50 @@ async def test_validate_session(api_client, coresys):
assert await resp.json() == {"result": "ok", "data": {}}
assert coresys.ingress.sessions[session] > valid_time
@pytest.mark.asyncio
async def test_validate_session_with_user_id(api_client, coresys):
"""Test validating ingress session with user ID passed."""
with patch("aiohttp.web_request.BaseRequest.__getitem__", return_value=None):
resp = await api_client.post(
"/ingress/validate_session",
json={"session": "non-existing"},
)
assert resp.status == 401
with patch(
"aiohttp.web_request.BaseRequest.__getitem__",
return_value=coresys.homeassistant,
):
client = coresys.homeassistant.websocket._client
client.async_send_command.return_value = [
{"id": "some-id", "name": "Some Name", "username": "sn"}
]
resp = await api_client.post("/ingress/session", json={"user_id": "some-id"})
result = await resp.json()
client.async_send_command.assert_called_with({"type": "config/auth/list"})
assert "session" in result["data"]
session = result["data"]["session"]
assert session in coresys.ingress.sessions
valid_time = coresys.ingress.sessions[session]
resp = await api_client.post(
"/ingress/validate_session",
json={"session": session},
)
assert resp.status == 200
assert await resp.json() == {"result": "ok", "data": {}}
assert coresys.ingress.sessions[session] > valid_time
assert session in coresys.ingress.sessions_data
assert coresys.ingress.get_session_data(session).user.id == "some-id"
assert coresys.ingress.get_session_data(session).user.username == "sn"
assert (
coresys.ingress.get_session_data(session).user.display_name == "Some Name"
)

View File

@ -1,6 +1,7 @@
"""Test ingress."""
from datetime import timedelta
from supervisor.const import ATTR_SESSION_DATA_USER_ID
from supervisor.utils.dt import utc_from_timestamp
@ -20,6 +21,21 @@ def test_session_handling(coresys):
assert not coresys.ingress.validate_session(session)
assert not coresys.ingress.validate_session("invalid session")
session_data = coresys.ingress.get_session_data(session)
assert session_data is None
def test_session_handling_with_session_data(coresys):
"""Create and test session."""
session = coresys.ingress.create_session(
dict([(ATTR_SESSION_DATA_USER_ID, "some-id")])
)
assert session
session_data = coresys.ingress.get_session_data(session)
assert session_data[ATTR_SESSION_DATA_USER_ID] == "some-id"
async def test_save_on_unload(coresys):
"""Test called save on unload."""