Fix mypy issues in host and jobs (#5939)

* Fix mypy issues in host

* Fix mypy issues in job module

* Fix mypy issues introduced in previously fixed modules

* Apply suggestions from code review

Co-authored-by: Stefan Agner <stefan@agner.ch>

---------

Co-authored-by: Stefan Agner <stefan@agner.ch>
This commit is contained in:
Mike Degatano 2025-06-11 12:04:25 -04:00 committed by GitHub
parent fd0b894d6a
commit 9682870c2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 192 additions and 134 deletions

View File

@ -630,7 +630,7 @@ class CoreSys:
def call_later( def call_later(
self, self,
delay: float, delay: float,
funct: Callable[..., Coroutine[Any, Any, T]], funct: Callable[..., Any],
*args: tuple[Any], *args: tuple[Any],
**kwargs: dict[str, Any], **kwargs: dict[str, Any],
) -> asyncio.TimerHandle: ) -> asyncio.TimerHandle:
@ -643,7 +643,7 @@ class CoreSys:
def call_at( def call_at(
self, self,
when: datetime, when: datetime,
funct: Callable[..., Coroutine[Any, Any, T]], funct: Callable[..., Any],
*args: tuple[Any], *args: tuple[Any],
**kwargs: dict[str, Any], **kwargs: dict[str, Any],
) -> asyncio.TimerHandle: ) -> asyncio.TimerHandle:
@ -843,7 +843,7 @@ class CoreSysAttributes:
def sys_call_later( def sys_call_later(
self, self,
delay: float, delay: float,
funct: Callable[..., Coroutine[Any, Any, T]], funct: Callable[..., Any],
*args, *args,
**kwargs, **kwargs,
) -> asyncio.TimerHandle: ) -> asyncio.TimerHandle:
@ -853,7 +853,7 @@ class CoreSysAttributes:
def sys_call_at( def sys_call_at(
self, self,
when: datetime, when: datetime,
funct: Callable[..., Coroutine[Any, Any, T]], funct: Callable[..., Any],
*args, *args,
**kwargs, **kwargs,
) -> asyncio.TimerHandle: ) -> asyncio.TimerHandle:

View File

@ -135,6 +135,7 @@ DBUS_ATTR_LAST_ERROR = "LastError"
DBUS_ATTR_LLMNR = "LLMNR" DBUS_ATTR_LLMNR = "LLMNR"
DBUS_ATTR_LLMNR_HOSTNAME = "LLMNRHostname" DBUS_ATTR_LLMNR_HOSTNAME = "LLMNRHostname"
DBUS_ATTR_LOADER_TIMESTAMP_MONOTONIC = "LoaderTimestampMonotonic" DBUS_ATTR_LOADER_TIMESTAMP_MONOTONIC = "LoaderTimestampMonotonic"
DBUS_ATTR_LOCAL_RTC = "LocalRTC"
DBUS_ATTR_MANAGED = "Managed" DBUS_ATTR_MANAGED = "Managed"
DBUS_ATTR_MODE = "Mode" DBUS_ATTR_MODE = "Mode"
DBUS_ATTR_MODEL = "Model" DBUS_ATTR_MODEL = "Model"

View File

@ -1,5 +1,6 @@
"""NetworkConnection objects for Network Manager.""" """NetworkConnection objects for Network Manager."""
from abc import ABC
from dataclasses import dataclass from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address
@ -29,7 +30,7 @@ class ConnectionProperties:
class WirelessProperties: class WirelessProperties:
"""Wireless Properties object for Network Manager.""" """Wireless Properties object for Network Manager."""
ssid: str | None ssid: str
assigned_mac: str | None assigned_mac: str | None
mode: str | None mode: str | None
powersave: int | None powersave: int | None
@ -55,7 +56,7 @@ class EthernetProperties:
class VlanProperties: class VlanProperties:
"""Ethernet properties object for Network Manager.""" """Ethernet properties object for Network Manager."""
id: int | None id: int
parent: str | None parent: str | None
@ -67,14 +68,20 @@ class IpAddress:
prefix: int prefix: int
@dataclass(slots=True) @dataclass
class IpProperties: class IpProperties(ABC):
"""IP properties object for Network Manager.""" """IP properties object for Network Manager."""
method: str | None method: str | None
address_data: list[IpAddress] | None address_data: list[IpAddress] | None
gateway: str | None gateway: str | None
dns: list[bytes | int] | None
@dataclass(slots=True)
class Ip4Properties(IpProperties):
"""IPv4 properties object."""
dns: list[int] | None
@dataclass(slots=True) @dataclass(slots=True)
@ -83,6 +90,7 @@ class Ip6Properties(IpProperties):
addr_gen_mode: int addr_gen_mode: int
ip6_privacy: int ip6_privacy: int
dns: list[bytes] | None
@dataclass(slots=True) @dataclass(slots=True)

View File

@ -12,9 +12,9 @@ from ...utils import dbus_connected
from ..configuration import ( from ..configuration import (
ConnectionProperties, ConnectionProperties,
EthernetProperties, EthernetProperties,
Ip4Properties,
Ip6Properties, Ip6Properties,
IpAddress, IpAddress,
IpProperties,
MatchProperties, MatchProperties,
VlanProperties, VlanProperties,
WirelessProperties, WirelessProperties,
@ -115,7 +115,7 @@ class NetworkSetting(DBusInterface):
self._wireless_security: WirelessSecurityProperties | None = None self._wireless_security: WirelessSecurityProperties | None = None
self._ethernet: EthernetProperties | None = None self._ethernet: EthernetProperties | None = None
self._vlan: VlanProperties | None = None self._vlan: VlanProperties | None = None
self._ipv4: IpProperties | None = None self._ipv4: Ip4Properties | None = None
self._ipv6: Ip6Properties | None = None self._ipv6: Ip6Properties | None = None
self._match: MatchProperties | None = None self._match: MatchProperties | None = None
super().__init__() super().__init__()
@ -151,7 +151,7 @@ class NetworkSetting(DBusInterface):
return self._vlan return self._vlan
@property @property
def ipv4(self) -> IpProperties | None: def ipv4(self) -> Ip4Properties | None:
"""Return ipv4 properties if any.""" """Return ipv4 properties if any."""
return self._ipv4 return self._ipv4
@ -271,16 +271,23 @@ class NetworkSetting(DBusInterface):
) )
if CONF_ATTR_VLAN in data: if CONF_ATTR_VLAN in data:
self._vlan = VlanProperties( if CONF_ATTR_VLAN_ID in data[CONF_ATTR_VLAN]:
id=data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_ID), self._vlan = VlanProperties(
parent=data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_PARENT), data[CONF_ATTR_VLAN][CONF_ATTR_VLAN_ID],
) data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_PARENT),
)
else:
self._vlan = None
_LOGGER.warning(
"Network settings for vlan connection %s missing required vlan id, cannot process it",
self.connection.interface_name,
)
if CONF_ATTR_IPV4 in data: if CONF_ATTR_IPV4 in data:
address_data = None address_data = None
if ips := data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_ADDRESS_DATA): if ips := data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_ADDRESS_DATA):
address_data = [IpAddress(ip["address"], ip["prefix"]) for ip in ips] address_data = [IpAddress(ip["address"], ip["prefix"]) for ip in ips]
self._ipv4 = IpProperties( self._ipv4 = Ip4Properties(
method=data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_METHOD), method=data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_METHOD),
address_data=address_data, address_data=address_data,
gateway=data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_GATEWAY), gateway=data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_GATEWAY),

View File

@ -222,8 +222,10 @@ def get_connection_from_interface(
} }
elif interface.type == "vlan": elif interface.type == "vlan":
parent = cast(VlanConfig, interface.vlan).interface parent = cast(VlanConfig, interface.vlan).interface
if parent in network_manager and ( if (
parent_connection := network_manager.get(parent).connection parent
and parent in network_manager
and (parent_connection := network_manager.get(parent).connection)
): ):
parent = parent_connection.uuid parent = parent_connection.uuid

View File

@ -10,6 +10,7 @@ from dbus_fast.aio.message_bus import MessageBus
from ..exceptions import DBusError, DBusInterfaceError, DBusServiceUnkownError from ..exceptions import DBusError, DBusInterfaceError, DBusServiceUnkownError
from ..utils.dt import get_time_zone, utc_from_timestamp from ..utils.dt import get_time_zone, utc_from_timestamp
from .const import ( from .const import (
DBUS_ATTR_LOCAL_RTC,
DBUS_ATTR_NTP, DBUS_ATTR_NTP,
DBUS_ATTR_NTPSYNCHRONIZED, DBUS_ATTR_NTPSYNCHRONIZED,
DBUS_ATTR_TIMEUSEC, DBUS_ATTR_TIMEUSEC,
@ -46,6 +47,12 @@ class TimeDate(DBusInterfaceProxy):
"""Return host timezone.""" """Return host timezone."""
return self.properties[DBUS_ATTR_TIMEZONE] return self.properties[DBUS_ATTR_TIMEZONE]
@property
@dbus_property
def local_rtc(self) -> bool:
"""Return whether rtc is local time or utc."""
return self.properties[DBUS_ATTR_LOCAL_RTC]
@property @property
@dbus_property @dbus_property
def ntp(self) -> bool: def ntp(self) -> bool:

View File

@ -2,6 +2,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from ipaddress import IPv4Address, IPv4Interface, IPv6Address, IPv6Interface from ipaddress import IPv4Address, IPv4Interface, IPv6Address, IPv6Interface
import logging
import socket import socket
from ..dbus.const import ( from ..dbus.const import (
@ -23,6 +24,8 @@ from .const import (
WifiMode, WifiMode,
) )
_LOGGER: logging.Logger = logging.getLogger(__name__)
@dataclass(slots=True) @dataclass(slots=True)
class AccessPoint: class AccessPoint:
@ -79,7 +82,7 @@ class VlanConfig:
"""Represent a vlan configuration.""" """Represent a vlan configuration."""
id: int id: int
interface: str interface: str | None
@dataclass(slots=True) @dataclass(slots=True)
@ -108,7 +111,10 @@ class Interface:
if inet.settings.match and inet.settings.match.path: if inet.settings.match and inet.settings.match.path:
return inet.settings.match.path == [self.path] return inet.settings.match.path == [self.path]
return inet.settings.connection.interface_name == self.name return (
inet.settings.connection is not None
and inet.settings.connection.interface_name == self.name
)
@staticmethod @staticmethod
def from_dbus_interface(inet: NetworkInterface) -> "Interface": def from_dbus_interface(inet: NetworkInterface) -> "Interface":
@ -160,23 +166,23 @@ class Interface:
ipv6_setting = Ip6Setting(InterfaceMethod.DISABLED, [], None, []) ipv6_setting = Ip6Setting(InterfaceMethod.DISABLED, [], None, [])
ipv4_ready = ( ipv4_ready = (
bool(inet.connection) inet.connection is not None
and ConnectionStateFlags.IP4_READY in inet.connection.state_flags and ConnectionStateFlags.IP4_READY in inet.connection.state_flags
) )
ipv6_ready = ( ipv6_ready = (
bool(inet.connection) inet.connection is not None
and ConnectionStateFlags.IP6_READY in inet.connection.state_flags and ConnectionStateFlags.IP6_READY in inet.connection.state_flags
) )
return Interface( return Interface(
inet.name, name=inet.name,
inet.hw_address, mac=inet.hw_address,
inet.path, path=inet.path,
inet.settings is not None, enabled=inet.settings is not None,
Interface._map_nm_connected(inet.connection), connected=Interface._map_nm_connected(inet.connection),
inet.primary, primary=inet.primary,
Interface._map_nm_type(inet.type), type=Interface._map_nm_type(inet.type),
IpConfig( ipv4=IpConfig(
address=inet.connection.ipv4.address address=inet.connection.ipv4.address
if inet.connection.ipv4.address if inet.connection.ipv4.address
else [], else [],
@ -188,8 +194,8 @@ class Interface:
) )
if inet.connection and inet.connection.ipv4 if inet.connection and inet.connection.ipv4
else IpConfig([], None, [], ipv4_ready), else IpConfig([], None, [], ipv4_ready),
ipv4_setting, ipv4setting=ipv4_setting,
IpConfig( ipv6=IpConfig(
address=inet.connection.ipv6.address address=inet.connection.ipv6.address
if inet.connection.ipv6.address if inet.connection.ipv6.address
else [], else [],
@ -201,30 +207,28 @@ class Interface:
) )
if inet.connection and inet.connection.ipv6 if inet.connection and inet.connection.ipv6
else IpConfig([], None, [], ipv6_ready), else IpConfig([], None, [], ipv6_ready),
ipv6_setting, ipv6setting=ipv6_setting,
Interface._map_nm_wifi(inet), wifi=Interface._map_nm_wifi(inet),
Interface._map_nm_vlan(inet), vlan=Interface._map_nm_vlan(inet),
) )
@staticmethod @staticmethod
def _map_nm_method(method: str) -> InterfaceMethod: def _map_nm_method(method: str | None) -> InterfaceMethod:
"""Map IP interface method.""" """Map IP interface method."""
mapping = { match method:
NMInterfaceMethod.AUTO: InterfaceMethod.AUTO, case NMInterfaceMethod.AUTO.value:
NMInterfaceMethod.DISABLED: InterfaceMethod.DISABLED, return InterfaceMethod.AUTO
NMInterfaceMethod.MANUAL: InterfaceMethod.STATIC, case NMInterfaceMethod.MANUAL:
NMInterfaceMethod.LINK_LOCAL: InterfaceMethod.DISABLED, return InterfaceMethod.STATIC
} return InterfaceMethod.DISABLED
return mapping.get(method, InterfaceMethod.DISABLED)
@staticmethod @staticmethod
def _map_nm_addr_gen_mode(addr_gen_mode: int) -> InterfaceAddrGenMode: def _map_nm_addr_gen_mode(addr_gen_mode: int) -> InterfaceAddrGenMode:
"""Map IPv6 interface addr_gen_mode.""" """Map IPv6 interface addr_gen_mode."""
mapping = { mapping = {
NMInterfaceAddrGenMode.EUI64: InterfaceAddrGenMode.EUI64, NMInterfaceAddrGenMode.EUI64.value: InterfaceAddrGenMode.EUI64,
NMInterfaceAddrGenMode.STABLE_PRIVACY: InterfaceAddrGenMode.STABLE_PRIVACY, NMInterfaceAddrGenMode.STABLE_PRIVACY.value: InterfaceAddrGenMode.STABLE_PRIVACY,
NMInterfaceAddrGenMode.DEFAULT_OR_EUI64: InterfaceAddrGenMode.DEFAULT_OR_EUI64, NMInterfaceAddrGenMode.DEFAULT_OR_EUI64.value: InterfaceAddrGenMode.DEFAULT_OR_EUI64,
} }
return mapping.get(addr_gen_mode, InterfaceAddrGenMode.DEFAULT) return mapping.get(addr_gen_mode, InterfaceAddrGenMode.DEFAULT)
@ -233,9 +237,9 @@ class Interface:
def _map_nm_ip6_privacy(ip6_privacy: int) -> InterfaceIp6Privacy: def _map_nm_ip6_privacy(ip6_privacy: int) -> InterfaceIp6Privacy:
"""Map IPv6 interface ip6_privacy.""" """Map IPv6 interface ip6_privacy."""
mapping = { mapping = {
NMInterfaceIp6Privacy.DISABLED: InterfaceIp6Privacy.DISABLED, NMInterfaceIp6Privacy.DISABLED.value: InterfaceIp6Privacy.DISABLED,
NMInterfaceIp6Privacy.ENABLED_PREFER_PUBLIC: InterfaceIp6Privacy.ENABLED_PREFER_PUBLIC, NMInterfaceIp6Privacy.ENABLED_PREFER_PUBLIC.value: InterfaceIp6Privacy.ENABLED_PREFER_PUBLIC,
NMInterfaceIp6Privacy.ENABLED: InterfaceIp6Privacy.ENABLED, NMInterfaceIp6Privacy.ENABLED.value: InterfaceIp6Privacy.ENABLED,
} }
return mapping.get(ip6_privacy, InterfaceIp6Privacy.DEFAULT) return mapping.get(ip6_privacy, InterfaceIp6Privacy.DEFAULT)
@ -253,12 +257,14 @@ class Interface:
@staticmethod @staticmethod
def _map_nm_type(device_type: int) -> InterfaceType: def _map_nm_type(device_type: int) -> InterfaceType:
mapping = { match device_type:
DeviceType.ETHERNET: InterfaceType.ETHERNET, case DeviceType.ETHERNET.value:
DeviceType.WIRELESS: InterfaceType.WIRELESS, return InterfaceType.ETHERNET
DeviceType.VLAN: InterfaceType.VLAN, case DeviceType.WIRELESS.value:
} return InterfaceType.WIRELESS
return mapping[device_type] case DeviceType.VLAN.value:
return InterfaceType.VLAN
raise ValueError(f"Invalid device type: {device_type}")
@staticmethod @staticmethod
def _map_nm_wifi(inet: NetworkInterface) -> WifiConfig | None: def _map_nm_wifi(inet: NetworkInterface) -> WifiConfig | None:
@ -267,15 +273,22 @@ class Interface:
return None return None
# Authentication and PSK # Authentication and PSK
auth = None auth = AuthMethod.OPEN
psk = None psk = None
if not inet.settings.wireless_security: if inet.settings.wireless_security:
auth = AuthMethod.OPEN match inet.settings.wireless_security.key_mgmt:
elif inet.settings.wireless_security.key_mgmt == "none": case "none":
auth = AuthMethod.WEP auth = AuthMethod.WEP
elif inet.settings.wireless_security.key_mgmt == "wpa-psk": case "wpa-psk":
auth = AuthMethod.WPA_PSK auth = AuthMethod.WPA_PSK
psk = inet.settings.wireless_security.psk psk = inet.settings.wireless_security.psk
case _:
_LOGGER.warning(
"Auth method %s for network interface %s unsupported, skipping",
inet.settings.wireless_security.key_mgmt,
inet.name,
)
return None
# WifiMode # WifiMode
mode = WifiMode.INFRASTRUCTURE mode = WifiMode.INFRASTRUCTURE
@ -289,17 +302,17 @@ class Interface:
signal = None signal = None
return WifiConfig( return WifiConfig(
mode, mode=mode,
inet.settings.wireless.ssid, ssid=inet.settings.wireless.ssid if inet.settings.wireless else "",
auth, auth=auth,
psk, psk=psk,
signal, signal=signal,
) )
@staticmethod @staticmethod
def _map_nm_vlan(inet: NetworkInterface) -> WifiConfig | None: def _map_nm_vlan(inet: NetworkInterface) -> VlanConfig | None:
"""Create mapping to nm vlan property.""" """Create mapping to nm vlan property."""
if inet.type != DeviceType.VLAN or not inet.settings: if inet.type != DeviceType.VLAN or not inet.settings or not inet.settings.vlan:
return None return None
return VlanConfig(inet.settings.vlan.id, inet.settings.vlan.parent) return VlanConfig(inet.settings.vlan.id, inet.settings.vlan.parent)

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Mapping
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import json import json
import logging import logging
@ -205,7 +205,7 @@ class LogsControl(CoreSysAttributes):
async def journald_logs( async def journald_logs(
self, self,
path: str = "/entries", path: str = "/entries",
params: dict[str, str | list[str]] | None = None, params: Mapping[str, str | list[str]] | None = None,
range_header: str | None = None, range_header: str | None = None,
accept: LogFormat = LogFormat.TEXT, accept: LogFormat = LogFormat.TEXT,
timeout: ClientTimeout | None = None, timeout: ClientTimeout | None = None,
@ -226,7 +226,7 @@ class LogsControl(CoreSysAttributes):
base_url = "http://localhost/" base_url = "http://localhost/"
connector = UnixConnector(path=str(SYSTEMD_JOURNAL_GATEWAYD_SOCKET)) connector = UnixConnector(path=str(SYSTEMD_JOURNAL_GATEWAYD_SOCKET))
async with ClientSession(base_url=base_url, connector=connector) as session: async with ClientSession(base_url=base_url, connector=connector) as session:
headers = {ACCEPT: accept} headers = {ACCEPT: accept.value}
if range_header: if range_header:
if range_header.endswith(":"): if range_header.endswith(":"):
# Make sure that num_entries is always set - before Systemd v256 it was # Make sure that num_entries is always set - before Systemd v256 it was

View File

@ -87,7 +87,7 @@ class NetworkManager(CoreSysAttributes):
for config in self.sys_dbus.network.dns.configuration: for config in self.sys_dbus.network.dns.configuration:
if config.vpn or not config.nameservers: if config.vpn or not config.nameservers:
continue continue
servers.extend(config.nameservers) servers.extend([str(ns) for ns in config.nameservers])
return list(dict.fromkeys(servers)) return list(dict.fromkeys(servers))
@ -197,10 +197,16 @@ class NetworkManager(CoreSysAttributes):
with suppress(NetworkInterfaceNotFound): with suppress(NetworkInterfaceNotFound):
inet = self.sys_dbus.network.get(interface.name) inet = self.sys_dbus.network.get(interface.name)
con: NetworkConnection = None con: NetworkConnection | None = None
# Update exist configuration # Update exist configuration
if inet and interface.equals_dbus_interface(inet) and interface.enabled: if (
inet
and inet.settings
and inet.settings.connection
and interface.equals_dbus_interface(inet)
and interface.enabled
):
_LOGGER.debug("Updating existing configuration for %s", interface.name) _LOGGER.debug("Updating existing configuration for %s", interface.name)
settings = get_connection_from_interface( settings = get_connection_from_interface(
interface, interface,
@ -211,12 +217,12 @@ class NetworkManager(CoreSysAttributes):
try: try:
await inet.settings.update(settings) await inet.settings.update(settings)
con = await self.sys_dbus.network.activate_connection( con = activated = await self.sys_dbus.network.activate_connection(
inet.settings.object_path, inet.object_path inet.settings.object_path, inet.object_path
) )
_LOGGER.debug( _LOGGER.debug(
"activate_connection returns %s", "activate_connection returns %s",
con.object_path, activated.object_path,
) )
except DBusError as err: except DBusError as err:
raise HostNetworkError( raise HostNetworkError(
@ -236,12 +242,16 @@ class NetworkManager(CoreSysAttributes):
settings = get_connection_from_interface(interface, self.sys_dbus.network) settings = get_connection_from_interface(interface, self.sys_dbus.network)
try: try:
settings, con = await self.sys_dbus.network.add_and_activate_connection( (
settings,
activated,
) = await self.sys_dbus.network.add_and_activate_connection(
settings, inet.object_path settings, inet.object_path
) )
con = activated
_LOGGER.debug( _LOGGER.debug(
"add_and_activate_connection returns %s", "add_and_activate_connection returns %s",
con.object_path, activated.object_path,
) )
except DBusError as err: except DBusError as err:
raise HostNetworkError( raise HostNetworkError(
@ -277,7 +287,7 @@ class NetworkManager(CoreSysAttributes):
) )
if con: if con:
async with con.dbus.signal( async with con.connected_dbus.signal(
DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED
) as signal: ) as signal:
# From this point we monitor signals. However, it might be that # From this point we monitor signals. However, it might be that
@ -303,7 +313,7 @@ class NetworkManager(CoreSysAttributes):
"""Scan on Interface for AccessPoint.""" """Scan on Interface for AccessPoint."""
inet = self.sys_dbus.network.get(interface.name) inet = self.sys_dbus.network.get(interface.name)
if inet.type != DeviceType.WIRELESS: if inet.type != DeviceType.WIRELESS or not inet.wireless:
raise HostNotSupportedError( raise HostNotSupportedError(
f"Can only scan with wireless card - {interface.name}", _LOGGER.error f"Can only scan with wireless card - {interface.name}", _LOGGER.error
) )

View File

@ -1,13 +1,13 @@
"""Supervisor job manager.""" """Supervisor job manager."""
import asyncio import asyncio
from collections.abc import Awaitable, Callable from collections.abc import Callable, Coroutine, Generator
from contextlib import contextmanager from contextlib import contextmanager, suppress
from contextvars import Context, ContextVar, Token from contextvars import Context, ContextVar, Token
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
import logging import logging
from typing import Any from typing import Any, Self
from uuid import uuid4 from uuid import uuid4
from attrs import Attribute, define, field from attrs import Attribute, define, field
@ -27,7 +27,7 @@ from .validate import SCHEMA_JOBS_CONFIG
# When a new asyncio task is started the current context is copied over. # When a new asyncio task is started the current context is copied over.
# Modifications to it in one task are not visible to others though. # Modifications to it in one task are not visible to others though.
# This allows us to track what job is currently in progress in each task. # This allows us to track what job is currently in progress in each task.
_CURRENT_JOB: ContextVar[str] = ContextVar("current_job") _CURRENT_JOB: ContextVar[str | None] = ContextVar("current_job", default=None)
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -75,7 +75,7 @@ class SupervisorJobError:
message: str = "Unknown error, see supervisor logs" message: str = "Unknown error, see supervisor logs"
stage: str | None = None stage: str | None = None
def as_dict(self) -> dict[str, str]: def as_dict(self) -> dict[str, str | None]:
"""Return dictionary representation.""" """Return dictionary representation."""
return { return {
"type": self.type_.__name__, "type": self.type_.__name__,
@ -101,9 +101,7 @@ class SupervisorJob:
stage: str | None = field( stage: str | None = field(
default=None, validator=[_invalid_if_done], on_setattr=_on_change default=None, validator=[_invalid_if_done], on_setattr=_on_change
) )
parent_id: str | None = field( parent_id: str | None = field(factory=_CURRENT_JOB.get, on_setattr=frozen)
factory=lambda: _CURRENT_JOB.get(None), on_setattr=frozen
)
done: bool | None = field(init=False, default=None, on_setattr=_on_change) done: bool | None = field(init=False, default=None, on_setattr=_on_change)
on_change: Callable[["SupervisorJob", Attribute, Any], None] | None = field( on_change: Callable[["SupervisorJob", Attribute, Any], None] | None = field(
default=None, on_setattr=frozen default=None, on_setattr=frozen
@ -137,7 +135,7 @@ class SupervisorJob:
self.errors += [new_error] self.errors += [new_error]
@contextmanager @contextmanager
def start(self): def start(self) -> Generator[Self]:
"""Start the job in the current task. """Start the job in the current task.
This can only be called if the parent ID matches the job running in the current task. This can only be called if the parent ID matches the job running in the current task.
@ -146,11 +144,11 @@ class SupervisorJob:
""" """
if self.done is not None: if self.done is not None:
raise JobStartException("Job has already been started") raise JobStartException("Job has already been started")
if _CURRENT_JOB.get(None) != self.parent_id: if _CURRENT_JOB.get() != self.parent_id:
raise JobStartException("Job has a different parent from current job") raise JobStartException("Job has a different parent from current job")
self.done = False self.done = False
token: Token[str] | None = None token: Token[str | None] | None = None
try: try:
token = _CURRENT_JOB.set(self.uuid) token = _CURRENT_JOB.set(self.uuid)
yield self yield self
@ -193,17 +191,15 @@ class JobManager(FileConfiguration, CoreSysAttributes):
Must be called from within a job. Raises RuntimeError if there is no current job. Must be called from within a job. Raises RuntimeError if there is no current job.
""" """
try: if job_id := _CURRENT_JOB.get():
return self.get_job(_CURRENT_JOB.get()) with suppress(JobNotFound):
except (LookupError, JobNotFound): return self.get_job(job_id)
raise RuntimeError( raise RuntimeError("No job for the current asyncio task!", _LOGGER.critical)
"No job for the current asyncio task!", _LOGGER.critical
) from None
@property @property
def is_job(self) -> bool: def is_job(self) -> bool:
"""Return true if there is an active job for the current asyncio task.""" """Return true if there is an active job for the current asyncio task."""
return bool(_CURRENT_JOB.get(None)) return _CURRENT_JOB.get() is not None
def _notify_on_job_change( def _notify_on_job_change(
self, job: SupervisorJob, attribute: Attribute, value: Any self, job: SupervisorJob, attribute: Attribute, value: Any
@ -265,7 +261,7 @@ class JobManager(FileConfiguration, CoreSysAttributes):
def schedule_job( def schedule_job(
self, self,
job_method: Callable[..., Awaitable[Any]], job_method: Callable[..., Coroutine],
options: JobSchedulerOptions, options: JobSchedulerOptions,
*args, *args,
**kwargs, **kwargs,

View File

@ -1,12 +1,12 @@
"""Job decorator.""" """Job decorator."""
import asyncio import asyncio
from collections.abc import Callable from collections.abc import Awaitable, Callable
from contextlib import suppress from contextlib import suppress
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import wraps from functools import wraps
import logging import logging
from typing import Any from typing import Any, cast
from ..const import CoreState from ..const import CoreState
from ..coresys import CoreSys, CoreSysAttributes from ..coresys import CoreSys, CoreSysAttributes
@ -54,11 +54,10 @@ class Job(CoreSysAttributes):
self.on_condition = on_condition self.on_condition = on_condition
self.limit = limit self.limit = limit
self._throttle_period = throttle_period self._throttle_period = throttle_period
self.throttle_max_calls = throttle_max_calls self._throttle_max_calls = throttle_max_calls
self._lock: asyncio.Semaphore | None = None self._lock: asyncio.Semaphore | None = None
self._method = None
self._last_call: dict[str | None, datetime] = {} self._last_call: dict[str | None, datetime] = {}
self._rate_limited_calls: dict[str, list[datetime]] | None = None self._rate_limited_calls: dict[str | None, list[datetime]] | None = None
self._internal = internal self._internal = internal
# Validate Options # Validate Options
@ -82,13 +81,29 @@ class Job(CoreSysAttributes):
JobExecutionLimit.THROTTLE_RATE_LIMIT, JobExecutionLimit.THROTTLE_RATE_LIMIT,
JobExecutionLimit.GROUP_THROTTLE_RATE_LIMIT, JobExecutionLimit.GROUP_THROTTLE_RATE_LIMIT,
): ):
if self.throttle_max_calls is None: if self._throttle_max_calls is None:
raise RuntimeError( raise RuntimeError(
f"Job {name} is using execution limit {limit} without throttle max calls!" f"Job {name} is using execution limit {limit} without throttle max calls!"
) )
self._rate_limited_calls = {} self._rate_limited_calls = {}
@property
def throttle_max_calls(self) -> int:
"""Return max calls for throttle."""
if self._throttle_max_calls is None:
raise RuntimeError("No throttle max calls set for job!")
return self._throttle_max_calls
@property
def lock(self) -> asyncio.Semaphore:
"""Return lock for limits."""
# asyncio.Semaphore objects must be created in event loop
# Since this is sync code it is not safe to create if missing here
if not self._lock:
raise RuntimeError("Lock has not been created yet!")
return self._lock
def last_call(self, group_name: str | None = None) -> datetime: def last_call(self, group_name: str | None = None) -> datetime:
"""Return last call datetime.""" """Return last call datetime."""
return self._last_call.get(group_name, datetime.min) return self._last_call.get(group_name, datetime.min)
@ -97,12 +112,12 @@ class Job(CoreSysAttributes):
"""Set last call datetime.""" """Set last call datetime."""
self._last_call[group_name] = value self._last_call[group_name] = value
def rate_limited_calls( def rate_limited_calls(self, group_name: str | None = None) -> list[datetime]:
self, group_name: str | None = None
) -> list[datetime] | None:
"""Return rate limited calls if used.""" """Return rate limited calls if used."""
if self._rate_limited_calls is None: if self._rate_limited_calls is None:
return None raise RuntimeError(
f"Rate limited calls not available for limit type {self.limit}"
)
return self._rate_limited_calls.get(group_name, []) return self._rate_limited_calls.get(group_name, [])
@ -131,10 +146,10 @@ class Job(CoreSysAttributes):
self._rate_limited_calls[group_name] = value self._rate_limited_calls[group_name] = value
def throttle_period(self, group_name: str | None = None) -> timedelta | None: def throttle_period(self, group_name: str | None = None) -> timedelta:
"""Return throttle period.""" """Return throttle period."""
if self._throttle_period is None: if self._throttle_period is None:
return None raise RuntimeError("No throttle period set for Job!")
if isinstance(self._throttle_period, timedelta): if isinstance(self._throttle_period, timedelta):
return self._throttle_period return self._throttle_period
@ -142,7 +157,7 @@ class Job(CoreSysAttributes):
return self._throttle_period( return self._throttle_period(
self.coresys, self.coresys,
self.last_call(group_name), self.last_call(group_name),
self.rate_limited_calls(group_name), self.rate_limited_calls(group_name) if self._rate_limited_calls else None,
) )
def _post_init(self, obj: JobGroup | CoreSysAttributes) -> JobGroup | None: def _post_init(self, obj: JobGroup | CoreSysAttributes) -> JobGroup | None:
@ -158,12 +173,12 @@ class Job(CoreSysAttributes):
self._lock = asyncio.Semaphore() self._lock = asyncio.Semaphore()
# Job groups # Job groups
try: job_group: JobGroup | None = None
is_job_group = obj.acquire and obj.release with suppress(AttributeError):
except AttributeError: if obj.acquire and obj.release: # type: ignore
is_job_group = False job_group = cast(JobGroup, obj)
if not is_job_group and self.limit in ( if not job_group and self.limit in (
JobExecutionLimit.GROUP_ONCE, JobExecutionLimit.GROUP_ONCE,
JobExecutionLimit.GROUP_WAIT, JobExecutionLimit.GROUP_WAIT,
JobExecutionLimit.GROUP_THROTTLE, JobExecutionLimit.GROUP_THROTTLE,
@ -174,7 +189,7 @@ class Job(CoreSysAttributes):
f"Job on {self.name} need to be a JobGroup to use group based limits!" f"Job on {self.name} need to be a JobGroup to use group based limits!"
) from None ) from None
return obj if is_job_group else None return job_group
def _handle_job_condition_exception(self, err: JobConditionException) -> None: def _handle_job_condition_exception(self, err: JobConditionException) -> None:
"""Handle a job condition failure.""" """Handle a job condition failure."""
@ -184,9 +199,8 @@ class Job(CoreSysAttributes):
return return
raise self.on_condition(error_msg, _LOGGER.warning) from None raise self.on_condition(error_msg, _LOGGER.warning) from None
def __call__(self, method): def __call__(self, method: Callable[..., Awaitable]):
"""Call the wrapper logic.""" """Call the wrapper logic."""
self._method = method
@wraps(method) @wraps(method)
async def wrapper( async def wrapper(
@ -221,7 +235,7 @@ class Job(CoreSysAttributes):
if self.conditions: if self.conditions:
try: try:
await Job.check_conditions( await Job.check_conditions(
self, set(self.conditions), self._method.__qualname__ self, set(self.conditions), method.__qualname__
) )
except JobConditionException as err: except JobConditionException as err:
return self._handle_job_condition_exception(err) return self._handle_job_condition_exception(err)
@ -237,7 +251,7 @@ class Job(CoreSysAttributes):
JobExecutionLimit.GROUP_WAIT, JobExecutionLimit.GROUP_WAIT,
): ):
try: try:
await obj.acquire( await cast(JobGroup, job_group).acquire(
job, self.limit == JobExecutionLimit.GROUP_WAIT job, self.limit == JobExecutionLimit.GROUP_WAIT
) )
except JobGroupExecutionLimitExceeded as err: except JobGroupExecutionLimitExceeded as err:
@ -296,12 +310,12 @@ class Job(CoreSysAttributes):
with job.start(): with job.start():
try: try:
self.set_last_call(datetime.now(), group_name) self.set_last_call(datetime.now(), group_name)
if self.rate_limited_calls(group_name) is not None: if self._rate_limited_calls is not None:
self.add_rate_limited_call( self.add_rate_limited_call(
self.last_call(group_name), group_name self.last_call(group_name), group_name
) )
return await self._method(obj, *args, **kwargs) return await method(obj, *args, **kwargs)
# If a method has a conditional JobCondition, they must check it in the method # If a method has a conditional JobCondition, they must check it in the method
# These should be handled like normal JobConditions as much as possible # These should be handled like normal JobConditions as much as possible
@ -317,11 +331,11 @@ class Job(CoreSysAttributes):
raise JobException() from err raise JobException() from err
finally: finally:
self._release_exception_limits() self._release_exception_limits()
if self.limit in ( if job_group and self.limit in (
JobExecutionLimit.GROUP_ONCE, JobExecutionLimit.GROUP_ONCE,
JobExecutionLimit.GROUP_WAIT, JobExecutionLimit.GROUP_WAIT,
): ):
obj.release() job_group.release()
# Jobs that weren't started are always cleaned up. Also clean up done jobs if required # Jobs that weren't started are always cleaned up. Also clean up done jobs if required
finally: finally:
@ -473,13 +487,13 @@ class Job(CoreSysAttributes):
): ):
return return
if self.limit == JobExecutionLimit.ONCE and self._lock.locked(): if self.limit == JobExecutionLimit.ONCE and self.lock.locked():
on_condition = ( on_condition = (
JobException if self.on_condition is None else self.on_condition JobException if self.on_condition is None else self.on_condition
) )
raise on_condition("Another job is running") raise on_condition("Another job is running")
await self._lock.acquire() await self.lock.acquire()
def _release_exception_limits(self) -> None: def _release_exception_limits(self) -> None:
"""Release possible exception limits.""" """Release possible exception limits."""
@ -490,4 +504,4 @@ class Job(CoreSysAttributes):
JobExecutionLimit.GROUP_THROTTLE_WAIT, JobExecutionLimit.GROUP_THROTTLE_WAIT,
): ):
return return
self._lock.release() self.lock.release()

View File

@ -41,7 +41,7 @@ class JobGroup(CoreSysAttributes):
def has_lock(self) -> bool: def has_lock(self) -> bool:
"""Return true if current task has the lock on this job group.""" """Return true if current task has the lock on this job group."""
return ( return (
self.active_job self.active_job is not None
and self.sys_jobs.is_job and self.sys_jobs.is_job
and self.active_job == self.sys_jobs.current and self.active_job == self.sys_jobs.current
) )