Fix mypy issues in misc, mounts and os modules (#5942)

* Fix mypy errors in misc and mounts

* Fix mypy issues in os module

* Fix typing of capture_exception

* avoid unnecessary property call

* Fixes from feedback
This commit is contained in:
Mike Degatano 2025-06-12 18:06:57 -04:00 committed by GitHub
parent bdbd09733a
commit 82ee4bc441
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 103 additions and 128 deletions

View File

@ -8,7 +8,7 @@ from dbus_fast.aio.message_bus import MessageBus
from ..const import SOCKET_DBUS
from ..coresys import CoreSys, CoreSysAttributes
from ..exceptions import DBusFatalError
from ..exceptions import DBusFatalError, DBusNotConnectedError
from .agent import OSAgent
from .hostname import Hostname
from .interface import DBusInterface
@ -91,6 +91,13 @@ class DBusManager(CoreSysAttributes):
"""Return the message bus."""
return self._bus
@property
def connected_bus(self) -> MessageBus:
"""Return the message bus. Raise if not connected."""
if not self._bus:
raise DBusNotConnectedError()
return self._bus
@property
def all(self) -> list[DBusInterface]:
"""Return all managed dbus interfaces."""

View File

@ -2,9 +2,10 @@
from datetime import datetime, timedelta
import logging
from typing import cast
from ..addons.const import ADDON_UPDATE_CONDITIONS
from ..backups.const import LOCATION_CLOUD_BACKUP
from ..backups.const import LOCATION_CLOUD_BACKUP, LOCATION_TYPE
from ..const import ATTR_TYPE, AddonState
from ..coresys import CoreSysAttributes
from ..exceptions import (
@ -378,6 +379,8 @@ class Tasks(CoreSysAttributes):
]
for backup in old_backups:
try:
await self.sys_backups.remove(backup, [LOCATION_CLOUD_BACKUP])
await self.sys_backups.remove(
backup, [cast(LOCATION_TYPE, LOCATION_CLOUD_BACKUP)]
)
except BackupFileNotFoundError as err:
_LOGGER.debug("Can't remove backup %s: %s", backup.slug, err)

View File

@ -56,7 +56,7 @@ class MountManager(FileConfiguration, CoreSysAttributes):
async def load_config(self) -> Self:
"""Load config in executor."""
await super().load_config()
self._mounts: dict[str, Mount] = {
self._mounts = {
mount[ATTR_NAME]: Mount.from_dict(self.coresys, mount)
for mount in self._data[ATTR_MOUNTS]
}
@ -172,12 +172,12 @@ class MountManager(FileConfiguration, CoreSysAttributes):
errors = await asyncio.gather(*mount_tasks, return_exceptions=True)
for i in range(len(errors)): # pylint: disable=consider-using-enumerate
if not errors[i]:
if not (err := errors[i]):
continue
if mounts[i].failed_issue in self.sys_resolution.issues:
continue
if not isinstance(errors[i], MountError):
await async_capture_exception(errors[i])
if not isinstance(err, MountError):
await async_capture_exception(err)
self.sys_resolution.add_issue(
evolve(mounts[i].failed_issue),
@ -219,7 +219,7 @@ class MountManager(FileConfiguration, CoreSysAttributes):
conditions=[JobCondition.MOUNT_AVAILABLE],
on_condition=MountJobError,
)
async def remove_mount(self, name: str, *, retain_entry: bool = False) -> None:
async def remove_mount(self, name: str, *, retain_entry: bool = False) -> Mount:
"""Remove a mount."""
# Add mount name to job
self.sys_jobs.current.reference = name

View File

@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable
from functools import cached_property
import logging
from pathlib import Path, PurePath
@ -9,14 +10,6 @@ from pathlib import Path, PurePath
from dbus_fast import Variant
from voluptuous import Coerce
from ..const import (
ATTR_NAME,
ATTR_PASSWORD,
ATTR_PORT,
ATTR_TYPE,
ATTR_USERNAME,
ATTR_VERSION,
)
from ..coresys import CoreSys, CoreSysAttributes
from ..dbus.const import (
DBUS_ATTR_ACTIVE_STATE,
@ -41,22 +34,13 @@ from ..exceptions import (
from ..resolution.const import ContextType, IssueType
from ..resolution.data import Issue
from ..utils.sentry import async_capture_exception
from .const import (
ATTR_PATH,
ATTR_READ_ONLY,
ATTR_SERVER,
ATTR_SHARE,
ATTR_USAGE,
MountCifsVersion,
MountType,
MountUsage,
)
from .const import MountCifsVersion, MountType, MountUsage
from .validate import MountData
_LOGGER: logging.Logger = logging.getLogger(__name__)
COERCE_MOUNT_TYPE = Coerce(MountType)
COERCE_MOUNT_USAGE = Coerce(MountUsage)
COERCE_MOUNT_TYPE: Callable[[str], MountType] = Coerce(MountType)
COERCE_MOUNT_USAGE: Callable[[str], MountUsage] = Coerce(MountUsage)
class Mount(CoreSysAttributes, ABC):
@ -80,7 +64,7 @@ class Mount(CoreSysAttributes, ABC):
if cls not in [Mount, NetworkMount]:
return cls(coresys, data)
type_ = COERCE_MOUNT_TYPE(data[ATTR_TYPE])
type_ = COERCE_MOUNT_TYPE(data["type"])
if type_ == MountType.CIFS:
return CIFSMount(coresys, data)
if type_ == MountType.NFS:
@ -90,32 +74,33 @@ class Mount(CoreSysAttributes, ABC):
def to_dict(self, *, skip_secrets: bool = True) -> MountData:
"""Return dictionary representation."""
return MountData(
name=self.name, type=self.type, usage=self.usage, read_only=self.read_only
name=self.name,
type=self.type,
usage=self.usage and self.usage.value,
read_only=self.read_only,
)
@property
def name(self) -> str:
"""Get name."""
return self._data[ATTR_NAME]
return self._data["name"]
@property
def type(self) -> MountType:
"""Get mount type."""
return COERCE_MOUNT_TYPE(self._data[ATTR_TYPE])
return COERCE_MOUNT_TYPE(self._data["type"])
@property
def usage(self) -> MountUsage | None:
"""Get mount usage."""
return (
COERCE_MOUNT_USAGE(self._data[ATTR_USAGE])
if ATTR_USAGE in self._data
else None
)
if self._data["usage"] is None:
return None
return COERCE_MOUNT_USAGE(self._data["usage"])
@property
def read_only(self) -> bool:
"""Is mount read-only."""
return self._data.get(ATTR_READ_ONLY, False)
return self._data.get("read_only", False)
@property
@abstractmethod
@ -186,20 +171,20 @@ class Mount(CoreSysAttributes, ABC):
async def load(self) -> None:
"""Initialize object."""
# If there's no mount unit, mount it to make one
if not await self._update_unit():
if not (unit := await self._update_unit()):
await self.mount()
return
await self._update_state_await(not_state=UnitActiveState.ACTIVATING)
await self._update_state_await(unit, not_state=UnitActiveState.ACTIVATING)
# If mount is not available, try to reload it
if not await self.is_mounted():
await self.reload()
async def _update_state(self) -> UnitActiveState | None:
async def _update_state(self, unit: SystemdUnit) -> None:
"""Update mount unit state."""
try:
self._state = await self.unit.get_active_state()
self._state = await unit.get_active_state()
except DBusError as err:
await async_capture_exception(err)
raise MountError(
@ -220,10 +205,10 @@ class Mount(CoreSysAttributes, ABC):
async def update(self) -> bool:
"""Update info about mount from dbus. Return true if it is mounted and available."""
if not await self._update_unit():
if not (unit := await self._update_unit()):
return False
await self._update_state()
await self._update_state(unit)
# If active, dismiss corresponding failed mount issue if found
if (
@ -235,16 +220,14 @@ class Mount(CoreSysAttributes, ABC):
async def _update_state_await(
self,
unit: SystemdUnit,
expected_states: list[UnitActiveState] | None = None,
not_state: UnitActiveState = UnitActiveState.ACTIVATING,
) -> None:
"""Update state info about mount from dbus. Wait for one of expected_states to appear or state to change from not_state."""
if not self.unit:
return
try:
async with asyncio.timeout(30), self.unit.properties_changed() as signal:
await self._update_state()
async with asyncio.timeout(30), unit.properties_changed() as signal:
await self._update_state(unit)
while (
expected_states
and self.state not in expected_states
@ -312,8 +295,8 @@ class Mount(CoreSysAttributes, ABC):
f"Could not mount {self.name} due to: {err!s}", _LOGGER.error
) from err
if await self._update_unit():
await self._update_state_await(not_state=UnitActiveState.ACTIVATING)
if unit := await self._update_unit():
await self._update_state_await(unit, not_state=UnitActiveState.ACTIVATING)
if not await self.is_mounted():
raise MountActivationError(
@ -323,17 +306,17 @@ class Mount(CoreSysAttributes, ABC):
async def unmount(self) -> None:
"""Unmount using systemd."""
if not await self._update_unit():
if not (unit := await self._update_unit()):
_LOGGER.info("Mount %s is not mounted, skipping unmount", self.name)
return
await self._update_state()
await self._update_state(unit)
try:
if self.state != UnitActiveState.FAILED:
await self.sys_dbus.systemd.stop_unit(self.unit_name, StopUnitMode.FAIL)
await self._update_state_await(
[UnitActiveState.INACTIVE, UnitActiveState.FAILED]
unit, [UnitActiveState.INACTIVE, UnitActiveState.FAILED]
)
if self.state == UnitActiveState.FAILED:
@ -360,8 +343,10 @@ class Mount(CoreSysAttributes, ABC):
f"Could not reload mount {self.name} due to: {err!s}", _LOGGER.error
) from err
else:
if await self._update_unit():
await self._update_state_await(not_state=UnitActiveState.ACTIVATING)
if unit := await self._update_unit():
await self._update_state_await(
unit, not_state=UnitActiveState.ACTIVATING
)
if not await self.is_mounted():
raise MountActivationError(
@ -381,18 +366,18 @@ class NetworkMount(Mount, ABC):
"""Return dictionary representation."""
out = MountData(server=self.server, **super().to_dict())
if self.port is not None:
out[ATTR_PORT] = self.port
out["port"] = self.port
return out
@property
def server(self) -> str:
"""Get server."""
return self._data[ATTR_SERVER]
return self._data["server"]
@property
def port(self) -> int | None:
"""Get port, returns none if using the protocol default."""
return self._data.get(ATTR_PORT)
return self._data.get("port")
@property
def where(self) -> PurePath:
@ -420,31 +405,31 @@ class CIFSMount(NetworkMount):
def to_dict(self, *, skip_secrets: bool = True) -> MountData:
"""Return dictionary representation."""
out = MountData(share=self.share, **super().to_dict())
if not skip_secrets and self.username is not None:
out[ATTR_USERNAME] = self.username
out[ATTR_PASSWORD] = self.password
out[ATTR_VERSION] = self.version
if not skip_secrets and self.username is not None and self.password is not None:
out["username"] = self.username
out["password"] = self.password
out["version"] = self.version
return out
@property
def share(self) -> str:
"""Get share."""
return self._data[ATTR_SHARE]
return self._data["share"]
@property
def username(self) -> str | None:
"""Get username, returns none if auth is not used."""
return self._data.get(ATTR_USERNAME)
return self._data.get("username")
@property
def password(self) -> str | None:
"""Get password, returns none if auth is not used."""
return self._data.get(ATTR_PASSWORD)
return self._data.get("password")
@property
def version(self) -> str | None:
"""Get password, returns none if auth is not used."""
version = self._data.get(ATTR_VERSION)
"""Get cifs version, returns none if using default."""
version = self._data.get("version")
if version == MountCifsVersion.LEGACY_1_0:
return "1.0"
if version == MountCifsVersion.LEGACY_2_0:
@ -513,7 +498,7 @@ class NFSMount(NetworkMount):
@property
def path(self) -> PurePath:
"""Get path."""
return PurePath(self._data[ATTR_PATH])
return PurePath(self._data["path"])
@property
def what(self) -> str:
@ -543,7 +528,7 @@ class BindMount(Mount):
def create(
coresys: CoreSys,
name: str,
path: Path,
path: PurePath,
usage: MountUsage | None = None,
where: PurePath | None = None,
read_only: bool = False,
@ -568,7 +553,7 @@ class BindMount(Mount):
@property
def path(self) -> PurePath:
"""Get path."""
return PurePath(self._data[ATTR_PATH])
return PurePath(self._data["path"])
@property
def what(self) -> str:

View File

@ -103,7 +103,7 @@ class MountData(TypedDict):
name: str
type: str
read_only: bool
usage: NotRequired[str]
usage: str | None
# CIFS and NFS fields
server: NotRequired[str]
@ -113,6 +113,7 @@ class MountData(TypedDict):
share: NotRequired[str]
username: NotRequired[str]
password: NotRequired[str]
version: NotRequired[str | None]
# NFS and Bind fields
path: NotRequired[str]

View File

@ -5,7 +5,7 @@ from contextlib import suppress
from dataclasses import dataclass
import logging
from pathlib import Path
from typing import Any, Final
from typing import Any, Final, cast
from awesomeversion import AwesomeVersion
@ -24,6 +24,7 @@ from ..exceptions import (
)
from ..jobs.const import JobCondition, JobExecutionLimit
from ..jobs.decorator import Job
from ..resolution.checks.base import CheckBase
from ..resolution.checks.disabled_data_disk import CheckDisabledDataDisk
from ..resolution.checks.multiple_data_disks import CheckMultipleDataDisks
from ..utils.sentry import async_capture_exception
@ -149,7 +150,7 @@ class DataDisk(CoreSysAttributes):
Available disks are drives where nothing on it has been mounted
and it can be formatted.
"""
available: list[UDisks2Drive] = []
available: list[Disk] = []
for drive in self.sys_dbus.udisks2.drives:
block_devices = self._get_block_devices_for_drive(drive)
primary = _get_primary_block_device(block_devices)
@ -166,12 +167,16 @@ class DataDisk(CoreSysAttributes):
@property
def check_multiple_data_disks(self) -> CheckMultipleDataDisks:
"""Resolution center check for multiple data disks."""
return self.sys_resolution.check.get("multiple_data_disks")
return cast(
CheckMultipleDataDisks, self.sys_resolution.check.get("multiple_data_disks")
)
@property
def check_disabled_data_disk(self) -> CheckDisabledDataDisk:
"""Resolution center check for disabled data disk."""
return self.sys_resolution.check.get("disabled_data_disk")
return cast(
CheckDisabledDataDisk, self.sys_resolution.check.get("disabled_data_disk")
)
def _get_block_devices_for_drive(self, drive: UDisks2Drive) -> list[UDisks2Block]:
"""Get block devices for a drive."""
@ -361,7 +366,7 @@ class DataDisk(CoreSysAttributes):
try:
partition_block = await UDisks2Block.new(
partition, self.sys_dbus.bus, sync_properties=False
partition, self.sys_dbus.connected_bus, sync_properties=False
)
except DBusError as err:
raise HassOSDataDiskError(
@ -388,7 +393,7 @@ class DataDisk(CoreSysAttributes):
properties[DBUS_IFACE_BLOCK][DBUS_ATTR_ID_LABEL]
== FILESYSTEM_LABEL_DATA_DISK
):
check = self.check_multiple_data_disks
check: CheckBase = self.check_multiple_data_disks
elif (
properties[DBUS_IFACE_BLOCK][DBUS_ATTR_ID_LABEL]
== FILESYSTEM_LABEL_DISABLED_DATA_DISK
@ -411,7 +416,7 @@ class DataDisk(CoreSysAttributes):
and issue.context == self.check_multiple_data_disks.context
for issue in self.sys_resolution.issues
):
check = self.check_multiple_data_disks
check: CheckBase = self.check_multiple_data_disks
elif any(
issue.type == self.check_disabled_data_disk.issue
and issue.context == self.check_disabled_data_disk.context

View File

@ -1,11 +1,11 @@
"""OS support on supervisor."""
from collections.abc import Awaitable
from dataclasses import dataclass
from datetime import datetime
import errno
import logging
from pathlib import Path, PurePath
from typing import cast
import aiohttp
from awesomeversion import AwesomeVersion, AwesomeVersionException
@ -61,8 +61,8 @@ class SlotStatus:
device=PurePath(data["device"]),
bundle_compatible=data.get("bundle.compatible"),
sha256=data.get("sha256"),
size=data.get("size"),
installed_count=data.get("installed.count"),
size=cast(int | None, data.get("size")),
installed_count=cast(int | None, data.get("installed.count")),
bundle_version=AwesomeVersion(data["bundle.version"])
if "bundle.version" in data
else None,
@ -70,51 +70,17 @@ class SlotStatus:
if "installed.timestamp" in data
else None,
status=data.get("status"),
activated_count=data.get("activated.count"),
activated_count=cast(int | None, data.get("activated.count")),
activated_timestamp=datetime.fromisoformat(data["activated.timestamp"])
if "activated.timestamp" in data
else None,
boot_status=data.get("boot-status"),
boot_status=RaucState(data["boot-status"])
if "boot-status" in data
else None,
bootname=data.get("bootname"),
parent=data.get("parent"),
)
def to_dict(self) -> SlotStatusDataType:
"""Get dictionary representation."""
out: SlotStatusDataType = {
"class": self.class_,
"type": self.type_,
"state": self.state,
"device": self.device.as_posix(),
}
if self.bundle_compatible is not None:
out["bundle.compatible"] = self.bundle_compatible
if self.sha256 is not None:
out["sha256"] = self.sha256
if self.size is not None:
out["size"] = self.size
if self.installed_count is not None:
out["installed.count"] = self.installed_count
if self.bundle_version is not None:
out["bundle.version"] = str(self.bundle_version)
if self.installed_timestamp is not None:
out["installed.timestamp"] = str(self.installed_timestamp)
if self.status is not None:
out["status"] = self.status
if self.activated_count is not None:
out["activated.count"] = self.activated_count
if self.activated_timestamp:
out["activated.timestamp"] = str(self.activated_timestamp)
if self.boot_status:
out["boot-status"] = self.boot_status
if self.bootname is not None:
out["bootname"] = self.bootname
if self.parent is not None:
out["parent"] = self.parent
return out
class OSManager(CoreSysAttributes):
"""OS interface inside supervisor."""
@ -148,7 +114,11 @@ class OSManager(CoreSysAttributes):
def need_update(self) -> bool:
"""Return true if a HassOS update is available."""
try:
return self.version < self.latest_version
return (
self.version is not None
and self.latest_version is not None
and self.version < self.latest_version
)
except (AwesomeVersionException, TypeError):
return False
@ -176,6 +146,9 @@ class OSManager(CoreSysAttributes):
def get_slot_name(self, boot_name: str) -> str:
"""Get slot name from boot name."""
if not self._slots:
raise HassOSSlotNotFound()
for name, status in self._slots.items():
if status.bootname == boot_name:
return name
@ -288,11 +261,8 @@ class OSManager(CoreSysAttributes):
conditions=[JobCondition.HAOS],
on_condition=HassOSJobError,
)
async def config_sync(self) -> Awaitable[None]:
"""Trigger a host config reload from usb.
Return a coroutine.
"""
async def config_sync(self) -> None:
"""Trigger a host config reload from usb."""
_LOGGER.info(
"Synchronizing configuration from USB with Home Assistant Operating System."
)
@ -314,6 +284,10 @@ class OSManager(CoreSysAttributes):
version = version or self.latest_version
# Check installed version
if not version:
raise HassOSUpdateError(
"No version information available, cannot update", _LOGGER.error
)
if version == self.version:
raise HassOSUpdateError(
f"Version {version!s} is already installed", _LOGGER.warning

View File

@ -78,7 +78,7 @@ async def async_capture_event(event: dict[str, Any], only_once: str | None = Non
)
def capture_exception(err: Exception) -> None:
def capture_exception(err: BaseException) -> None:
"""Capture an exception and send to sentry.
Must be called in executor.
@ -87,7 +87,7 @@ def capture_exception(err: Exception) -> None:
sentry_sdk.capture_exception(err)
async def async_capture_exception(err: Exception) -> None:
async def async_capture_exception(err: BaseException) -> None:
"""Capture an exception and send to sentry.
Safe to call in event loop.