From 87385cf28e54091de21a22f011aafbc33e2badfd Mon Sep 17 00:00:00 2001 From: Mike Degatano Date: Tue, 7 Nov 2023 13:07:16 -0500 Subject: [PATCH] Fix saving ingress data on supervisor shutdown (#4672) * Fix saving ingress data on supervisor shutdown * Fix ci issues --- supervisor/const.py | 28 ++++++++++++++++++++++++++++ supervisor/ingress.py | 10 ++++++---- supervisor/validate.py | 8 +++++--- tests/test_ingress.py | 40 +++++++++++++++++++++++++++++++++------- 4 files changed, 72 insertions(+), 14 deletions(-) diff --git a/supervisor/const.py b/supervisor/const.py index 983733d1d..09c7fa5cd 100644 --- a/supervisor/const.py +++ b/supervisor/const.py @@ -4,6 +4,7 @@ from enum import StrEnum from ipaddress import ip_network from pathlib import Path from sys import version_info as systemversion +from typing import Self from aiohttp import __version__ as aiohttpversion @@ -159,6 +160,7 @@ ATTR_DISK_LED = "disk_led" ATTR_DISK_LIFE_TIME = "disk_life_time" ATTR_DISK_TOTAL = "disk_total" ATTR_DISK_USED = "disk_used" +ATTR_DISPLAYNAME = "displayname" ATTR_DNS = "dns" ATTR_DOCKER = "docker" ATTR_DOCKER_API = "docker_api" @@ -491,6 +493,23 @@ class IngressSessionDataUser: display_name: str | None = None username: str | None = None + def to_dict(self) -> dict[str, str | None]: + """Get dictionary representation.""" + return { + ATTR_ID: self.id, + ATTR_DISPLAYNAME: self.display_name, + ATTR_USERNAME: self.username, + } + + @classmethod + def from_dict(cls, data: dict[str, str | None]) -> Self: + """Return object from dictionary representation.""" + return cls( + id=data[ATTR_ID], + display_name=data.get(ATTR_DISPLAYNAME), + username=data.get(ATTR_USERNAME), + ) + @dataclass class IngressSessionData: @@ -498,6 +517,15 @@ class IngressSessionData: user: IngressSessionDataUser + def to_dict(self) -> dict[str, dict[str, str | None]]: + """Get dictionary representation.""" + return {ATTR_USER: self.user.to_dict()} + + @classmethod + def from_dict(cls, data: dict[str, dict[str, str | None]]) -> Self: + """Return object from dictionary representation.""" + return cls(user=IngressSessionDataUser.from_dict(data[ATTR_USER])) + STARTING_STATES = [ CoreState.INITIALIZE, diff --git a/supervisor/ingress.py b/supervisor/ingress.py index 362449802..f53a3e2a3 100644 --- a/supervisor/ingress.py +++ b/supervisor/ingress.py @@ -38,7 +38,9 @@ class Ingress(FileConfiguration, CoreSysAttributes): def get_session_data(self, session_id: str) -> IngressSessionData | None: """Return complementary data of current session or None.""" - return self.sessions_data.get(session_id) + if data := self.sessions_data.get(session_id): + return IngressSessionData.from_dict(data) + return None @property def sessions(self) -> dict[str, float]: @@ -46,7 +48,7 @@ class Ingress(FileConfiguration, CoreSysAttributes): return self._data[ATTR_SESSION] @property - def sessions_data(self) -> dict[str, IngressSessionData]: + def sessions_data(self) -> dict[str, dict[str, str | None]]: """Return sessions_data.""" return self._data[ATTR_SESSION_DATA] @@ -100,7 +102,7 @@ class Ingress(FileConfiguration, CoreSysAttributes): # Is valid sessions[session] = valid - sessions_data[session] = self.get_session_data(session) + sessions_data[session] = self.sessions_data.get(session) # Write back self.sessions.clear() @@ -123,7 +125,7 @@ class Ingress(FileConfiguration, CoreSysAttributes): self.sessions[session] = valid.timestamp() if data is not None: - self.sessions_data[session] = data + self.sessions_data[session] = data.to_dict() return session diff --git a/supervisor/validate.py b/supervisor/validate.py index 268cd5bdf..8aa3f2966 100644 --- a/supervisor/validate.py +++ b/supervisor/validate.py @@ -15,10 +15,12 @@ from .const import ( ATTR_DEBUG, ATTR_DEBUG_BLOCK, ATTR_DIAGNOSTICS, + ATTR_DISPLAYNAME, ATTR_DNS, ATTR_FORCE_SECURITY, ATTR_HASSOS, ATTR_HOMEASSISTANT, + ATTR_ID, ATTR_IMAGE, ATTR_LAST_BOOT, ATTR_LOGGING, @@ -186,9 +188,9 @@ SCHEMA_SESSION_DATA = vol.Schema( { vol.Required(ATTR_SESSION_DATA_USER): vol.Schema( { - vol.Required("id"): str, - vol.Required("username"): str, - vol.Required("displayname"): str, + vol.Required(ATTR_ID): str, + vol.Required(ATTR_USERNAME, default=None): vol.Maybe(str), + vol.Required(ATTR_DISPLAYNAME, default=None): vol.Maybe(str), } ) } diff --git a/tests/test_ingress.py b/tests/test_ingress.py index 0870d126e..69b43991e 100644 --- a/tests/test_ingress.py +++ b/tests/test_ingress.py @@ -1,11 +1,16 @@ """Test ingress.""" from datetime import timedelta +from pathlib import Path +from unittest.mock import ANY, patch -from supervisor.const import ATTR_SESSION_DATA_USER_ID +from supervisor.const import IngressSessionData, IngressSessionDataUser +from supervisor.coresys import CoreSys +from supervisor.ingress import Ingress from supervisor.utils.dt import utc_from_timestamp +from supervisor.utils.json import read_json_file -def test_session_handling(coresys): +def test_session_handling(coresys: CoreSys): """Create and test session.""" session = coresys.ingress.create_session() validate = coresys.ingress.sessions[session] @@ -25,19 +30,19 @@ def test_session_handling(coresys): assert session_data is None -def test_session_handling_with_session_data(coresys): +def test_session_handling_with_session_data(coresys: CoreSys): """Create and test session.""" session = coresys.ingress.create_session( - dict([(ATTR_SESSION_DATA_USER_ID, "some-id")]) + IngressSessionData(IngressSessionDataUser("some-id")) ) assert session session_data = coresys.ingress.get_session_data(session) - assert session_data[ATTR_SESSION_DATA_USER_ID] == "some-id" + assert session_data.user.id == "some-id" -async def test_save_on_unload(coresys): +async def test_save_on_unload(coresys: CoreSys): """Test called save on unload.""" coresys.ingress.create_session() await coresys.ingress.unload() @@ -45,7 +50,7 @@ async def test_save_on_unload(coresys): assert coresys.ingress.save_data.called -def test_dynamic_ports(coresys): +def test_dynamic_ports(coresys: CoreSys): """Test dyanmic port handling.""" port_test1 = coresys.ingress.get_dynamic_port("test1") @@ -62,3 +67,24 @@ def test_dynamic_ports(coresys): assert port_test2 < 65500 assert port_test1 > 62000 assert port_test1 < 65500 + + +async def test_ingress_save_data(coresys: CoreSys, tmp_supervisor_data: Path): + """Test saving ingress data to file.""" + config_file = tmp_supervisor_data / "ingress.json" + with patch("supervisor.ingress.FILE_HASSIO_INGRESS", new=config_file): + ingress = Ingress(coresys) + session = ingress.create_session( + IngressSessionData(IngressSessionDataUser("123", "Test", "test")) + ) + ingress.save_data() + + assert config_file.exists() + data = read_json_file(config_file) + assert data == { + "session": {session: ANY}, + "session_data": { + session: {"user": {"id": "123", "displayname": "Test", "username": "test"}} + }, + "ports": {}, + }