Enable mypy part 1 (addons and api) (#5759)

* Fix mypy issues in addons

* Fix mypy issues in api

* fix docstring

* Brackets instead of get with default
This commit is contained in:
Mike Degatano 2025-03-25 15:06:35 -04:00 committed by GitHub
parent 543d6efec4
commit 0636e49fe2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
62 changed files with 500 additions and 382 deletions

View File

@ -18,7 +18,7 @@ from tempfile import TemporaryDirectory
from typing import Any, Final from typing import Any, Final
import aiohttp import aiohttp
from awesomeversion import AwesomeVersionCompareException from awesomeversion import AwesomeVersion, AwesomeVersionCompareException
from deepmerge import Merger from deepmerge import Merger
from securetar import AddFileError, atomic_contents_add, secure_path from securetar import AddFileError, atomic_contents_add, secure_path
import voluptuous as vol import voluptuous as vol
@ -285,28 +285,28 @@ class Addon(AddonModel):
@property @property
def with_icon(self) -> bool: def with_icon(self) -> bool:
"""Return True if an icon exists.""" """Return True if an icon exists."""
if self.is_detached: if self.is_detached or not self.addon_store:
return super().with_icon return super().with_icon
return self.addon_store.with_icon return self.addon_store.with_icon
@property @property
def with_logo(self) -> bool: def with_logo(self) -> bool:
"""Return True if a logo exists.""" """Return True if a logo exists."""
if self.is_detached: if self.is_detached or not self.addon_store:
return super().with_logo return super().with_logo
return self.addon_store.with_logo return self.addon_store.with_logo
@property @property
def with_changelog(self) -> bool: def with_changelog(self) -> bool:
"""Return True if a changelog exists.""" """Return True if a changelog exists."""
if self.is_detached: if self.is_detached or not self.addon_store:
return super().with_changelog return super().with_changelog
return self.addon_store.with_changelog return self.addon_store.with_changelog
@property @property
def with_documentation(self) -> bool: def with_documentation(self) -> bool:
"""Return True if a documentation exists.""" """Return True if a documentation exists."""
if self.is_detached: if self.is_detached or not self.addon_store:
return super().with_documentation return super().with_documentation
return self.addon_store.with_documentation return self.addon_store.with_documentation
@ -316,7 +316,7 @@ class Addon(AddonModel):
return self._available(self.data_store) return self._available(self.data_store)
@property @property
def version(self) -> str | None: def version(self) -> AwesomeVersion:
"""Return installed version.""" """Return installed version."""
return self.persist[ATTR_VERSION] return self.persist[ATTR_VERSION]
@ -464,7 +464,7 @@ class Addon(AddonModel):
return None return None
@property @property
def latest_version(self) -> str: def latest_version(self) -> AwesomeVersion:
"""Return version of add-on.""" """Return version of add-on."""
return self.data_store[ATTR_VERSION] return self.data_store[ATTR_VERSION]
@ -518,9 +518,8 @@ class Addon(AddonModel):
def webui(self) -> str | None: def webui(self) -> str | None:
"""Return URL to webui or None.""" """Return URL to webui or None."""
url = super().webui url = super().webui
if not url: if not url or not (webui := RE_WEBUI.match(url)):
return None return None
webui = RE_WEBUI.match(url)
# extract arguments # extract arguments
t_port = webui.group("t_port") t_port = webui.group("t_port")
@ -675,10 +674,9 @@ class Addon(AddonModel):
async def watchdog_application(self) -> bool: async def watchdog_application(self) -> bool:
"""Return True if application is running.""" """Return True if application is running."""
url = super().watchdog url = self.watchdog_url
if not url: if not url or not (application := RE_WATCHDOG.match(url)):
return True return True
application = RE_WATCHDOG.match(url)
# extract arguments # extract arguments
t_port = int(application.group("t_port")) t_port = int(application.group("t_port"))
@ -687,8 +685,10 @@ class Addon(AddonModel):
s_suffix = application.group("s_suffix") or "" s_suffix = application.group("s_suffix") or ""
# search host port for this docker port # search host port for this docker port
if self.host_network: if self.host_network and self.ports:
port = self.ports.get(f"{t_port}/tcp", t_port) port = self.ports.get(f"{t_port}/tcp")
if port is None:
port = t_port
else: else:
port = t_port port = t_port
@ -777,6 +777,9 @@ class Addon(AddonModel):
) )
async def install(self) -> None: async def install(self) -> None:
"""Install and setup this addon.""" """Install and setup this addon."""
if not self.addon_store:
raise AddonsError("Missing from store, cannot install!")
await self.sys_addons.data.install(self.addon_store) await self.sys_addons.data.install(self.addon_store)
await self.load() await self.load()
@ -880,6 +883,9 @@ class Addon(AddonModel):
Returns a Task that completes when addon has state 'started' (see start) Returns a Task that completes when addon has state 'started' (see start)
if it was running. Else nothing is returned. if it was running. Else nothing is returned.
""" """
if not self.addon_store:
raise AddonsError("Missing from store, cannot update!")
old_image = self.image old_image = self.image
# Cache data to prevent races with other updates to global # Cache data to prevent races with other updates to global
store = self.addon_store.clone() store = self.addon_store.clone()
@ -936,7 +942,9 @@ class Addon(AddonModel):
except DockerError as err: except DockerError as err:
raise AddonsError() from err raise AddonsError() from err
await self.sys_addons.data.update(self.addon_store) if self.addon_store:
await self.sys_addons.data.update(self.addon_store)
await self._check_ingress_port() await self._check_ingress_port()
_LOGGER.info("Add-on '%s' successfully rebuilt", self.slug) _LOGGER.info("Add-on '%s' successfully rebuilt", self.slug)
@ -965,7 +973,9 @@ class Addon(AddonModel):
await self.sys_run_in_executor(write_pulse_config) await self.sys_run_in_executor(write_pulse_config)
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
_LOGGER.error( _LOGGER.error(
"Add-on %s can't write pulse/client.config: %s", self.slug, err "Add-on %s can't write pulse/client.config: %s", self.slug, err
) )
@ -1324,7 +1334,7 @@ class Addon(AddonModel):
arcname="config", arcname="config",
) )
wait_for_start: Awaitable[None] | None = None wait_for_start: asyncio.Task | None = None
data = { data = {
ATTR_USER: self.persist, ATTR_USER: self.persist,
@ -1370,7 +1380,7 @@ class Addon(AddonModel):
Returns a Task that completes when addon has state 'started' (see start) Returns a Task that completes when addon has state 'started' (see start)
if addon is started after restore. Else nothing is returned. if addon is started after restore. Else nothing is returned.
""" """
wait_for_start: Awaitable[None] | None = None wait_for_start: asyncio.Task | None = None
# Extract backup # Extract backup
def _extract_tarfile() -> tuple[TemporaryDirectory, dict[str, Any]]: def _extract_tarfile() -> tuple[TemporaryDirectory, dict[str, Any]]:
@ -1594,6 +1604,6 @@ class Addon(AddonModel):
def refresh_path_cache(self) -> Awaitable[None]: def refresh_path_cache(self) -> Awaitable[None]:
"""Refresh cache of existing paths.""" """Refresh cache of existing paths."""
if self.is_detached: if self.is_detached or not self.addon_store:
return super().refresh_path_cache() return super().refresh_path_cache()
return self.addon_store.refresh_path_cache() return self.addon_store.refresh_path_cache()

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any
from awesomeversion import AwesomeVersion from awesomeversion import AwesomeVersion
@ -23,7 +23,7 @@ from ..utils.common import FileConfiguration, find_one_filetype
from .validate import SCHEMA_BUILD_CONFIG from .validate import SCHEMA_BUILD_CONFIG
if TYPE_CHECKING: if TYPE_CHECKING:
from . import AnyAddon from .manager import AnyAddon
class AddonBuild(FileConfiguration, CoreSysAttributes): class AddonBuild(FileConfiguration, CoreSysAttributes):
@ -63,7 +63,7 @@ class AddonBuild(FileConfiguration, CoreSysAttributes):
@cached_property @cached_property
def arch(self) -> str: def arch(self) -> str:
"""Return arch of the add-on.""" """Return arch of the add-on."""
return self.sys_arch.match(self.addon.arch) return self.sys_arch.match([self.addon.arch])
@property @property
def base_image(self) -> str: def base_image(self) -> str:
@ -126,7 +126,7 @@ class AddonBuild(FileConfiguration, CoreSysAttributes):
Must be run in executor. Must be run in executor.
""" """
args = { args: dict[str, Any] = {
"path": str(self.addon.path_location), "path": str(self.addon.path_location),
"tag": f"{image or self.addon.image}:{version!s}", "tag": f"{image or self.addon.image}:{version!s}",
"dockerfile": str(self.get_dockerfile()), "dockerfile": str(self.get_dockerfile()),

View File

@ -194,6 +194,7 @@ class AddonManager(CoreSysAttributes):
_LOGGER.info("Add-on '%s' successfully installed", slug) _LOGGER.info("Add-on '%s' successfully installed", slug)
@Job(name="addon_manager_uninstall")
async def uninstall(self, slug: str, *, remove_config: bool = False) -> None: async def uninstall(self, slug: str, *, remove_config: bool = False) -> None:
"""Remove an add-on.""" """Remove an add-on."""
if slug not in self.local: if slug not in self.local:
@ -313,7 +314,7 @@ class AddonManager(CoreSysAttributes):
if slug not in self.local: if slug not in self.local:
_LOGGER.debug("Add-on %s is not local available for restore", slug) _LOGGER.debug("Add-on %s is not local available for restore", slug)
addon = Addon(self.coresys, slug) addon = Addon(self.coresys, slug)
had_ingress = False had_ingress: bool | None = False
else: else:
_LOGGER.debug("Add-on %s is local available for restore", slug) _LOGGER.debug("Add-on %s is local available for restore", slug)
addon = self.local[slug] addon = self.local[slug]

View File

@ -294,7 +294,7 @@ class AddonModel(JobGroup, ABC):
return self.data.get(ATTR_WEBUI) return self.data.get(ATTR_WEBUI)
@property @property
def watchdog(self) -> str | None: def watchdog_url(self) -> str | None:
"""Return URL to for watchdog or None.""" """Return URL to for watchdog or None."""
return self.data.get(ATTR_WATCHDOG) return self.data.get(ATTR_WATCHDOG)
@ -606,7 +606,7 @@ class AddonModel(JobGroup, ABC):
return AddonOptions(self.coresys, raw_schema, self.name, self.slug) return AddonOptions(self.coresys, raw_schema, self.name, self.slug)
@property @property
def schema_ui(self) -> list[dict[any, any]] | None: def schema_ui(self) -> list[dict[Any, Any]] | None:
"""Create a UI schema for add-on options.""" """Create a UI schema for add-on options."""
raw_schema = self.data[ATTR_SCHEMA] raw_schema = self.data[ATTR_SCHEMA]

View File

@ -137,7 +137,7 @@ class AddonOptions(CoreSysAttributes):
) from None ) from None
# prepare range # prepare range
range_args = {} range_args: dict[str, Any] = {}
for group_name in _SCHEMA_LENGTH_PARTS: for group_name in _SCHEMA_LENGTH_PARTS:
group_value = match.group(group_name) group_value = match.group(group_name)
if group_value: if group_value:
@ -390,14 +390,14 @@ class UiOptions(CoreSysAttributes):
multiple: bool = False, multiple: bool = False,
) -> None: ) -> None:
"""UI nested dict items.""" """UI nested dict items."""
ui_node = { ui_node: dict[str, Any] = {
"name": key, "name": key,
"type": "schema", "type": "schema",
"optional": True, "optional": True,
"multiple": multiple, "multiple": multiple,
} }
nested_schema = [] nested_schema: list[dict[str, Any]] = []
for c_key, c_value in option_dict.items(): for c_key, c_value in option_dict.items():
# Nested? # Nested?
if isinstance(c_value, list): if isinstance(c_value, list):
@ -413,7 +413,7 @@ def _create_device_filter(str_filter: str) -> dict[str, Any]:
"""Generate device Filter.""" """Generate device Filter."""
raw_filter = dict(value.split("=") for value in str_filter.split(";")) raw_filter = dict(value.split("=") for value in str_filter.split(";"))
clean_filter = {} clean_filter: dict[str, Any] = {}
for key, value in raw_filter.items(): for key, value in raw_filter.items():
if key == "subsystem": if key == "subsystem":
clean_filter[key] = UdevSubsystem(value) clean_filter[key] = UdevSubsystem(value)

View File

@ -6,7 +6,7 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from aiohttp import web from aiohttp import hdrs, web
from ..const import AddonState from ..const import AddonState
from ..coresys import CoreSys, CoreSysAttributes from ..coresys import CoreSys, CoreSysAttributes
@ -82,15 +82,13 @@ class RestAPI(CoreSysAttributes):
self._site: web.TCPSite | None = None self._site: web.TCPSite | None = None
# share single host API handler for reuse in logging endpoints # share single host API handler for reuse in logging endpoints
self._api_host: APIHost | None = None self._api_host: APIHost = APIHost()
self._api_host.coresys = coresys
async def load(self) -> None: async def load(self) -> None:
"""Register REST API Calls.""" """Register REST API Calls."""
static_resource_configs: list[StaticResourceConfig] = [] static_resource_configs: list[StaticResourceConfig] = []
self._api_host = APIHost()
self._api_host.coresys = self.coresys
self._register_addons() self._register_addons()
self._register_audio() self._register_audio()
self._register_auth() self._register_auth()
@ -526,7 +524,7 @@ class RestAPI(CoreSysAttributes):
self.webapp.add_routes( self.webapp.add_routes(
[ [
web.get("/addons", api_addons.list), web.get("/addons", api_addons.list_addons),
web.post("/addons/{addon}/uninstall", api_addons.uninstall), web.post("/addons/{addon}/uninstall", api_addons.uninstall),
web.post("/addons/{addon}/start", api_addons.start), web.post("/addons/{addon}/start", api_addons.start),
web.post("/addons/{addon}/stop", api_addons.stop), web.post("/addons/{addon}/stop", api_addons.stop),
@ -594,7 +592,9 @@ class RestAPI(CoreSysAttributes):
web.post("/ingress/session", api_ingress.create_session), web.post("/ingress/session", api_ingress.create_session),
web.post("/ingress/validate_session", api_ingress.validate_session), web.post("/ingress/validate_session", api_ingress.validate_session),
web.get("/ingress/panels", api_ingress.panels), web.get("/ingress/panels", api_ingress.panels),
web.view("/ingress/{token}/{path:.*}", api_ingress.handler), web.route(
hdrs.METH_ANY, "/ingress/{token}/{path:.*}", api_ingress.handler
),
] ]
) )
@ -605,7 +605,7 @@ class RestAPI(CoreSysAttributes):
self.webapp.add_routes( self.webapp.add_routes(
[ [
web.get("/backups", api_backups.list), web.get("/backups", api_backups.list_backups),
web.get("/backups/info", api_backups.info), web.get("/backups/info", api_backups.info),
web.post("/backups/options", api_backups.options), web.post("/backups/options", api_backups.options),
web.post("/backups/reload", api_backups.reload), web.post("/backups/reload", api_backups.reload),
@ -632,7 +632,7 @@ class RestAPI(CoreSysAttributes):
self.webapp.add_routes( self.webapp.add_routes(
[ [
web.get("/services", api_services.list), web.get("/services", api_services.list_services),
web.get("/services/{service}", api_services.get_service), web.get("/services/{service}", api_services.get_service),
web.post("/services/{service}", api_services.set_service), web.post("/services/{service}", api_services.set_service),
web.delete("/services/{service}", api_services.del_service), web.delete("/services/{service}", api_services.del_service),
@ -646,7 +646,7 @@ class RestAPI(CoreSysAttributes):
self.webapp.add_routes( self.webapp.add_routes(
[ [
web.get("/discovery", api_discovery.list), web.get("/discovery", api_discovery.list_discovery),
web.get("/discovery/{uuid}", api_discovery.get_discovery), web.get("/discovery/{uuid}", api_discovery.get_discovery),
web.delete("/discovery/{uuid}", api_discovery.del_discovery), web.delete("/discovery/{uuid}", api_discovery.del_discovery),
web.post("/discovery", api_discovery.set_discovery), web.post("/discovery", api_discovery.set_discovery),

View File

@ -3,14 +3,13 @@
import asyncio import asyncio
from collections.abc import Awaitable from collections.abc import Awaitable
import logging import logging
from typing import Any from typing import Any, TypedDict
from aiohttp import web from aiohttp import web
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
from ..addons.addon import Addon from ..addons.addon import Addon
from ..addons.manager import AnyAddon
from ..addons.utils import rating_security from ..addons.utils import rating_security
from ..const import ( from ..const import (
ATTR_ADDONS, ATTR_ADDONS,
@ -63,7 +62,6 @@ from ..const import (
ATTR_MEMORY_LIMIT, ATTR_MEMORY_LIMIT,
ATTR_MEMORY_PERCENT, ATTR_MEMORY_PERCENT,
ATTR_MEMORY_USAGE, ATTR_MEMORY_USAGE,
ATTR_MESSAGE,
ATTR_NAME, ATTR_NAME,
ATTR_NETWORK, ATTR_NETWORK,
ATTR_NETWORK_DESCRIPTION, ATTR_NETWORK_DESCRIPTION,
@ -72,7 +70,6 @@ from ..const import (
ATTR_OPTIONS, ATTR_OPTIONS,
ATTR_PRIVILEGED, ATTR_PRIVILEGED,
ATTR_PROTECTED, ATTR_PROTECTED,
ATTR_PWNED,
ATTR_RATING, ATTR_RATING,
ATTR_REPOSITORY, ATTR_REPOSITORY,
ATTR_SCHEMA, ATTR_SCHEMA,
@ -90,7 +87,6 @@ from ..const import (
ATTR_UPDATE_AVAILABLE, ATTR_UPDATE_AVAILABLE,
ATTR_URL, ATTR_URL,
ATTR_USB, ATTR_USB,
ATTR_VALID,
ATTR_VERSION, ATTR_VERSION,
ATTR_VERSION_LATEST, ATTR_VERSION_LATEST,
ATTR_VIDEO, ATTR_VIDEO,
@ -146,12 +142,20 @@ SCHEMA_UNINSTALL = vol.Schema(
# pylint: enable=no-value-for-parameter # pylint: enable=no-value-for-parameter
class OptionsValidateResponse(TypedDict):
"""Response object for options validate."""
message: str
valid: bool
pwned: bool | None
class APIAddons(CoreSysAttributes): class APIAddons(CoreSysAttributes):
"""Handle RESTful API for add-on functions.""" """Handle RESTful API for add-on functions."""
def get_addon_for_request(self, request: web.Request) -> Addon: def get_addon_for_request(self, request: web.Request) -> Addon:
"""Return addon, throw an exception if it doesn't exist.""" """Return addon, throw an exception if it doesn't exist."""
addon_slug: str = request.match_info.get("addon") addon_slug: str = request.match_info["addon"]
# Lookup itself # Lookup itself
if addon_slug == "self": if addon_slug == "self":
@ -169,7 +173,7 @@ class APIAddons(CoreSysAttributes):
return addon return addon
@api_process @api_process
async def list(self, request: web.Request) -> dict[str, Any]: async def list_addons(self, request: web.Request) -> dict[str, Any]:
"""Return all add-ons or repositories.""" """Return all add-ons or repositories."""
data_addons = [ data_addons = [
{ {
@ -204,7 +208,7 @@ class APIAddons(CoreSysAttributes):
async def info(self, request: web.Request) -> dict[str, Any]: async def info(self, request: web.Request) -> dict[str, Any]:
"""Return add-on information.""" """Return add-on information."""
addon: AnyAddon = self.get_addon_for_request(request) addon: Addon = self.get_addon_for_request(request)
data = { data = {
ATTR_NAME: addon.name, ATTR_NAME: addon.name,
@ -339,10 +343,10 @@ class APIAddons(CoreSysAttributes):
await addon.save_persist() await addon.save_persist()
@api_process @api_process
async def options_validate(self, request: web.Request) -> None: async def options_validate(self, request: web.Request) -> OptionsValidateResponse:
"""Validate user options for add-on.""" """Validate user options for add-on."""
addon = self.get_addon_for_request(request) addon = self.get_addon_for_request(request)
data = {ATTR_MESSAGE: "", ATTR_VALID: True, ATTR_PWNED: False} data = OptionsValidateResponse(message="", valid=True, pwned=False)
options = await request.json(loads=json_loads) or addon.options options = await request.json(loads=json_loads) or addon.options
@ -351,8 +355,8 @@ class APIAddons(CoreSysAttributes):
try: try:
options_schema.validate(options) options_schema.validate(options)
except vol.Invalid as ex: except vol.Invalid as ex:
data[ATTR_MESSAGE] = humanize_error(options, ex) data["message"] = humanize_error(options, ex)
data[ATTR_VALID] = False data["valid"] = False
if not self.sys_security.pwned: if not self.sys_security.pwned:
return data return data
@ -363,24 +367,24 @@ class APIAddons(CoreSysAttributes):
await self.sys_security.verify_secret(secret) await self.sys_security.verify_secret(secret)
continue continue
except PwnedSecret: except PwnedSecret:
data[ATTR_PWNED] = True data["pwned"] = True
except PwnedError: except PwnedError:
data[ATTR_PWNED] = None data["pwned"] = None
break break
if self.sys_security.force and data[ATTR_PWNED] in (None, True): if self.sys_security.force and data["pwned"] in (None, True):
data[ATTR_VALID] = False data["valid"] = False
if data[ATTR_PWNED] is None: if data["pwned"] is None:
data[ATTR_MESSAGE] = "Error happening on pwned secrets check!" data["message"] = "Error happening on pwned secrets check!"
else: else:
data[ATTR_MESSAGE] = "Add-on uses pwned secrets!" data["message"] = "Add-on uses pwned secrets!"
return data return data
@api_process @api_process
async def options_config(self, request: web.Request) -> None: async def options_config(self, request: web.Request) -> None:
"""Validate user options for add-on.""" """Validate user options for add-on."""
slug: str = request.match_info.get("addon") slug: str = request.match_info["addon"]
if slug != "self": if slug != "self":
raise APIForbidden("This can be only read by the Add-on itself!") raise APIForbidden("This can be only read by the Add-on itself!")
addon = self.get_addon_for_request(request) addon = self.get_addon_for_request(request)

View File

@ -124,7 +124,7 @@ class APIAudio(CoreSysAttributes):
@api_process @api_process
async def set_volume(self, request: web.Request) -> None: async def set_volume(self, request: web.Request) -> None:
"""Set audio volume on stream.""" """Set audio volume on stream."""
source: StreamType = StreamType(request.match_info.get("source")) source: StreamType = StreamType(request.match_info["source"])
application: bool = request.path.endswith("application") application: bool = request.path.endswith("application")
body = await api_validate(SCHEMA_VOLUME, request) body = await api_validate(SCHEMA_VOLUME, request)
@ -137,7 +137,7 @@ class APIAudio(CoreSysAttributes):
@api_process @api_process
async def set_mute(self, request: web.Request) -> None: async def set_mute(self, request: web.Request) -> None:
"""Mute audio volume on stream.""" """Mute audio volume on stream."""
source: StreamType = StreamType(request.match_info.get("source")) source: StreamType = StreamType(request.match_info["source"])
application: bool = request.path.endswith("application") application: bool = request.path.endswith("application")
body = await api_validate(SCHEMA_MUTE, request) body = await api_validate(SCHEMA_MUTE, request)
@ -150,7 +150,7 @@ class APIAudio(CoreSysAttributes):
@api_process @api_process
async def set_default(self, request: web.Request) -> None: async def set_default(self, request: web.Request) -> None:
"""Set audio default stream.""" """Set audio default stream."""
source: StreamType = StreamType(request.match_info.get("source")) source: StreamType = StreamType(request.match_info["source"])
body = await api_validate(SCHEMA_DEFAULT, request) body = await api_validate(SCHEMA_DEFAULT, request)
await asyncio.shield(self.sys_host.sound.set_default(source, body[ATTR_NAME])) await asyncio.shield(self.sys_host.sound.set_default(source, body[ATTR_NAME]))

View File

@ -1,6 +1,7 @@
"""Init file for Supervisor auth/SSO RESTful API.""" """Init file for Supervisor auth/SSO RESTful API."""
import asyncio import asyncio
from collections.abc import Awaitable
import logging import logging
from typing import Any from typing import Any
@ -42,7 +43,7 @@ REALM_HEADER: dict[str, str] = {
class APIAuth(CoreSysAttributes): class APIAuth(CoreSysAttributes):
"""Handle RESTful API for auth functions.""" """Handle RESTful API for auth functions."""
def _process_basic(self, request: web.Request, addon: Addon) -> bool: def _process_basic(self, request: web.Request, addon: Addon) -> Awaitable[bool]:
"""Process login request with basic auth. """Process login request with basic auth.
Return a coroutine. Return a coroutine.
@ -52,7 +53,7 @@ class APIAuth(CoreSysAttributes):
def _process_dict( def _process_dict(
self, request: web.Request, addon: Addon, data: dict[str, str] self, request: web.Request, addon: Addon, data: dict[str, str]
) -> bool: ) -> Awaitable[bool]:
"""Process login with dict data. """Process login with dict data.
Return a coroutine. Return a coroutine.

View File

@ -10,9 +10,9 @@ import logging
from pathlib import Path from pathlib import Path
import re import re
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any from typing import Any, cast
from aiohttp import web from aiohttp import BodyPartReader, web
from aiohttp.hdrs import CONTENT_DISPOSITION from aiohttp.hdrs import CONTENT_DISPOSITION
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -52,8 +52,9 @@ from ..const import (
) )
from ..coresys import CoreSysAttributes from ..coresys import CoreSysAttributes
from ..exceptions import APIError, APIForbidden, APINotFound from ..exceptions import APIError, APIForbidden, APINotFound
from ..jobs import JobSchedulerOptions from ..jobs import JobSchedulerOptions, SupervisorJob
from ..mounts.const import MountUsage from ..mounts.const import MountUsage
from ..mounts.mount import Mount
from ..resolution.const import UnhealthyReason from ..resolution.const import UnhealthyReason
from .const import ( from .const import (
ATTR_ADDITIONAL_LOCATIONS, ATTR_ADDITIONAL_LOCATIONS,
@ -187,7 +188,7 @@ class APIBackups(CoreSysAttributes):
] ]
@api_process @api_process
async def list(self, request): async def list_backups(self, request):
"""Return backup list.""" """Return backup list."""
data_backups = self._list_backups() data_backups = self._list_backups()
@ -295,8 +296,11 @@ class APIBackups(CoreSysAttributes):
) -> tuple[asyncio.Task, str]: ) -> tuple[asyncio.Task, str]:
"""Start backup task in background and return task and job ID.""" """Start backup task in background and return task and job ID."""
event = asyncio.Event() event = asyncio.Event()
job, backup_task = self.sys_jobs.schedule_job( job, backup_task = cast(
backup_method, JobSchedulerOptions(), *args, **kwargs tuple[SupervisorJob, asyncio.Task],
self.sys_jobs.schedule_job(
backup_method, JobSchedulerOptions(), *args, **kwargs
),
) )
async def release_on_freeze(new_state: CoreState): async def release_on_freeze(new_state: CoreState):
@ -311,10 +315,7 @@ class APIBackups(CoreSysAttributes):
try: try:
event_task = self.sys_create_task(event.wait()) event_task = self.sys_create_task(event.wait())
_, pending = await asyncio.wait( _, pending = await asyncio.wait(
( (backup_task, event_task),
backup_task,
event_task,
),
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
) )
# It seems backup returned early (error or something), make sure to cancel # It seems backup returned early (error or something), make sure to cancel
@ -497,8 +498,10 @@ class APIBackups(CoreSysAttributes):
locations: list[LOCATION_TYPE] | None = None locations: list[LOCATION_TYPE] | None = None
tmp_path = self.sys_config.path_tmp tmp_path = self.sys_config.path_tmp
if ATTR_LOCATION in request.query: if ATTR_LOCATION in request.query:
location_names: list[str] = request.query.getall(ATTR_LOCATION) location_names: list[str] = request.query.getall(ATTR_LOCATION, [])
self._validate_cloud_backup_location(request, location_names) self._validate_cloud_backup_location(
request, cast(list[str | None], location_names)
)
# Convert empty string to None if necessary # Convert empty string to None if necessary
locations = [ locations = [
self._location_to_mount(location) self._location_to_mount(location)
@ -509,7 +512,7 @@ class APIBackups(CoreSysAttributes):
location = locations.pop(0) location = locations.pop(0)
if location and location != LOCATION_CLOUD_BACKUP: if location and location != LOCATION_CLOUD_BACKUP:
tmp_path = location.local_where tmp_path = cast(Mount, location).local_where or tmp_path
filename: str | None = None filename: str | None = None
if ATTR_FILENAME in request.query: if ATTR_FILENAME in request.query:
@ -540,10 +543,15 @@ class APIBackups(CoreSysAttributes):
try: try:
reader = await request.multipart() reader = await request.multipart()
contents = await reader.next() contents = await reader.next()
if not isinstance(contents, BodyPartReader):
raise APIError("Improperly formatted upload, could not read backup")
tar_file = await self.sys_run_in_executor(open_backup_file) tar_file = await self.sys_run_in_executor(open_backup_file)
while chunk := await contents.read_chunk(size=2**16): while chunk := await contents.read_chunk(size=2**16):
await self.sys_run_in_executor(backup_file_stream.write, chunk) await self.sys_run_in_executor(
await self.sys_run_in_executor(backup_file_stream.close) cast(IOBase, backup_file_stream).write, chunk
)
await self.sys_run_in_executor(cast(IOBase, backup_file_stream).close)
backup = await asyncio.shield( backup = await asyncio.shield(
self.sys_backups.import_backup( self.sys_backups.import_backup(
@ -558,7 +566,9 @@ class APIBackups(CoreSysAttributes):
LOCATION_CLOUD_BACKUP, LOCATION_CLOUD_BACKUP,
None, None,
}: }:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
_LOGGER.error("Can't write new backup file: %s", err) _LOGGER.error("Can't write new backup file: %s", err)
return False return False

View File

@ -1,7 +1,9 @@
"""Init file for Supervisor network RESTful API.""" """Init file for Supervisor network RESTful API."""
import logging import logging
from typing import Any, cast
from aiohttp import web
import voluptuous as vol import voluptuous as vol
from ..addons.addon import Addon from ..addons.addon import Addon
@ -16,6 +18,7 @@ from ..const import (
AddonState, AddonState,
) )
from ..coresys import CoreSysAttributes from ..coresys import CoreSysAttributes
from ..discovery import Message
from ..exceptions import APIForbidden, APINotFound from ..exceptions import APIForbidden, APINotFound
from .utils import api_process, api_validate, require_home_assistant from .utils import api_process, api_validate, require_home_assistant
@ -32,16 +35,16 @@ SCHEMA_DISCOVERY = vol.Schema(
class APIDiscovery(CoreSysAttributes): class APIDiscovery(CoreSysAttributes):
"""Handle RESTful API for discovery functions.""" """Handle RESTful API for discovery functions."""
def _extract_message(self, request): def _extract_message(self, request: web.Request) -> Message:
"""Extract discovery message from URL.""" """Extract discovery message from URL."""
message = self.sys_discovery.get(request.match_info.get("uuid")) message = self.sys_discovery.get(request.match_info["uuid"])
if not message: if not message:
raise APINotFound("Discovery message not found") raise APINotFound("Discovery message not found")
return message return message
@api_process @api_process
@require_home_assistant @require_home_assistant
async def list(self, request): async def list_discovery(self, request: web.Request) -> dict[str, Any]:
"""Show registered and available services.""" """Show registered and available services."""
# Get available discovery # Get available discovery
discovery = [ discovery = [
@ -52,12 +55,16 @@ class APIDiscovery(CoreSysAttributes):
ATTR_CONFIG: message.config, ATTR_CONFIG: message.config,
} }
for message in self.sys_discovery.list_messages for message in self.sys_discovery.list_messages
if (addon := self.sys_addons.get(message.addon, local_only=True)) if (
and addon.state == AddonState.STARTED discovered := cast(
Addon, self.sys_addons.get(message.addon, local_only=True)
)
)
and discovered.state == AddonState.STARTED
] ]
# Get available services/add-ons # Get available services/add-ons
services = {} services: dict[str, list[str]] = {}
for addon in self.sys_addons.all: for addon in self.sys_addons.all:
for name in addon.discovery: for name in addon.discovery:
services.setdefault(name, []).append(addon.slug) services.setdefault(name, []).append(addon.slug)
@ -65,7 +72,7 @@ class APIDiscovery(CoreSysAttributes):
return {ATTR_DISCOVERY: discovery, ATTR_SERVICES: services} return {ATTR_DISCOVERY: discovery, ATTR_SERVICES: services}
@api_process @api_process
async def set_discovery(self, request): async def set_discovery(self, request: web.Request) -> dict[str, str]:
"""Write data into a discovery pipeline.""" """Write data into a discovery pipeline."""
body = await api_validate(SCHEMA_DISCOVERY, request) body = await api_validate(SCHEMA_DISCOVERY, request)
addon: Addon = request[REQUEST_FROM] addon: Addon = request[REQUEST_FROM]
@ -89,7 +96,7 @@ class APIDiscovery(CoreSysAttributes):
@api_process @api_process
@require_home_assistant @require_home_assistant
async def get_discovery(self, request): async def get_discovery(self, request: web.Request) -> dict[str, Any]:
"""Read data into a discovery message.""" """Read data into a discovery message."""
message = self._extract_message(request) message = self._extract_message(request)
@ -101,7 +108,7 @@ class APIDiscovery(CoreSysAttributes):
} }
@api_process @api_process
async def del_discovery(self, request): async def del_discovery(self, request: web.Request) -> None:
"""Delete data into a discovery message.""" """Delete data into a discovery message."""
message = self._extract_message(request) message = self._extract_message(request)
addon = request[REQUEST_FROM] addon = request[REQUEST_FROM]
@ -111,4 +118,3 @@ class APIDiscovery(CoreSysAttributes):
raise APIForbidden("Can't remove discovery message") raise APIForbidden("Can't remove discovery message")
await self.sys_discovery.remove(message) await self.sys_discovery.remove(message)
return True

View File

@ -68,7 +68,10 @@ def filesystem_struct(fs_block: UDisks2Block) -> dict[str, Any]:
ATTR_NAME: fs_block.id_label, ATTR_NAME: fs_block.id_label,
ATTR_SYSTEM: fs_block.hint_system, ATTR_SYSTEM: fs_block.hint_system,
ATTR_MOUNT_POINTS: [ ATTR_MOUNT_POINTS: [
str(mount_point) for mount_point in fs_block.filesystem.mount_points str(mount_point)
for mount_point in (
fs_block.filesystem.mount_points if fs_block.filesystem else []
)
], ],
} }

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
from contextlib import suppress from contextlib import suppress
import logging import logging
from typing import Any
from aiohttp import ClientConnectionResetError, web from aiohttp import ClientConnectionResetError, web
from aiohttp.hdrs import ACCEPT, RANGE from aiohttp.hdrs import ACCEPT, RANGE
@ -195,20 +196,18 @@ class APIHost(CoreSysAttributes):
) -> web.StreamResponse: ) -> web.StreamResponse:
"""Return systemd-journald logs.""" """Return systemd-journald logs."""
log_formatter = LogFormatter.PLAIN log_formatter = LogFormatter.PLAIN
params = {} params: dict[str, Any] = {}
if identifier: if identifier:
params[PARAM_SYSLOG_IDENTIFIER] = identifier params[PARAM_SYSLOG_IDENTIFIER] = identifier
elif IDENTIFIER in request.match_info: elif IDENTIFIER in request.match_info:
params[PARAM_SYSLOG_IDENTIFIER] = request.match_info.get(IDENTIFIER) params[PARAM_SYSLOG_IDENTIFIER] = request.match_info[IDENTIFIER]
else: else:
params[PARAM_SYSLOG_IDENTIFIER] = self.sys_host.logs.default_identifiers params[PARAM_SYSLOG_IDENTIFIER] = self.sys_host.logs.default_identifiers
# host logs should be always verbose, no matter what Accept header is used # host logs should be always verbose, no matter what Accept header is used
log_formatter = LogFormatter.VERBOSE log_formatter = LogFormatter.VERBOSE
if BOOTID in request.match_info: if BOOTID in request.match_info:
params[PARAM_BOOT_ID] = await self._get_boot_id( params[PARAM_BOOT_ID] = await self._get_boot_id(request.match_info[BOOTID])
request.match_info.get(BOOTID)
)
if follow: if follow:
params[PARAM_FOLLOW] = "" params[PARAM_FOLLOW] = ""
@ -241,7 +240,7 @@ class APIHost(CoreSysAttributes):
# entries=cursor[[:num_skip]:num_entries] # entries=cursor[[:num_skip]:num_entries]
range_header = f"entries=:-{lines - 1}:{'' if follow else lines}" range_header = f"entries=:-{lines - 1}:{'' if follow else lines}"
elif RANGE in request.headers: elif RANGE in request.headers:
range_header = request.headers.get(RANGE) range_header = request.headers[RANGE]
else: else:
range_header = ( range_header = (
f"entries=:-{DEFAULT_LINES - 1}:{'' if follow else DEFAULT_LINES}" f"entries=:-{DEFAULT_LINES - 1}:{'' if follow else DEFAULT_LINES}"

View File

@ -83,7 +83,7 @@ class APIIngress(CoreSysAttributes):
def _extract_addon(self, request: web.Request) -> Addon: def _extract_addon(self, request: web.Request) -> Addon:
"""Return addon, throw an exception it it doesn't exist.""" """Return addon, throw an exception it it doesn't exist."""
token = request.match_info.get("token") token = request.match_info["token"]
# Find correct add-on # Find correct add-on
addon = self.sys_ingress.get(token) addon = self.sys_ingress.get(token)
@ -132,7 +132,7 @@ class APIIngress(CoreSysAttributes):
@api_process @api_process
@require_home_assistant @require_home_assistant
async def validate_session(self, request: web.Request) -> dict[str, Any]: async def validate_session(self, request: web.Request) -> None:
"""Validate session and extending how long it's valid for.""" """Validate session and extending how long it's valid for."""
data = await api_validate(VALIDATE_SESSION_DATA, request) data = await api_validate(VALIDATE_SESSION_DATA, request)
@ -147,14 +147,14 @@ class APIIngress(CoreSysAttributes):
"""Route data to Supervisor ingress service.""" """Route data to Supervisor ingress service."""
# Check Ingress Session # Check Ingress Session
session = request.cookies.get(COOKIE_INGRESS) session = request.cookies.get(COOKIE_INGRESS, "")
if not self.sys_ingress.validate_session(session): if not self.sys_ingress.validate_session(session):
_LOGGER.warning("No valid ingress session %s", session) _LOGGER.warning("No valid ingress session %s", session)
raise HTTPUnauthorized() raise HTTPUnauthorized()
# Process requests # Process requests
addon = self._extract_addon(request) addon = self._extract_addon(request)
path = request.match_info.get("path") path = request.match_info["path"]
session_data = self.sys_ingress.get_session_data(session) session_data = self.sys_ingress.get_session_data(session)
try: try:
# Websocket # Websocket
@ -183,7 +183,7 @@ class APIIngress(CoreSysAttributes):
for proto in request.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",") for proto in request.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",")
] ]
else: else:
req_protocols = () req_protocols = []
ws_server = web.WebSocketResponse( ws_server = web.WebSocketResponse(
protocols=req_protocols, autoclose=False, autoping=False protocols=req_protocols, autoclose=False, autoping=False
@ -340,9 +340,10 @@ def _init_header(
headers[name] = value headers[name] = value
# Update X-Forwarded-For # Update X-Forwarded-For
forward_for = request.headers.get(hdrs.X_FORWARDED_FOR) if request.transport:
connected_ip = ip_address(request.transport.get_extra_info("peername")[0]) forward_for = request.headers.get(hdrs.X_FORWARDED_FOR)
headers[hdrs.X_FORWARDED_FOR] = f"{forward_for}, {connected_ip!s}" connected_ip = ip_address(request.transport.get_extra_info("peername")[0])
headers[hdrs.X_FORWARDED_FOR] = f"{forward_for}, {connected_ip!s}"
return headers return headers

View File

@ -26,7 +26,7 @@ class APIJobs(CoreSysAttributes):
def _extract_job(self, request: web.Request) -> SupervisorJob: def _extract_job(self, request: web.Request) -> SupervisorJob:
"""Extract job from request or raise.""" """Extract job from request or raise."""
try: try:
return self.sys_jobs.get_job(request.match_info.get("uuid")) return self.sys_jobs.get_job(request.match_info["uuid"])
except JobNotFound: except JobNotFound:
raise APINotFound("Job does not exist") from None raise APINotFound("Job does not exist") from None
@ -71,7 +71,10 @@ class APIJobs(CoreSysAttributes):
if current_job.uuid in jobs_by_parent: if current_job.uuid in jobs_by_parent:
queue.extend( queue.extend(
[(child_jobs, job) for job in jobs_by_parent.get(current_job.uuid)] [
(child_jobs, job)
for job in jobs_by_parent.get(current_job.uuid, [])
]
) )
return job_list return job_list

View File

@ -1,11 +1,12 @@
"""Handle security part of this API.""" """Handle security part of this API."""
from collections.abc import Callable
import logging import logging
import re import re
from typing import Final from typing import Final
from urllib.parse import unquote from urllib.parse import unquote
from aiohttp.web import Request, RequestHandler, Response, middleware from aiohttp.web import Request, Response, middleware
from aiohttp.web_exceptions import HTTPBadRequest, HTTPForbidden, HTTPUnauthorized from aiohttp.web_exceptions import HTTPBadRequest, HTTPForbidden, HTTPUnauthorized
from awesomeversion import AwesomeVersion from awesomeversion import AwesomeVersion
@ -23,7 +24,7 @@ from ...const import (
) )
from ...coresys import CoreSys, CoreSysAttributes from ...coresys import CoreSys, CoreSysAttributes
from ...utils import version_is_new_enough from ...utils import version_is_new_enough
from ..utils import api_return_error, excract_supervisor_token from ..utils import api_return_error, extract_supervisor_token
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
_CORE_VERSION: Final = AwesomeVersion("2023.3.4") _CORE_VERSION: Final = AwesomeVersion("2023.3.4")
@ -179,9 +180,7 @@ class SecurityMiddleware(CoreSysAttributes):
return unquoted return unquoted
@middleware @middleware
async def block_bad_requests( async def block_bad_requests(self, request: Request, handler: Callable) -> Response:
self, request: Request, handler: RequestHandler
) -> Response:
"""Process request and tblock commonly known exploit attempts.""" """Process request and tblock commonly known exploit attempts."""
if FILTERS.search(self._recursive_unquote(request.path)): if FILTERS.search(self._recursive_unquote(request.path)):
_LOGGER.warning( _LOGGER.warning(
@ -199,9 +198,7 @@ class SecurityMiddleware(CoreSysAttributes):
return await handler(request) return await handler(request)
@middleware @middleware
async def system_validation( async def system_validation(self, request: Request, handler: Callable) -> Response:
self, request: Request, handler: RequestHandler
) -> Response:
"""Check if core is ready to response.""" """Check if core is ready to response."""
if self.sys_core.state not in ( if self.sys_core.state not in (
CoreState.STARTUP, CoreState.STARTUP,
@ -215,12 +212,10 @@ class SecurityMiddleware(CoreSysAttributes):
return await handler(request) return await handler(request)
@middleware @middleware
async def token_validation( async def token_validation(self, request: Request, handler: Callable) -> Response:
self, request: Request, handler: RequestHandler
) -> Response:
"""Check security access of this layer.""" """Check security access of this layer."""
request_from = None request_from: CoreSysAttributes | None = None
supervisor_token = excract_supervisor_token(request) supervisor_token = extract_supervisor_token(request)
# Blacklist # Blacklist
if BLACKLIST.match(request.path): if BLACKLIST.match(request.path):
@ -288,7 +283,7 @@ class SecurityMiddleware(CoreSysAttributes):
raise HTTPForbidden() raise HTTPForbidden()
@middleware @middleware
async def core_proxy(self, request: Request, handler: RequestHandler) -> Response: async def core_proxy(self, request: Request, handler: Callable) -> Response:
"""Validate user from Core API proxy.""" """Validate user from Core API proxy."""
if ( if (
request[REQUEST_FROM] != self.sys_homeassistant request[REQUEST_FROM] != self.sys_homeassistant

View File

@ -1,6 +1,6 @@
"""Inits file for supervisor mounts REST API.""" """Inits file for supervisor mounts REST API."""
from typing import Any from typing import Any, cast
from aiohttp import web from aiohttp import web
import voluptuous as vol import voluptuous as vol
@ -10,7 +10,7 @@ from ..coresys import CoreSysAttributes
from ..exceptions import APIError, APINotFound from ..exceptions import APIError, APINotFound
from ..mounts.const import ATTR_DEFAULT_BACKUP_MOUNT, MountUsage from ..mounts.const import ATTR_DEFAULT_BACKUP_MOUNT, MountUsage
from ..mounts.mount import Mount from ..mounts.mount import Mount
from ..mounts.validate import SCHEMA_MOUNT_CONFIG from ..mounts.validate import SCHEMA_MOUNT_CONFIG, MountData
from .const import ATTR_MOUNTS, ATTR_USER_PATH from .const import ATTR_MOUNTS, ATTR_USER_PATH
from .utils import api_process, api_validate from .utils import api_process, api_validate
@ -26,7 +26,7 @@ class APIMounts(CoreSysAttributes):
def _extract_mount(self, request: web.Request) -> Mount: def _extract_mount(self, request: web.Request) -> Mount:
"""Extract mount from request or raise.""" """Extract mount from request or raise."""
name = request.match_info.get("mount") name = request.match_info["mount"]
if name not in self.sys_mounts: if name not in self.sys_mounts:
raise APINotFound(f"No mount exists with name {name}") raise APINotFound(f"No mount exists with name {name}")
return self.sys_mounts.get(name) return self.sys_mounts.get(name)
@ -71,10 +71,10 @@ class APIMounts(CoreSysAttributes):
@api_process @api_process
async def create_mount(self, request: web.Request) -> None: async def create_mount(self, request: web.Request) -> None:
"""Create a new mount in supervisor.""" """Create a new mount in supervisor."""
body = await api_validate(SCHEMA_MOUNT_CONFIG, request) body = cast(MountData, await api_validate(SCHEMA_MOUNT_CONFIG, request))
if body[ATTR_NAME] in self.sys_mounts: if body["name"] in self.sys_mounts:
raise APIError(f"A mount already exists with name {body[ATTR_NAME]}") raise APIError(f"A mount already exists with name {body['name']}")
mount = Mount.from_dict(self.coresys, body) mount = Mount.from_dict(self.coresys, body)
await self.sys_mounts.create_mount(mount) await self.sys_mounts.create_mount(mount)
@ -97,7 +97,10 @@ class APIMounts(CoreSysAttributes):
{vol.Optional(ATTR_NAME, default=current.name): current.name}, {vol.Optional(ATTR_NAME, default=current.name): current.name},
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
body = await api_validate(vol.All(name_schema, SCHEMA_MOUNT_CONFIG), request) body = cast(
MountData,
await api_validate(vol.All(name_schema, SCHEMA_MOUNT_CONFIG), request),
)
mount = Mount.from_dict(self.coresys, body) mount = Mount.from_dict(self.coresys, body)
await self.sys_mounts.create_mount(mount) await self.sys_mounts.create_mount(mount)

View File

@ -132,8 +132,12 @@ def interface_struct(interface: Interface) -> dict[str, Any]:
ATTR_CONNECTED: interface.connected, ATTR_CONNECTED: interface.connected,
ATTR_PRIMARY: interface.primary, ATTR_PRIMARY: interface.primary,
ATTR_MAC: interface.mac, ATTR_MAC: interface.mac,
ATTR_IPV4: ipconfig_struct(interface.ipv4, interface.ipv4setting), ATTR_IPV4: ipconfig_struct(interface.ipv4, interface.ipv4setting)
ATTR_IPV6: ipconfig_struct(interface.ipv6, interface.ipv6setting), if interface.ipv4 and interface.ipv4setting
else None,
ATTR_IPV6: ipconfig_struct(interface.ipv6, interface.ipv6setting)
if interface.ipv6 and interface.ipv6setting
else None,
ATTR_WIFI: wifi_struct(interface.wifi) if interface.wifi else None, ATTR_WIFI: wifi_struct(interface.wifi) if interface.wifi else None,
ATTR_VLAN: vlan_struct(interface.vlan) if interface.vlan else None, ATTR_VLAN: vlan_struct(interface.vlan) if interface.vlan else None,
} }
@ -190,14 +194,14 @@ class APINetwork(CoreSysAttributes):
@api_process @api_process
async def interface_info(self, request: web.Request) -> dict[str, Any]: async def interface_info(self, request: web.Request) -> dict[str, Any]:
"""Return network information for a interface.""" """Return network information for a interface."""
interface = self._get_interface(request.match_info.get(ATTR_INTERFACE)) interface = self._get_interface(request.match_info[ATTR_INTERFACE])
return interface_struct(interface) return interface_struct(interface)
@api_process @api_process
async def interface_update(self, request: web.Request) -> None: async def interface_update(self, request: web.Request) -> None:
"""Update the configuration of an interface.""" """Update the configuration of an interface."""
interface = self._get_interface(request.match_info.get(ATTR_INTERFACE)) interface = self._get_interface(request.match_info[ATTR_INTERFACE])
# Validate data # Validate data
body = await api_validate(SCHEMA_UPDATE, request) body = await api_validate(SCHEMA_UPDATE, request)
@ -243,7 +247,7 @@ class APINetwork(CoreSysAttributes):
@api_process @api_process
async def scan_accesspoints(self, request: web.Request) -> dict[str, Any]: async def scan_accesspoints(self, request: web.Request) -> dict[str, Any]:
"""Scan and return a list of available networks.""" """Scan and return a list of available networks."""
interface = self._get_interface(request.match_info.get(ATTR_INTERFACE)) interface = self._get_interface(request.match_info[ATTR_INTERFACE])
# Only wlan is supported # Only wlan is supported
if interface.type != InterfaceType.WIRELESS: if interface.type != InterfaceType.WIRELESS:
@ -256,8 +260,10 @@ class APINetwork(CoreSysAttributes):
@api_process @api_process
async def create_vlan(self, request: web.Request) -> None: async def create_vlan(self, request: web.Request) -> None:
"""Create a new vlan.""" """Create a new vlan."""
interface = self._get_interface(request.match_info.get(ATTR_INTERFACE)) interface = self._get_interface(request.match_info[ATTR_INTERFACE])
vlan = int(request.match_info.get(ATTR_VLAN)) vlan = int(request.match_info.get(ATTR_VLAN, -1))
if vlan < 0:
raise APIError(f"Invalid vlan specified: {vlan}")
# Only ethernet is supported # Only ethernet is supported
if interface.type != InterfaceType.ETHERNET: if interface.type != InterfaceType.ETHERNET:

View File

@ -1,6 +1,7 @@
"""Utils for Home Assistant Proxy.""" """Utils for Home Assistant Proxy."""
import asyncio import asyncio
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import logging import logging
@ -40,7 +41,7 @@ class APIProxy(CoreSysAttributes):
bearer = request.headers[AUTHORIZATION] bearer = request.headers[AUTHORIZATION]
supervisor_token = bearer.split(" ")[-1] supervisor_token = bearer.split(" ")[-1]
else: else:
supervisor_token = request.headers.get(HEADER_HA_ACCESS) supervisor_token = request.headers.get(HEADER_HA_ACCESS, "")
addon = self.sys_addons.from_token(supervisor_token) addon = self.sys_addons.from_token(supervisor_token)
if not addon: if not addon:
@ -54,7 +55,9 @@ class APIProxy(CoreSysAttributes):
raise HTTPUnauthorized() raise HTTPUnauthorized()
@asynccontextmanager @asynccontextmanager
async def _api_client(self, request: web.Request, path: str, timeout: int = 300): async def _api_client(
self, request: web.Request, path: str, timeout: int | None = 300
) -> AsyncIterator[aiohttp.ClientResponse]:
"""Return a client request with proxy origin for Home Assistant.""" """Return a client request with proxy origin for Home Assistant."""
try: try:
async with self.sys_homeassistant.api.make_request( async with self.sys_homeassistant.api.make_request(
@ -93,7 +96,7 @@ class APIProxy(CoreSysAttributes):
_LOGGER.info("Home Assistant EventStream start") _LOGGER.info("Home Assistant EventStream start")
async with self._api_client(request, "stream", timeout=None) as client: async with self._api_client(request, "stream", timeout=None) as client:
response = web.StreamResponse() response = web.StreamResponse()
response.content_type = request.headers.get(CONTENT_TYPE) response.content_type = request.headers.get(CONTENT_TYPE, "")
try: try:
response.headers["X-Accel-Buffering"] = "no" response.headers["X-Accel-Buffering"] = "no"
await response.prepare(request) await response.prepare(request)
@ -113,7 +116,7 @@ class APIProxy(CoreSysAttributes):
raise HTTPBadGateway() raise HTTPBadGateway()
# Normal request # Normal request
path = request.match_info.get("path", "") path = request.match_info["path"]
async with self._api_client(request, path) as client: async with self._api_client(request, path) as client:
data = await client.read() data = await client.read()
return web.Response( return web.Response(
@ -180,21 +183,19 @@ class APIProxy(CoreSysAttributes):
target: web.WebSocketResponse | ClientWebSocketResponse, target: web.WebSocketResponse | ClientWebSocketResponse,
) -> None: ) -> None:
"""Proxy a message from client to server or vice versa.""" """Proxy a message from client to server or vice versa."""
if read_task.exception():
raise read_task.exception()
msg: WSMessage = read_task.result() msg: WSMessage = read_task.result()
if msg.type == WSMsgType.TEXT: match msg.type:
return await target.send_str(msg.data) case WSMsgType.TEXT:
if msg.type == WSMsgType.BINARY: await target.send_str(msg.data)
return await target.send_bytes(msg.data) case WSMsgType.BINARY:
if msg.type == WSMsgType.CLOSE: await target.send_bytes(msg.data)
_LOGGER.debug("Received close message from WebSocket.") case WSMsgType.CLOSE:
return await target.close() _LOGGER.debug("Received close message from WebSocket.")
await target.close()
raise TypeError( case _:
f"Cannot proxy websocket message of unsupported type: {msg.type}" raise TypeError(
) f"Cannot proxy websocket message of unsupported type: {msg.type}"
)
async def websocket(self, request: web.Request): async def websocket(self, request: web.Request):
"""Initialize a WebSocket API connection.""" """Initialize a WebSocket API connection."""

View File

@ -33,23 +33,21 @@ class APIResoulution(CoreSysAttributes):
def _extract_issue(self, request: web.Request) -> Issue: def _extract_issue(self, request: web.Request) -> Issue:
"""Extract issue from request or raise.""" """Extract issue from request or raise."""
try: try:
return self.sys_resolution.get_issue(request.match_info.get("issue")) return self.sys_resolution.get_issue(request.match_info["issue"])
except ResolutionNotFound: except ResolutionNotFound:
raise APINotFound("The supplied UUID is not a valid issue") from None raise APINotFound("The supplied UUID is not a valid issue") from None
def _extract_suggestion(self, request: web.Request) -> Suggestion: def _extract_suggestion(self, request: web.Request) -> Suggestion:
"""Extract suggestion from request or raise.""" """Extract suggestion from request or raise."""
try: try:
return self.sys_resolution.get_suggestion( return self.sys_resolution.get_suggestion(request.match_info["suggestion"])
request.match_info.get("suggestion")
)
except ResolutionNotFound: except ResolutionNotFound:
raise APINotFound("The supplied UUID is not a valid suggestion") from None raise APINotFound("The supplied UUID is not a valid suggestion") from None
def _extract_check(self, request: web.Request) -> CheckBase: def _extract_check(self, request: web.Request) -> CheckBase:
"""Extract check from request or raise.""" """Extract check from request or raise."""
try: try:
return self.sys_resolution.check.get(request.match_info.get("check")) return self.sys_resolution.check.get(request.match_info["check"])
except ResolutionNotFound: except ResolutionNotFound:
raise APINotFound("The supplied check slug is not available") from None raise APINotFound("The supplied check slug is not available") from None

View File

@ -25,7 +25,7 @@ class APIServices(CoreSysAttributes):
return service return service
@api_process @api_process
async def list(self, request): async def list_services(self, request):
"""Show register services.""" """Show register services."""
services = [] services = []
for service in self.sys_services.list_services: for service in self.sys_services.list_services:

View File

@ -3,11 +3,12 @@
import asyncio import asyncio
from collections.abc import Awaitable from collections.abc import Awaitable
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, cast
from aiohttp import web from aiohttp import web
import voluptuous as vol import voluptuous as vol
from ..addons.addon import Addon
from ..addons.manager import AnyAddon from ..addons.manager import AnyAddon
from ..addons.utils import rating_security from ..addons.utils import rating_security
from ..api.const import ATTR_SIGNED from ..api.const import ATTR_SIGNED
@ -92,7 +93,7 @@ class APIStore(CoreSysAttributes):
def _extract_addon(self, request: web.Request, installed=False) -> AnyAddon: def _extract_addon(self, request: web.Request, installed=False) -> AnyAddon:
"""Return add-on, throw an exception it it doesn't exist.""" """Return add-on, throw an exception it it doesn't exist."""
addon_slug: str = request.match_info.get("addon") addon_slug: str = request.match_info["addon"]
if not (addon := self.sys_addons.get(addon_slug)): if not (addon := self.sys_addons.get(addon_slug)):
raise APINotFound(f"Addon {addon_slug} does not exist") raise APINotFound(f"Addon {addon_slug} does not exist")
@ -101,6 +102,7 @@ class APIStore(CoreSysAttributes):
raise APIError(f"Addon {addon_slug} is not installed") raise APIError(f"Addon {addon_slug} is not installed")
if not installed and addon.is_installed: if not installed and addon.is_installed:
addon = cast(Addon, addon)
if not addon.addon_store: if not addon.addon_store:
raise APINotFound(f"Addon {addon_slug} does not exist in the store") raise APINotFound(f"Addon {addon_slug} does not exist in the store")
return addon.addon_store return addon.addon_store
@ -109,7 +111,7 @@ class APIStore(CoreSysAttributes):
def _extract_repository(self, request: web.Request) -> Repository: def _extract_repository(self, request: web.Request) -> Repository:
"""Return repository, throw an exception it it doesn't exist.""" """Return repository, throw an exception it it doesn't exist."""
repository_slug: str = request.match_info.get("repository") repository_slug: str = request.match_info["repository"]
if repository_slug not in self.sys_store.repositories: if repository_slug not in self.sys_store.repositories:
raise APINotFound( raise APINotFound(
@ -124,7 +126,7 @@ class APIStore(CoreSysAttributes):
"""Generate addon information.""" """Generate addon information."""
installed = ( installed = (
self.sys_addons.get(addon.slug, local_only=True) cast(Addon, self.sys_addons.get(addon.slug, local_only=True))
if addon.is_installed if addon.is_installed
else None else None
) )
@ -144,12 +146,10 @@ class APIStore(CoreSysAttributes):
ATTR_REPOSITORY: addon.repository, ATTR_REPOSITORY: addon.repository,
ATTR_SLUG: addon.slug, ATTR_SLUG: addon.slug,
ATTR_STAGE: addon.stage, ATTR_STAGE: addon.stage,
ATTR_UPDATE_AVAILABLE: installed.need_update ATTR_UPDATE_AVAILABLE: installed.need_update if installed else False,
if addon.is_installed
else False,
ATTR_URL: addon.url, ATTR_URL: addon.url,
ATTR_VERSION_LATEST: addon.latest_version, ATTR_VERSION_LATEST: addon.latest_version,
ATTR_VERSION: installed.version if addon.is_installed else None, ATTR_VERSION: installed.version if installed else None,
} }
if extended: if extended:
data.update( data.update(
@ -246,7 +246,7 @@ class APIStore(CoreSysAttributes):
# Used by legacy routing for addons/{addon}/info, can be refactored out when that is removed (1/2023) # Used by legacy routing for addons/{addon}/info, can be refactored out when that is removed (1/2023)
async def addons_addon_info_wrapped(self, request: web.Request) -> dict[str, Any]: async def addons_addon_info_wrapped(self, request: web.Request) -> dict[str, Any]:
"""Return add-on information directly (not api).""" """Return add-on information directly (not api)."""
addon: AddonStore = self._extract_addon(request) addon = cast(AddonStore, self._extract_addon(request))
return await self._generate_addon_information(addon, True) return await self._generate_addon_information(addon, True)
@api_process_raw(CONTENT_TYPE_PNG) @api_process_raw(CONTENT_TYPE_PNG)

View File

@ -230,19 +230,12 @@ class APISupervisor(CoreSysAttributes):
await asyncio.shield(self.sys_supervisor.update(version)) await asyncio.shield(self.sys_supervisor.update(version))
@api_process @api_process
def reload(self, request: web.Request) -> Awaitable[None]: def reload(self, request: web.Request) -> Awaitable:
"""Reload add-ons, configuration, etc.""" """Reload add-ons, configuration, etc."""
return asyncio.shield( return asyncio.gather(
asyncio.wait( asyncio.shield(self.sys_updater.reload()),
[ asyncio.shield(self.sys_homeassistant.secrets.reload()),
self.sys_create_task(coro) asyncio.shield(self.sys_resolution.evaluate.evaluate_system()),
for coro in [
self.sys_updater.reload(),
self.sys_homeassistant.secrets.reload(),
self.sys_resolution.evaluate.evaluate_system(),
]
]
)
) )
@api_process @api_process

View File

@ -21,7 +21,7 @@ from ..const import (
RESULT_ERROR, RESULT_ERROR,
RESULT_OK, RESULT_OK,
) )
from ..coresys import CoreSys from ..coresys import CoreSys, CoreSysAttributes
from ..exceptions import APIError, BackupFileNotFoundError, DockerAPIError, HassioError from ..exceptions import APIError, BackupFileNotFoundError, DockerAPIError, HassioError
from ..utils import check_exception_chain, get_message_from_exception_chain from ..utils import check_exception_chain, get_message_from_exception_chain
from ..utils.json import json_dumps, json_loads as json_loads_util from ..utils.json import json_dumps, json_loads as json_loads_util
@ -29,7 +29,7 @@ from ..utils.log_format import format_message
from . import const from . import const
def excract_supervisor_token(request: web.Request) -> str | None: def extract_supervisor_token(request: web.Request) -> str | None:
"""Extract Supervisor token from request.""" """Extract Supervisor token from request."""
if supervisor_token := request.headers.get(HEADER_TOKEN): if supervisor_token := request.headers.get(HEADER_TOKEN):
return supervisor_token return supervisor_token
@ -58,7 +58,9 @@ def json_loads(data: Any) -> dict[str, Any]:
def api_process(method): def api_process(method):
"""Wrap function with true/false calls to rest api.""" """Wrap function with true/false calls to rest api."""
async def wrap_api(api, *args, **kwargs): async def wrap_api(
api: CoreSysAttributes, *args, **kwargs
) -> web.Response | web.StreamResponse:
"""Return API information.""" """Return API information."""
try: try:
answer = await method(api, *args, **kwargs) answer = await method(api, *args, **kwargs)
@ -85,7 +87,7 @@ def api_process(method):
def require_home_assistant(method): def require_home_assistant(method):
"""Ensure that the request comes from Home Assistant.""" """Ensure that the request comes from Home Assistant."""
async def wrap_api(api, *args, **kwargs): async def wrap_api(api: CoreSysAttributes, *args, **kwargs) -> Any:
"""Return API information.""" """Return API information."""
coresys: CoreSys = api.coresys coresys: CoreSys = api.coresys
request: Request = args[0] request: Request = args[0]
@ -102,7 +104,9 @@ def api_process_raw(content, *, error_type=None):
def wrap_method(method): def wrap_method(method):
"""Wrap function with raw output to rest api.""" """Wrap function with raw output to rest api."""
async def wrap_api(api, *args, **kwargs): async def wrap_api(
api: CoreSysAttributes, *args, **kwargs
) -> web.Response | web.StreamResponse:
"""Return api information.""" """Return api information."""
try: try:
msg_data = await method(api, *args, **kwargs) msg_data = await method(api, *args, **kwargs)
@ -165,7 +169,7 @@ def api_return_error(
) )
def api_return_ok(data: dict[str, Any] | None = None) -> web.Response: def api_return_ok(data: dict[str, Any] | list[Any] | None = None) -> web.Response:
"""Return an API ok answer.""" """Return an API ok answer."""
return web.json_response( return web.json_response(
{JSON_RESULT: RESULT_OK, JSON_DATA: data or {}}, {JSON_RESULT: RESULT_OK, JSON_DATA: data or {}},
@ -174,7 +178,7 @@ def api_return_ok(data: dict[str, Any] | None = None) -> web.Response:
async def api_validate( async def api_validate(
schema: vol.Schema, schema: vol.Schema | vol.All,
request: web.Request, request: web.Request,
origin: list[str] | None = None, origin: list[str] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:

View File

@ -68,7 +68,9 @@ class Auth(FileConfiguration, CoreSysAttributes):
self._data.pop(username_h, None) self._data.pop(username_h, None)
await self.save_data() await self.save_data()
async def check_login(self, addon: Addon, username: str, password: str) -> bool: async def check_login(
self, addon: Addon, username: str | None, password: str | None
) -> bool:
"""Check username login.""" """Check username login."""
if password is None: if password is None:
raise AuthError("None as password is not supported!", _LOGGER.error) raise AuthError("None as password is not supported!", _LOGGER.error)

View File

@ -196,7 +196,9 @@ class BackupManager(FileConfiguration, JobGroup):
self.sys_config.path_backup, self.sys_config.path_backup,
self.sys_config.path_core_backup, self.sys_config.path_core_backup,
}: }:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
_LOGGER.error("Could not list backups from %s: %s", path.as_posix(), err) _LOGGER.error("Could not list backups from %s: %s", path.as_posix(), err)
return [] return []
@ -350,7 +352,9 @@ class BackupManager(FileConfiguration, JobGroup):
None, None,
LOCATION_CLOUD_BACKUP, LOCATION_CLOUD_BACKUP,
}: }:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
raise BackupError(msg, _LOGGER.error) from err raise BackupError(msg, _LOGGER.error) from err
# If backup has been removed from all locations, remove it from cache # If backup has been removed from all locations, remove it from cache
@ -404,7 +408,9 @@ class BackupManager(FileConfiguration, JobGroup):
copy_to_additional_locations copy_to_additional_locations
) )
except BackupDataDiskBadMessageError: except BackupDataDiskBadMessageError:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
raise raise
backup.all_locations.update( backup.all_locations.update(
@ -445,7 +451,9 @@ class BackupManager(FileConfiguration, JobGroup):
await self.sys_run_in_executor(backup.tarfile.rename, tar_file) await self.sys_run_in_executor(backup.tarfile.rename, tar_file)
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG and location in {LOCATION_CLOUD_BACKUP, None}: if err.errno == errno.EBADMSG and location in {LOCATION_CLOUD_BACKUP, None}:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
_LOGGER.error("Can't move backup file to storage: %s", err) _LOGGER.error("Can't move backup file to storage: %s", err)
return None return None

View File

@ -114,7 +114,7 @@ class Core(CoreSysAttributes):
self.sys_resolution.create_issue( self.sys_resolution.create_issue(
IssueType.UPDATE_ROLLBACK, ContextType.SUPERVISOR IssueType.UPDATE_ROLLBACK, ContextType.SUPERVISOR
) )
self.sys_resolution.unhealthy = UnhealthyReason.SUPERVISOR self.sys_resolution.add_unhealthy_reason(UnhealthyReason.SUPERVISOR)
# Fix wrong version in config / avoid boot loop on OS # Fix wrong version in config / avoid boot loop on OS
self.sys_config.version = self.sys_supervisor.version self.sys_config.version = self.sys_supervisor.version
@ -177,7 +177,7 @@ class Core(CoreSysAttributes):
_LOGGER.critical( _LOGGER.critical(
"Fatal error happening on load Task %s: %s", setup_task, err "Fatal error happening on load Task %s: %s", setup_task, err
) )
self.sys_resolution.unhealthy = UnhealthyReason.SETUP self.sys_resolution.add_unhealthy_reason(UnhealthyReason.SETUP)
await async_capture_exception(err) await async_capture_exception(err)
# Set OS Agent diagnostics if needed # Set OS Agent diagnostics if needed

View File

@ -807,7 +807,7 @@ class CoreSysAttributes:
return self.coresys.now() return self.coresys.now()
def sys_run_in_executor( def sys_run_in_executor(
self, funct: Callable[..., T], *args: tuple[Any], **kwargs: dict[str, Any] self, funct: Callable[..., T], *args, **kwargs
) -> Coroutine[Any, Any, T]: ) -> Coroutine[Any, Any, T]:
"""Add a job to the executor pool.""" """Add a job to the executor pool."""
return self.coresys.run_in_executor(funct, *args, **kwargs) return self.coresys.run_in_executor(funct, *args, **kwargs)
@ -820,8 +820,8 @@ class CoreSysAttributes:
self, self,
delay: float, delay: float,
funct: Callable[..., Coroutine[Any, Any, T]], funct: Callable[..., Coroutine[Any, Any, T]],
*args: tuple[Any], *args,
**kwargs: dict[str, Any], **kwargs,
) -> asyncio.TimerHandle: ) -> asyncio.TimerHandle:
"""Start a task after a delay.""" """Start a task after a delay."""
return self.coresys.call_later(delay, funct, *args, **kwargs) return self.coresys.call_later(delay, funct, *args, **kwargs)
@ -830,8 +830,8 @@ class CoreSysAttributes:
self, self,
when: datetime, when: datetime,
funct: Callable[..., Coroutine[Any, Any, T]], funct: Callable[..., Coroutine[Any, Any, T]],
*args: tuple[Any], *args,
**kwargs: dict[str, Any], **kwargs,
) -> asyncio.TimerHandle: ) -> asyncio.TimerHandle:
"""Start a task at the specified datetime.""" """Start a task at the specified datetime."""
return self.coresys.call_at(when, funct, *args, **kwargs) return self.coresys.call_at(when, funct, *args, **kwargs)

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from contextlib import suppress from contextlib import suppress
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from uuid import UUID, uuid4 from uuid import uuid4
import attr import attr
@ -31,7 +31,7 @@ class Message:
addon: str = attr.ib() addon: str = attr.ib()
service: str = attr.ib() service: str = attr.ib()
config: dict[str, Any] = attr.ib(eq=False) config: dict[str, Any] = attr.ib(eq=False)
uuid: UUID = attr.ib(factory=lambda: uuid4().hex, eq=False) uuid: str = attr.ib(factory=lambda: uuid4().hex, eq=False)
class Discovery(CoreSysAttributes, FileConfiguration): class Discovery(CoreSysAttributes, FileConfiguration):

View File

@ -53,7 +53,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
IMAGE_WITH_HOST = re.compile(r"^((?:[a-z0-9]+(?:-[a-z0-9]+)*\.)+[a-z]{2,})\/.+") IMAGE_WITH_HOST = re.compile(r"^((?:[a-z0-9]+(?:-[a-z0-9]+)*\.)+[a-z]{2,})\/.+")
DOCKER_HUB = "hub.docker.com" DOCKER_HUB = "hub.docker.com"
MAP_ARCH = { MAP_ARCH: dict[CpuArch | str, str] = {
CpuArch.ARMV7: "linux/arm/v7", CpuArch.ARMV7: "linux/arm/v7",
CpuArch.ARMHF: "linux/arm/v6", CpuArch.ARMHF: "linux/arm/v6",
CpuArch.AARCH64: "linux/arm64", CpuArch.AARCH64: "linux/arm64",

View File

@ -40,7 +40,7 @@ class HwMonitor(CoreSysAttributes):
), ),
) )
except OSError: except OSError:
self.sys_resolution.unhealthy = UnhealthyReason.PRIVILEGED self.sys_resolution.add_unhealthy_reason(UnhealthyReason.PRIVILEGED)
_LOGGER.critical("Not privileged to run udev monitor!") _LOGGER.critical("Not privileged to run udev monitor!")
else: else:
self.observer.start() self.observer.start()

View File

@ -1,7 +1,8 @@
"""Home Assistant control object.""" """Home Assistant control object."""
import asyncio import asyncio
from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress from collections.abc import AsyncIterator
from contextlib import asynccontextmanager, suppress
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
import logging import logging
@ -10,6 +11,7 @@ from typing import Any
import aiohttp import aiohttp
from aiohttp import hdrs from aiohttp import hdrs
from awesomeversion import AwesomeVersion from awesomeversion import AwesomeVersion
from multidict import MultiMapping
from ..coresys import CoreSys, CoreSysAttributes from ..coresys import CoreSys, CoreSysAttributes
from ..exceptions import HomeAssistantAPIError, HomeAssistantAuthError from ..exceptions import HomeAssistantAPIError, HomeAssistantAuthError
@ -84,10 +86,10 @@ class HomeAssistantAPI(CoreSysAttributes):
json: dict[str, Any] | None = None, json: dict[str, Any] | None = None,
content_type: str | None = None, content_type: str | None = None,
data: Any = None, data: Any = None,
timeout: int = 30, timeout: int | None = 30,
params: dict[str, str] | None = None, params: MultiMapping[str] | None = None,
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
) -> AbstractAsyncContextManager[aiohttp.ClientResponse]: ) -> AsyncIterator[aiohttp.ClientResponse]:
"""Async context manager to make a request with right auth.""" """Async context manager to make a request with right auth."""
url = f"{self.sys_homeassistant.api_url}/{path}" url = f"{self.sys_homeassistant.api_url}/{path}"
headers = headers or {} headers = headers or {}

View File

@ -329,7 +329,9 @@ class HomeAssistant(FileConfiguration, CoreSysAttributes):
await self.sys_run_in_executor(write_pulse_config) await self.sys_run_in_executor(write_pulse_config)
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
_LOGGER.error("Home Assistant can't write pulse/client.config: %s", err) _LOGGER.error("Home Assistant can't write pulse/client.config: %s", err)
else: else:
_LOGGER.info("Update pulse/client.config: %s", self.path_pulse) _LOGGER.info("Update pulse/client.config: %s", self.path_pulse)

View File

@ -90,7 +90,9 @@ class AppArmorControl(CoreSysAttributes):
await self.sys_run_in_executor(shutil.copyfile, profile_file, dest_profile) await self.sys_run_in_executor(shutil.copyfile, profile_file, dest_profile)
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
raise HostAppArmorError( raise HostAppArmorError(
f"Can't copy {profile_file}: {err}", _LOGGER.error f"Can't copy {profile_file}: {err}", _LOGGER.error
) from err ) from err
@ -115,7 +117,9 @@ class AppArmorControl(CoreSysAttributes):
await self.sys_run_in_executor(profile_file.unlink) await self.sys_run_in_executor(profile_file.unlink)
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
raise HostAppArmorError( raise HostAppArmorError(
f"Can't remove profile: {err}", _LOGGER.error f"Can't remove profile: {err}", _LOGGER.error
) from err ) from err
@ -131,7 +135,9 @@ class AppArmorControl(CoreSysAttributes):
shutil.copy(profile_file, backup_file) shutil.copy(profile_file, backup_file)
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
raise HostAppArmorError( raise HostAppArmorError(
f"Can't backup profile {profile_name}: {err}", _LOGGER.error f"Can't backup profile {profile_name}: {err}", _LOGGER.error
) from err ) from err

View File

@ -35,7 +35,7 @@ class Job(CoreSysAttributes):
name: str, name: str,
conditions: list[JobCondition] | None = None, conditions: list[JobCondition] | None = None,
cleanup: bool = True, cleanup: bool = True,
on_condition: JobException | None = None, on_condition: type[JobException] | None = None,
limit: JobExecutionLimit | None = None, limit: JobExecutionLimit | None = None,
throttle_period: timedelta throttle_period: timedelta
| Callable[[CoreSys, datetime, list[datetime] | None], timedelta] | Callable[[CoreSys, datetime, list[datetime] | None], timedelta]

View File

@ -237,7 +237,9 @@ class OSManager(CoreSysAttributes):
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
raise HassOSUpdateError( raise HassOSUpdateError(
f"Can't write OTA file: {err!s}", _LOGGER.error f"Can't write OTA file: {err!s}", _LOGGER.error
) from err ) from err

View File

@ -94,7 +94,9 @@ class PluginAudio(PluginBase):
) )
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
_LOGGER.error("Can't read pulse-client.tmpl: %s", err) _LOGGER.error("Can't read pulse-client.tmpl: %s", err)
@ -111,7 +113,9 @@ class PluginAudio(PluginBase):
await self.sys_run_in_executor(setup_default_asound) await self.sys_run_in_executor(setup_default_asound)
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
_LOGGER.error("Can't create default asound: %s", err) _LOGGER.error("Can't create default asound: %s", err)
@Job( @Job(

View File

@ -156,7 +156,9 @@ class PluginDns(PluginBase):
) )
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
_LOGGER.error("Can't read resolve.tmpl: %s", err) _LOGGER.error("Can't read resolve.tmpl: %s", err)
try: try:
@ -165,7 +167,9 @@ class PluginDns(PluginBase):
) )
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
_LOGGER.error("Can't read hosts.tmpl: %s", err) _LOGGER.error("Can't read hosts.tmpl: %s", err)
await self._init_hosts() await self._init_hosts()
@ -343,7 +347,9 @@ class PluginDns(PluginBase):
) )
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
raise CoreDNSError(f"Can't update hosts: {err}", _LOGGER.error) from err raise CoreDNSError(f"Can't update hosts: {err}", _LOGGER.error) from err
async def add_host( async def add_host(
@ -432,7 +438,9 @@ class PluginDns(PluginBase):
await self.sys_run_in_executor(resolv_conf.write_text, data) await self.sys_run_in_executor(resolv_conf.write_text, data)
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
_LOGGER.warning("Can't write/update %s: %s", resolv_conf, err) _LOGGER.warning("Can't write/update %s: %s", resolv_conf, err)
return return

View File

@ -30,7 +30,7 @@ class CheckSupervisorTrust(CheckBase):
try: try:
await self.sys_supervisor.check_trust() await self.sys_supervisor.check_trust()
except CodeNotaryUntrusted: except CodeNotaryUntrusted:
self.sys_resolution.unhealthy = UnhealthyReason.UNTRUSTED self.sys_resolution.add_unhealthy_reason(UnhealthyReason.UNTRUSTED)
self.sys_resolution.create_issue(IssueType.TRUST, ContextType.SUPERVISOR) self.sys_resolution.create_issue(IssueType.TRUST, ContextType.SUPERVISOR)
except CodeNotaryError: except CodeNotaryError:
pass pass

View File

@ -67,6 +67,6 @@ class ResolutionEvaluation(CoreSysAttributes):
await async_capture_exception(err) await async_capture_exception(err)
if any(reason in self.sys_resolution.unsupported for reason in UNHEALTHY): if any(reason in self.sys_resolution.unsupported for reason in UNHEALTHY):
self.sys_resolution.unhealthy = UnhealthyReason.DOCKER self.sys_resolution.add_unhealthy_reason(UnhealthyReason.DOCKER)
_LOGGER.info("System evaluation complete") _LOGGER.info("System evaluation complete")

View File

@ -23,7 +23,7 @@ class EvaluateBase(ABC, CoreSysAttributes):
return return
if await self.evaluate(): if await self.evaluate():
if self.reason not in self.sys_resolution.unsupported: if self.reason not in self.sys_resolution.unsupported:
self.sys_resolution.unsupported = self.reason self.sys_resolution.add_unsupported_reason(self.reason)
_LOGGER.warning( _LOGGER.warning(
"%s (more-info: https://www.home-assistant.io/more-info/unsupported/%s)", "%s (more-info: https://www.home-assistant.io/more-info/unsupported/%s)",
self.on_failure, self.on_failure,

View File

@ -101,6 +101,6 @@ class EvaluateContainer(EvaluateBase):
"Found image in unhealthy image list '%s' on the host", "Found image in unhealthy image list '%s' on the host",
image_name, image_name,
) )
self.sys_resolution.unhealthy = UnhealthyReason.DOCKER self.sys_resolution.add_unhealthy_reason(UnhealthyReason.DOCKER)
return len(self._images) != 0 return len(self._images) != 0

View File

@ -51,7 +51,9 @@ class EvaluateSourceMods(EvaluateBase):
) )
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
self.sys_resolution.create_issue( self.sys_resolution.create_issue(
IssueType.CORRUPT_FILESYSTEM, ContextType.SYSTEM IssueType.CORRUPT_FILESYSTEM, ContextType.SYSTEM

View File

@ -87,28 +87,12 @@ class ResolutionManager(FileConfiguration, CoreSysAttributes):
"""Return a list of issues.""" """Return a list of issues."""
return self._issues return self._issues
@issues.setter
def issues(self, issue: Issue) -> None:
"""Add issues."""
if issue in self._issues:
return
_LOGGER.info(
"Create new issue %s - %s / %s", issue.type, issue.context, issue.reference
)
self._issues.append(issue)
# Event on issue creation
self.sys_homeassistant.websocket.supervisor_event(
WSEvent.ISSUE_CHANGED, self._make_issue_message(issue)
)
@property @property
def suggestions(self) -> list[Suggestion]: def suggestions(self) -> list[Suggestion]:
"""Return a list of suggestions that can handled.""" """Return a list of suggestions that can handled."""
return self._suggestions return self._suggestions
@suggestions.setter def add_suggestion(self, suggestion: Suggestion) -> None:
def suggestions(self, suggestion: Suggestion) -> None:
"""Add suggestion.""" """Add suggestion."""
if suggestion in self._suggestions: if suggestion in self._suggestions:
return return
@ -132,8 +116,7 @@ class ResolutionManager(FileConfiguration, CoreSysAttributes):
"""Return a list of unsupported reasons.""" """Return a list of unsupported reasons."""
return self._unsupported return self._unsupported
@unsupported.setter def add_unsupported_reason(self, reason: UnsupportedReason) -> None:
def unsupported(self, reason: UnsupportedReason) -> None:
"""Add a reason for unsupported.""" """Add a reason for unsupported."""
if reason not in self._unsupported: if reason not in self._unsupported:
self._unsupported.append(reason) self._unsupported.append(reason)
@ -144,12 +127,11 @@ class ResolutionManager(FileConfiguration, CoreSysAttributes):
@property @property
def unhealthy(self) -> list[UnhealthyReason]: def unhealthy(self) -> list[UnhealthyReason]:
"""Return a list of unsupported reasons.""" """Return a list of unhealthy reasons."""
return self._unhealthy return self._unhealthy
@unhealthy.setter def add_unhealthy_reason(self, reason: UnhealthyReason) -> None:
def unhealthy(self, reason: UnhealthyReason) -> None: """Add a reason for unhealthy."""
"""Add a reason for unsupported."""
if reason not in self._unhealthy: if reason not in self._unhealthy:
self._unhealthy.append(reason) self._unhealthy.append(reason)
self.sys_homeassistant.websocket.supervisor_event( self.sys_homeassistant.websocket.supervisor_event(
@ -198,11 +180,21 @@ class ResolutionManager(FileConfiguration, CoreSysAttributes):
"""Add an issue and suggestions.""" """Add an issue and suggestions."""
if suggestions: if suggestions:
for suggestion in suggestions: for suggestion in suggestions:
self.suggestions = Suggestion( self.add_suggestion(
suggestion, issue.context, issue.reference Suggestion(suggestion, issue.context, issue.reference)
) )
self.issues = issue if issue in self._issues:
return
_LOGGER.info(
"Create new issue %s - %s / %s", issue.type, issue.context, issue.reference
)
self._issues.append(issue)
# Event on issue creation
self.sys_homeassistant.websocket.supervisor_event(
WSEvent.ISSUE_CHANGED, self._make_issue_message(issue)
)
async def load(self): async def load(self):
"""Load the resoulution manager.""" """Load the resoulution manager."""

View File

@ -179,7 +179,9 @@ class StoreData(CoreSysAttributes):
except OSError as err: except OSError as err:
suggestion = None suggestion = None
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
elif path.stem != StoreType.LOCAL: elif path.stem != StoreType.LOCAL:
suggestion = [SuggestionType.EXECUTE_RESET] suggestion = [SuggestionType.EXECUTE_RESET]
self.sys_resolution.create_issue( self.sys_resolution.create_issue(

View File

@ -174,7 +174,9 @@ class Supervisor(CoreSysAttributes):
except OSError as err: except OSError as err:
if err.errno == errno.EBADMSG: if err.errno == errno.EBADMSG:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE self.sys_resolution.add_unhealthy_reason(
UnhealthyReason.OSERROR_BAD_MESSAGE
)
raise SupervisorAppArmorError( raise SupervisorAppArmorError(
f"Can't write temporary profile: {err!s}", _LOGGER.error f"Can't write temporary profile: {err!s}", _LOGGER.error
) from err ) from err

View File

@ -61,7 +61,7 @@ def journal_verbose_formatter(entries: dict[str, str]) -> str:
async def journal_logs_reader( async def journal_logs_reader(
journal_logs: ClientResponse, log_formatter: LogFormatter = LogFormatter.PLAIN journal_logs: ClientResponse, log_formatter: LogFormatter = LogFormatter.PLAIN
) -> AsyncGenerator[str | None, str]: ) -> AsyncGenerator[tuple[str | None, str]]:
"""Read logs from systemd journal line by line, formatted using the given formatter. """Read logs from systemd journal line by line, formatted using the given formatter.
Returns a generator of (cursor, formatted_entry) tuples. Returns a generator of (cursor, formatted_entry) tuples.

View File

@ -27,9 +27,9 @@ from supervisor.resolution.data import Issue, Suggestion
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_resolution_base(coresys: CoreSys, api_client: TestClient): async def test_api_resolution_base(coresys: CoreSys, api_client: TestClient):
"""Test resolution manager api.""" """Test resolution manager api."""
coresys.resolution.unsupported = UnsupportedReason.OS coresys.resolution.add_unsupported_reason(UnsupportedReason.OS)
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM Suggestion(SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM)
) )
coresys.resolution.create_issue(IssueType.FREE_SPACE, ContextType.SYSTEM) coresys.resolution.create_issue(IssueType.FREE_SPACE, ContextType.SYSTEM)
@ -47,8 +47,8 @@ async def test_api_resolution_dismiss_suggestion(
coresys: CoreSys, api_client: TestClient coresys: CoreSys, api_client: TestClient
): ):
"""Test resolution manager suggestion apply api.""" """Test resolution manager suggestion apply api."""
coresys.resolution.suggestions = clear_backup = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM clear_backup := Suggestion(SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM)
) )
assert coresys.resolution.suggestions[-1].type == SuggestionType.CLEAR_FULL_BACKUP assert coresys.resolution.suggestions[-1].type == SuggestionType.CLEAR_FULL_BACKUP
@ -61,11 +61,13 @@ async def test_api_resolution_apply_suggestion(
coresys: CoreSys, api_client: TestClient coresys: CoreSys, api_client: TestClient
): ):
"""Test resolution manager suggestion apply api.""" """Test resolution manager suggestion apply api."""
coresys.resolution.suggestions = clear_backup = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM clear_backup := Suggestion(SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM)
) )
coresys.resolution.suggestions = create_backup = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.CREATE_FULL_BACKUP, ContextType.SYSTEM create_backup := Suggestion(
SuggestionType.CREATE_FULL_BACKUP, ContextType.SYSTEM
)
) )
mock_backups = AsyncMock() mock_backups = AsyncMock()
@ -89,8 +91,8 @@ async def test_api_resolution_apply_suggestion(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_resolution_dismiss_issue(coresys: CoreSys, api_client: TestClient): async def test_api_resolution_dismiss_issue(coresys: CoreSys, api_client: TestClient):
"""Test resolution manager issue apply api.""" """Test resolution manager issue apply api."""
coresys.resolution.issues = updated_failed = Issue( coresys.resolution.add_issue(
IssueType.UPDATE_FAILED, ContextType.SYSTEM updated_failed := Issue(IssueType.UPDATE_FAILED, ContextType.SYSTEM)
) )
assert coresys.resolution.issues[-1].type == IssueType.UPDATE_FAILED assert coresys.resolution.issues[-1].type == IssueType.UPDATE_FAILED
@ -101,7 +103,7 @@ async def test_api_resolution_dismiss_issue(coresys: CoreSys, api_client: TestCl
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_resolution_unhealthy(coresys: CoreSys, api_client: TestClient): async def test_api_resolution_unhealthy(coresys: CoreSys, api_client: TestClient):
"""Test resolution manager api.""" """Test resolution manager api."""
coresys.resolution.unhealthy = UnhealthyReason.DOCKER coresys.resolution.add_unhealthy_reason(UnhealthyReason.DOCKER)
resp = await api_client.get("/resolution/info") resp = await api_client.get("/resolution/info")
result = await resp.json() result = await resp.json()
@ -142,8 +144,8 @@ async def test_api_resolution_suggestions_for_issue(
coresys: CoreSys, api_client: TestClient coresys: CoreSys, api_client: TestClient
): ):
"""Test getting suggestions that fix an issue.""" """Test getting suggestions that fix an issue."""
coresys.resolution.issues = corrupt_repo = Issue( coresys.resolution.add_issue(
IssueType.CORRUPT_REPOSITORY, ContextType.STORE, "repo_1" corrupt_repo := Issue(IssueType.CORRUPT_REPOSITORY, ContextType.STORE, "repo_1")
) )
resp = await api_client.get(f"/resolution/issue/{corrupt_repo.uuid}/suggestions") resp = await api_client.get(f"/resolution/issue/{corrupt_repo.uuid}/suggestions")
@ -151,11 +153,15 @@ async def test_api_resolution_suggestions_for_issue(
assert result["data"]["suggestions"] == [] assert result["data"]["suggestions"] == []
coresys.resolution.suggestions = execute_reset = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_RESET, ContextType.STORE, "repo_1" execute_reset := Suggestion(
SuggestionType.EXECUTE_RESET, ContextType.STORE, "repo_1"
)
) )
coresys.resolution.suggestions = execute_remove = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "repo_1" execute_remove := Suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "repo_1"
)
) )
resp = await api_client.get(f"/resolution/issue/{corrupt_repo.uuid}/suggestions") resp = await api_client.get(f"/resolution/issue/{corrupt_repo.uuid}/suggestions")

View File

@ -50,7 +50,7 @@ async def test_healthy(coresys: CoreSys, caplog: pytest.LogCaptureFixture):
test = TestClass(coresys) test = TestClass(coresys)
assert await test.execute() assert await test.execute()
coresys.resolution.unhealthy = UnhealthyReason.DOCKER coresys.resolution.add_unhealthy_reason(UnhealthyReason.DOCKER)
assert not await test.execute() assert not await test.execute()
assert "blocked from execution, system is not healthy - docker" in caplog.text assert "blocked from execution, system is not healthy - docker" in caplog.text

View File

@ -96,7 +96,7 @@ def test_diagnostics_disabled(coresys):
def test_not_supported(coresys): def test_not_supported(coresys):
"""Test if not supported.""" """Test if not supported."""
coresys.config.diagnostics = True coresys.config.diagnostics = True
coresys.resolution.unsupported = UnsupportedReason.DOCKER_VERSION coresys.resolution.add_unsupported_reason(UnsupportedReason.DOCKER_VERSION)
assert filter_data(coresys, SAMPLE_EVENT, {}) is None assert filter_data(coresys, SAMPLE_EVENT, {}) is None
@ -215,7 +215,7 @@ async def test_unhealthy_on_report(coresys):
coresys.config.diagnostics = True coresys.config.diagnostics = True
await coresys.core.set_state(CoreState.RUNNING) await coresys.core.set_state(CoreState.RUNNING)
coresys.resolution.unhealthy = UnhealthyReason.DOCKER coresys.resolution.add_unhealthy_reason(UnhealthyReason.DOCKER)
with patch("shutil.disk_usage", return_value=(42, 42, 2 * (1024.0**3))): with patch("shutil.disk_usage", return_value=(42, 42, 2 * (1024.0**3))):
event = filter_data(coresys, SAMPLE_EVENT, {}) event = filter_data(coresys, SAMPLE_EVENT, {})

View File

@ -15,15 +15,19 @@ async def test_fixup(coresys: CoreSys, install_addon_ssh: Addon):
assert addon_execute_remove.auto is False assert addon_execute_remove.auto is False
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_REMOVE, Suggestion(
ContextType.ADDON, SuggestionType.EXECUTE_REMOVE,
reference=install_addon_ssh.slug, ContextType.ADDON,
reference=install_addon_ssh.slug,
)
) )
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.DETACHED_ADDON_REMOVED, Issue(
ContextType.ADDON, IssueType.DETACHED_ADDON_REMOVED,
reference=install_addon_ssh.slug, ContextType.ADDON,
reference=install_addon_ssh.slug,
)
) )
with patch.object(Addon, "uninstall") as uninstall: with patch.object(Addon, "uninstall") as uninstall:

View File

@ -28,8 +28,8 @@ async def test_check_autofix(coresys: CoreSys):
"system_create_full_backup" "system_create_full_backup"
].process_fixup.assert_not_called() ].process_fixup.assert_not_called()
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.CREATE_FULL_BACKUP, ContextType.SYSTEM Suggestion(SuggestionType.CREATE_FULL_BACKUP, ContextType.SYSTEM)
) )
with patch( with patch(
"supervisor.resolution.fixups.system_create_full_backup.FixupSystemCreateFullBackup.auto", "supervisor.resolution.fixups.system_create_full_backup.FixupSystemCreateFullBackup.auto",

View File

@ -15,11 +15,11 @@ async def test_fixup(coresys: CoreSys):
assert store_execute_reload.auto assert store_execute_reload.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_RELOAD, ContextType.STORE, reference="test" Suggestion(SuggestionType.EXECUTE_RELOAD, ContextType.STORE, reference="test")
) )
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.FATAL_ERROR, ContextType.STORE, reference="test" Issue(IssueType.FATAL_ERROR, ContextType.STORE, reference="test")
) )
mock_repositorie = AsyncMock() mock_repositorie = AsyncMock()

View File

@ -16,11 +16,15 @@ async def test_fixup(coresys: CoreSys, repository: Repository):
assert store_execute_remove.auto is False assert store_execute_remove.auto is False
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, reference=repository.slug Suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, reference=repository.slug
)
) )
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.CORRUPT_REPOSITORY, ContextType.STORE, reference=repository.slug Issue(
IssueType.CORRUPT_REPOSITORY, ContextType.STORE, reference=repository.slug
)
) )
with patch.object(type(repository), "remove") as remove_repo: with patch.object(type(repository), "remove") as remove_repo:

View File

@ -18,11 +18,11 @@ async def test_fixup(coresys: CoreSys, tmp_path):
assert store_execute_reset.auto assert store_execute_reset.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_RESET, ContextType.STORE, reference="test" Suggestion(SuggestionType.EXECUTE_RESET, ContextType.STORE, reference="test")
) )
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.CORRUPT_REPOSITORY, ContextType.STORE, reference="test" Issue(IssueType.CORRUPT_REPOSITORY, ContextType.STORE, reference="test")
) )
test_repo.mkdir() test_repo.mkdir()

View File

@ -85,11 +85,13 @@ async def test_fixup(
assert not system_adopt_data_disk.auto assert not system_adopt_data_disk.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.ADOPT_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1" Suggestion(
SuggestionType.ADOPT_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1"
)
) )
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1" Issue(IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1")
) )
udisks2_service.resolved_devices = [ udisks2_service.resolved_devices = [
["/org/freedesktop/UDisks2/block_devices/sda1"], ["/org/freedesktop/UDisks2/block_devices/sda1"],
@ -124,11 +126,13 @@ async def test_fixup_device_removed(
assert not system_adopt_data_disk.auto assert not system_adopt_data_disk.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.ADOPT_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1" Suggestion(
SuggestionType.ADOPT_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1"
)
) )
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1" Issue(IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1")
) )
udisks2_service.resolved_devices = [] udisks2_service.resolved_devices = []
@ -159,11 +163,13 @@ async def test_fixup_reboot_failed(
assert not system_adopt_data_disk.auto assert not system_adopt_data_disk.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.ADOPT_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1" Suggestion(
SuggestionType.ADOPT_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1"
)
) )
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1" Issue(IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1")
) )
udisks2_service.resolved_devices = [ udisks2_service.resolved_devices = [
["/org/freedesktop/UDisks2/block_devices/sda1"], ["/org/freedesktop/UDisks2/block_devices/sda1"],
@ -209,11 +215,13 @@ async def test_fixup_disabled_data_disk(
assert not system_adopt_data_disk.auto assert not system_adopt_data_disk.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.ADOPT_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1" Suggestion(
SuggestionType.ADOPT_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1"
)
) )
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.DISABLED_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1" Issue(IssueType.DISABLED_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1")
) )
udisks2_service.resolved_devices = [ udisks2_service.resolved_devices = [
["/org/freedesktop/UDisks2/block_devices/sda1"], ["/org/freedesktop/UDisks2/block_devices/sda1"],

View File

@ -17,8 +17,8 @@ async def test_fixup(coresys: CoreSys, backups: list[Backup]):
assert not clear_full_backup.auto assert not clear_full_backup.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM Suggestion(SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM)
) )
newest_full_backup = coresys.backups._backups["sn4"] newest_full_backup = coresys.backups._backups["sn4"]

View File

@ -17,8 +17,8 @@ async def test_fixup(coresys: CoreSys):
assert not create_full_backup.auto assert not create_full_backup.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.CREATE_FULL_BACKUP, ContextType.SYSTEM Suggestion(SuggestionType.CREATE_FULL_BACKUP, ContextType.SYSTEM)
) )
mock_backups = AsyncMock() mock_backups = AsyncMock()

View File

@ -22,10 +22,10 @@ async def test_fixup(coresys: CoreSys):
assert system_execute_integrity.auto assert system_execute_integrity.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_INTEGRITY, ContextType.SYSTEM Suggestion(SuggestionType.EXECUTE_INTEGRITY, ContextType.SYSTEM)
) )
coresys.resolution.issues = Issue(IssueType.TRUST, ContextType.SYSTEM) coresys.resolution.add_issue(Issue(IssueType.TRUST, ContextType.SYSTEM))
coresys.security.integrity_check = AsyncMock( coresys.security.integrity_check = AsyncMock(
return_value=IntegrityResult( return_value=IntegrityResult(
@ -48,10 +48,10 @@ async def test_fixup_error(coresys: CoreSys):
assert system_execute_integrity.auto assert system_execute_integrity.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_INTEGRITY, ContextType.SYSTEM Suggestion(SuggestionType.EXECUTE_INTEGRITY, ContextType.SYSTEM)
) )
coresys.resolution.issues = Issue(IssueType.TRUST, ContextType.SYSTEM) coresys.resolution.add_issue(Issue(IssueType.TRUST, ContextType.SYSTEM))
coresys.security.integrity_check = AsyncMock( coresys.security.integrity_check = AsyncMock(
return_value=IntegrityResult( return_value=IntegrityResult(

View File

@ -20,10 +20,10 @@ async def test_fixup(
system_execute_reboot = FixupSystemExecuteReboot(coresys) system_execute_reboot = FixupSystemExecuteReboot(coresys)
assert system_execute_reboot.auto is False assert system_execute_reboot.auto is False
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_REBOOT, ContextType.SYSTEM Suggestion(SuggestionType.EXECUTE_REBOOT, ContextType.SYSTEM)
) )
coresys.resolution.issues = Issue(IssueType.REBOOT_REQUIRED, ContextType.SYSTEM) coresys.resolution.add_issue(Issue(IssueType.REBOOT_REQUIRED, ContextType.SYSTEM))
await system_execute_reboot() await system_execute_reboot()

View File

@ -43,11 +43,13 @@ async def test_fixup(coresys: CoreSys, sda1_filesystem_service: FilesystemServic
assert not system_rename_data_disk.auto assert not system_rename_data_disk.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.RENAME_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1" Suggestion(
SuggestionType.RENAME_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1"
)
) )
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1" Issue(IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1")
) )
await system_rename_data_disk() await system_rename_data_disk()
@ -73,11 +75,13 @@ async def test_fixup_device_removed(
assert not system_rename_data_disk.auto assert not system_rename_data_disk.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.RENAME_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1" Suggestion(
SuggestionType.RENAME_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1"
)
) )
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1" Issue(IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1")
) )
udisks2_service.resolved_devices = [] udisks2_service.resolved_devices = []
@ -98,11 +102,13 @@ async def test_fixup_device_not_filesystem(
assert not system_rename_data_disk.auto assert not system_rename_data_disk.auto
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.RENAME_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1" Suggestion(
SuggestionType.RENAME_DATA_DISK, ContextType.SYSTEM, reference="/dev/sda1"
)
) )
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1" Issue(IssueType.MULTIPLE_DATA_DISKS, ContextType.SYSTEM, reference="/dev/sda1")
) )
udisks2_service.resolved_devices = ["/org/freedesktop/UDisks2/block_devices/sda"] udisks2_service.resolved_devices = ["/org/freedesktop/UDisks2/block_devices/sda"]

View File

@ -22,7 +22,7 @@ def test_properies_unsupported(coresys: CoreSys):
"""Test resolution manager properties unsupported.""" """Test resolution manager properties unsupported."""
assert coresys.core.supported assert coresys.core.supported
coresys.resolution.unsupported = UnsupportedReason.OS coresys.resolution.add_unsupported_reason(UnsupportedReason.OS)
assert not coresys.core.supported assert not coresys.core.supported
@ -30,15 +30,15 @@ def test_properies_unhealthy(coresys: CoreSys):
"""Test resolution manager properties unhealthy.""" """Test resolution manager properties unhealthy."""
assert coresys.core.healthy assert coresys.core.healthy
coresys.resolution.unhealthy = UnhealthyReason.SUPERVISOR coresys.resolution.add_unhealthy_reason(UnhealthyReason.SUPERVISOR)
assert not coresys.core.healthy assert not coresys.core.healthy
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resolution_dismiss_suggestion(coresys: CoreSys): async def test_resolution_dismiss_suggestion(coresys: CoreSys):
"""Test resolution manager suggestion apply api.""" """Test resolution manager suggestion apply api."""
coresys.resolution.suggestions = clear_backup = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM clear_backup := Suggestion(SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM)
) )
assert coresys.resolution.suggestions[-1].type == SuggestionType.CLEAR_FULL_BACKUP assert coresys.resolution.suggestions[-1].type == SuggestionType.CLEAR_FULL_BACKUP
@ -52,11 +52,13 @@ async def test_resolution_dismiss_suggestion(coresys: CoreSys):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resolution_apply_suggestion(coresys: CoreSys): async def test_resolution_apply_suggestion(coresys: CoreSys):
"""Test resolution manager suggestion apply api.""" """Test resolution manager suggestion apply api."""
coresys.resolution.suggestions = clear_backup = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM clear_backup := Suggestion(SuggestionType.CLEAR_FULL_BACKUP, ContextType.SYSTEM)
) )
coresys.resolution.suggestions = create_backup = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.CREATE_FULL_BACKUP, ContextType.SYSTEM create_backup := Suggestion(
SuggestionType.CREATE_FULL_BACKUP, ContextType.SYSTEM
)
) )
mock_backups = AsyncMock() mock_backups = AsyncMock()
@ -80,8 +82,8 @@ async def test_resolution_apply_suggestion(coresys: CoreSys):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resolution_dismiss_issue(coresys: CoreSys): async def test_resolution_dismiss_issue(coresys: CoreSys):
"""Test resolution manager issue apply api.""" """Test resolution manager issue apply api."""
coresys.resolution.issues = updated_failed = Issue( coresys.resolution.add_issue(
IssueType.UPDATE_FAILED, ContextType.SYSTEM updated_failed := Issue(IssueType.UPDATE_FAILED, ContextType.SYSTEM)
) )
assert coresys.resolution.issues[-1].type == IssueType.UPDATE_FAILED assert coresys.resolution.issues[-1].type == IssueType.UPDATE_FAILED
@ -113,7 +115,7 @@ async def test_resolution_create_issue_suggestion(coresys: CoreSys):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resolution_dismiss_unsupported(coresys: CoreSys): async def test_resolution_dismiss_unsupported(coresys: CoreSys):
"""Test resolution manager dismiss unsupported reason.""" """Test resolution manager dismiss unsupported reason."""
coresys.resolution.unsupported = UnsupportedReason.SOFTWARE coresys.resolution.add_unsupported_reason(UnsupportedReason.SOFTWARE)
coresys.resolution.dismiss_unsupported(UnsupportedReason.SOFTWARE) coresys.resolution.dismiss_unsupported(UnsupportedReason.SOFTWARE)
assert UnsupportedReason.SOFTWARE not in coresys.resolution.unsupported assert UnsupportedReason.SOFTWARE not in coresys.resolution.unsupported
@ -124,26 +126,32 @@ async def test_resolution_dismiss_unsupported(coresys: CoreSys):
async def test_suggestions_for_issue(coresys: CoreSys): async def test_suggestions_for_issue(coresys: CoreSys):
"""Test getting suggestions that fix an issue.""" """Test getting suggestions that fix an issue."""
coresys.resolution.issues = corrupt_repo = Issue( coresys.resolution.add_issue(
IssueType.CORRUPT_REPOSITORY, ContextType.STORE, "test_repo" corrupt_repo := Issue(
IssueType.CORRUPT_REPOSITORY, ContextType.STORE, "test_repo"
)
) )
# Unrelated suggestions don't appear # Unrelated suggestions don't appear
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_RESET, ContextType.SUPERVISOR Suggestion(SuggestionType.EXECUTE_RESET, ContextType.SUPERVISOR)
) )
coresys.resolution.suggestions = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "other_repo" Suggestion(SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "other_repo")
) )
assert coresys.resolution.suggestions_for_issue(corrupt_repo) == set() assert coresys.resolution.suggestions_for_issue(corrupt_repo) == set()
# Related suggestions do # Related suggestions do
coresys.resolution.suggestions = execute_remove = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "test_repo" execute_remove := Suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "test_repo"
)
) )
coresys.resolution.suggestions = execute_reset = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_RESET, ContextType.STORE, "test_repo" execute_reset := Suggestion(
SuggestionType.EXECUTE_RESET, ContextType.STORE, "test_repo"
)
) )
assert coresys.resolution.suggestions_for_issue(corrupt_repo) == { assert coresys.resolution.suggestions_for_issue(corrupt_repo) == {
@ -154,24 +162,28 @@ async def test_suggestions_for_issue(coresys: CoreSys):
async def test_issues_for_suggestion(coresys: CoreSys): async def test_issues_for_suggestion(coresys: CoreSys):
"""Test getting issues fixed by a suggestion.""" """Test getting issues fixed by a suggestion."""
coresys.resolution.suggestions = execute_reset = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_RESET, ContextType.STORE, "test_repo" execute_reset := Suggestion(
SuggestionType.EXECUTE_RESET, ContextType.STORE, "test_repo"
)
) )
# Unrelated issues don't appear # Unrelated issues don't appear
coresys.resolution.issues = Issue(IssueType.FATAL_ERROR, ContextType.CORE) coresys.resolution.add_issue(Issue(IssueType.FATAL_ERROR, ContextType.CORE))
coresys.resolution.issues = Issue( coresys.resolution.add_issue(
IssueType.CORRUPT_REPOSITORY, ContextType.STORE, "other_repo" Issue(IssueType.CORRUPT_REPOSITORY, ContextType.STORE, "other_repo")
) )
assert coresys.resolution.issues_for_suggestion(execute_reset) == set() assert coresys.resolution.issues_for_suggestion(execute_reset) == set()
# Related issues do # Related issues do
coresys.resolution.issues = fatal_error = Issue( coresys.resolution.add_issue(
IssueType.FATAL_ERROR, ContextType.STORE, "test_repo" fatal_error := Issue(IssueType.FATAL_ERROR, ContextType.STORE, "test_repo")
) )
coresys.resolution.issues = corrupt_repo = Issue( coresys.resolution.add_issue(
IssueType.CORRUPT_REPOSITORY, ContextType.STORE, "test_repo" corrupt_repo := Issue(
IssueType.CORRUPT_REPOSITORY, ContextType.STORE, "test_repo"
)
) )
assert coresys.resolution.issues_for_suggestion(execute_reset) == { assert coresys.resolution.issues_for_suggestion(execute_reset) == {
@ -226,8 +238,10 @@ async def test_events_on_issue_changes(coresys: CoreSys, ha_ws_client: AsyncMock
# Adding a suggestion that fixes the issue changes it # Adding a suggestion that fixes the issue changes it
ha_ws_client.async_send_command.reset_mock() ha_ws_client.async_send_command.reset_mock()
coresys.resolution.suggestions = execute_remove = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "test_repo" execute_remove := Suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "test_repo"
)
) )
await asyncio.sleep(0) await asyncio.sleep(0)
messages = [ messages = [
@ -270,14 +284,20 @@ async def test_events_on_issue_changes(coresys: CoreSys, ha_ws_client: AsyncMock
async def test_resolution_apply_suggestion_multiple_copies(coresys: CoreSys): async def test_resolution_apply_suggestion_multiple_copies(coresys: CoreSys):
"""Test resolution manager applies correct suggestion when has multiple that differ by reference.""" """Test resolution manager applies correct suggestion when has multiple that differ by reference."""
coresys.resolution.suggestions = remove_store_1 = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "repo_1" remove_store_1 := Suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "repo_1"
)
) )
coresys.resolution.suggestions = remove_store_2 = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "repo_2" remove_store_2 := Suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "repo_2"
)
) )
coresys.resolution.suggestions = remove_store_3 = Suggestion( coresys.resolution.add_suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "repo_3" remove_store_3 := Suggestion(
SuggestionType.EXECUTE_REMOVE, ContextType.STORE, "repo_3"
)
) )
await coresys.resolution.apply_suggestion(remove_store_2) await coresys.resolution.apply_suggestion(remove_store_2)
@ -294,7 +314,7 @@ async def test_events_on_unsupported_changed(coresys: CoreSys):
) as send_message: ) as send_message:
# Marking system as unsupported tells HA # Marking system as unsupported tells HA
assert coresys.resolution.unsupported == [] assert coresys.resolution.unsupported == []
coresys.resolution.unsupported = UnsupportedReason.CONNECTIVITY_CHECK coresys.resolution.add_unsupported_reason(UnsupportedReason.CONNECTIVITY_CHECK)
await asyncio.sleep(0) await asyncio.sleep(0)
assert coresys.resolution.unsupported == [UnsupportedReason.CONNECTIVITY_CHECK] assert coresys.resolution.unsupported == [UnsupportedReason.CONNECTIVITY_CHECK]
send_message.assert_called_once_with( send_message.assert_called_once_with(
@ -306,13 +326,13 @@ async def test_events_on_unsupported_changed(coresys: CoreSys):
# Adding the same reason again does nothing # Adding the same reason again does nothing
send_message.reset_mock() send_message.reset_mock()
coresys.resolution.unsupported = UnsupportedReason.CONNECTIVITY_CHECK coresys.resolution.add_unsupported_reason(UnsupportedReason.CONNECTIVITY_CHECK)
await asyncio.sleep(0) await asyncio.sleep(0)
assert coresys.resolution.unsupported == [UnsupportedReason.CONNECTIVITY_CHECK] assert coresys.resolution.unsupported == [UnsupportedReason.CONNECTIVITY_CHECK]
send_message.assert_not_called() send_message.assert_not_called()
# Adding and removing additional reasons tells HA unsupported reasons changed # Adding and removing additional reasons tells HA unsupported reasons changed
coresys.resolution.unsupported = UnsupportedReason.JOB_CONDITIONS coresys.resolution.add_unsupported_reason(UnsupportedReason.JOB_CONDITIONS)
await asyncio.sleep(0) await asyncio.sleep(0)
assert coresys.resolution.unsupported == [ assert coresys.resolution.unsupported == [
UnsupportedReason.CONNECTIVITY_CHECK, UnsupportedReason.CONNECTIVITY_CHECK,
@ -358,7 +378,7 @@ async def test_events_on_unhealthy_changed(coresys: CoreSys):
) as send_message: ) as send_message:
# Marking system as unhealthy tells HA # Marking system as unhealthy tells HA
assert coresys.resolution.unhealthy == [] assert coresys.resolution.unhealthy == []
coresys.resolution.unhealthy = UnhealthyReason.DOCKER coresys.resolution.add_unhealthy_reason(UnhealthyReason.DOCKER)
await asyncio.sleep(0) await asyncio.sleep(0)
assert coresys.resolution.unhealthy == [UnhealthyReason.DOCKER] assert coresys.resolution.unhealthy == [UnhealthyReason.DOCKER]
send_message.assert_called_once_with( send_message.assert_called_once_with(
@ -370,13 +390,13 @@ async def test_events_on_unhealthy_changed(coresys: CoreSys):
# Adding the same reason again does nothing # Adding the same reason again does nothing
send_message.reset_mock() send_message.reset_mock()
coresys.resolution.unhealthy = UnhealthyReason.DOCKER coresys.resolution.add_unhealthy_reason(UnhealthyReason.DOCKER)
await asyncio.sleep(0) await asyncio.sleep(0)
assert coresys.resolution.unhealthy == [UnhealthyReason.DOCKER] assert coresys.resolution.unhealthy == [UnhealthyReason.DOCKER]
send_message.assert_not_called() send_message.assert_not_called()
# Adding an additional reason tells HA unhealthy reasons changed # Adding an additional reason tells HA unhealthy reasons changed
coresys.resolution.unhealthy = UnhealthyReason.UNTRUSTED coresys.resolution.add_unhealthy_reason(UnhealthyReason.UNTRUSTED)
await asyncio.sleep(0) await asyncio.sleep(0)
assert coresys.resolution.unhealthy == [ assert coresys.resolution.unhealthy == [
UnhealthyReason.DOCKER, UnhealthyReason.DOCKER,