Fix mypy issues in host and jobs (#5939)

* Fix mypy issues in host

* Fix mypy issues in job module

* Fix mypy issues introduced in previously fixed modules

* Apply suggestions from code review

Co-authored-by: Stefan Agner <stefan@agner.ch>

---------

Co-authored-by: Stefan Agner <stefan@agner.ch>
This commit is contained in:
Mike Degatano 2025-06-11 12:04:25 -04:00 committed by GitHub
parent fd0b894d6a
commit 9682870c2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 192 additions and 134 deletions

View File

@ -630,7 +630,7 @@ class CoreSys:
def call_later(
self,
delay: float,
funct: Callable[..., Coroutine[Any, Any, T]],
funct: Callable[..., Any],
*args: tuple[Any],
**kwargs: dict[str, Any],
) -> asyncio.TimerHandle:
@ -643,7 +643,7 @@ class CoreSys:
def call_at(
self,
when: datetime,
funct: Callable[..., Coroutine[Any, Any, T]],
funct: Callable[..., Any],
*args: tuple[Any],
**kwargs: dict[str, Any],
) -> asyncio.TimerHandle:
@ -843,7 +843,7 @@ class CoreSysAttributes:
def sys_call_later(
self,
delay: float,
funct: Callable[..., Coroutine[Any, Any, T]],
funct: Callable[..., Any],
*args,
**kwargs,
) -> asyncio.TimerHandle:
@ -853,7 +853,7 @@ class CoreSysAttributes:
def sys_call_at(
self,
when: datetime,
funct: Callable[..., Coroutine[Any, Any, T]],
funct: Callable[..., Any],
*args,
**kwargs,
) -> asyncio.TimerHandle:

View File

@ -135,6 +135,7 @@ DBUS_ATTR_LAST_ERROR = "LastError"
DBUS_ATTR_LLMNR = "LLMNR"
DBUS_ATTR_LLMNR_HOSTNAME = "LLMNRHostname"
DBUS_ATTR_LOADER_TIMESTAMP_MONOTONIC = "LoaderTimestampMonotonic"
DBUS_ATTR_LOCAL_RTC = "LocalRTC"
DBUS_ATTR_MANAGED = "Managed"
DBUS_ATTR_MODE = "Mode"
DBUS_ATTR_MODEL = "Model"

View File

@ -1,5 +1,6 @@
"""NetworkConnection objects for Network Manager."""
from abc import ABC
from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address
@ -29,7 +30,7 @@ class ConnectionProperties:
class WirelessProperties:
"""Wireless Properties object for Network Manager."""
ssid: str | None
ssid: str
assigned_mac: str | None
mode: str | None
powersave: int | None
@ -55,7 +56,7 @@ class EthernetProperties:
class VlanProperties:
"""Ethernet properties object for Network Manager."""
id: int | None
id: int
parent: str | None
@ -67,14 +68,20 @@ class IpAddress:
prefix: int
@dataclass(slots=True)
class IpProperties:
@dataclass
class IpProperties(ABC):
"""IP properties object for Network Manager."""
method: str | None
address_data: list[IpAddress] | None
gateway: str | None
dns: list[bytes | int] | None
@dataclass(slots=True)
class Ip4Properties(IpProperties):
"""IPv4 properties object."""
dns: list[int] | None
@dataclass(slots=True)
@ -83,6 +90,7 @@ class Ip6Properties(IpProperties):
addr_gen_mode: int
ip6_privacy: int
dns: list[bytes] | None
@dataclass(slots=True)

View File

@ -12,9 +12,9 @@ from ...utils import dbus_connected
from ..configuration import (
ConnectionProperties,
EthernetProperties,
Ip4Properties,
Ip6Properties,
IpAddress,
IpProperties,
MatchProperties,
VlanProperties,
WirelessProperties,
@ -115,7 +115,7 @@ class NetworkSetting(DBusInterface):
self._wireless_security: WirelessSecurityProperties | None = None
self._ethernet: EthernetProperties | None = None
self._vlan: VlanProperties | None = None
self._ipv4: IpProperties | None = None
self._ipv4: Ip4Properties | None = None
self._ipv6: Ip6Properties | None = None
self._match: MatchProperties | None = None
super().__init__()
@ -151,7 +151,7 @@ class NetworkSetting(DBusInterface):
return self._vlan
@property
def ipv4(self) -> IpProperties | None:
def ipv4(self) -> Ip4Properties | None:
"""Return ipv4 properties if any."""
return self._ipv4
@ -271,16 +271,23 @@ class NetworkSetting(DBusInterface):
)
if CONF_ATTR_VLAN in data:
self._vlan = VlanProperties(
id=data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_ID),
parent=data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_PARENT),
)
if CONF_ATTR_VLAN_ID in data[CONF_ATTR_VLAN]:
self._vlan = VlanProperties(
data[CONF_ATTR_VLAN][CONF_ATTR_VLAN_ID],
data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_PARENT),
)
else:
self._vlan = None
_LOGGER.warning(
"Network settings for vlan connection %s missing required vlan id, cannot process it",
self.connection.interface_name,
)
if CONF_ATTR_IPV4 in data:
address_data = None
if ips := data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_ADDRESS_DATA):
address_data = [IpAddress(ip["address"], ip["prefix"]) for ip in ips]
self._ipv4 = IpProperties(
self._ipv4 = Ip4Properties(
method=data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_METHOD),
address_data=address_data,
gateway=data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_GATEWAY),

View File

@ -222,8 +222,10 @@ def get_connection_from_interface(
}
elif interface.type == "vlan":
parent = cast(VlanConfig, interface.vlan).interface
if parent in network_manager and (
parent_connection := network_manager.get(parent).connection
if (
parent
and parent in network_manager
and (parent_connection := network_manager.get(parent).connection)
):
parent = parent_connection.uuid

View File

@ -10,6 +10,7 @@ from dbus_fast.aio.message_bus import MessageBus
from ..exceptions import DBusError, DBusInterfaceError, DBusServiceUnkownError
from ..utils.dt import get_time_zone, utc_from_timestamp
from .const import (
DBUS_ATTR_LOCAL_RTC,
DBUS_ATTR_NTP,
DBUS_ATTR_NTPSYNCHRONIZED,
DBUS_ATTR_TIMEUSEC,
@ -46,6 +47,12 @@ class TimeDate(DBusInterfaceProxy):
"""Return host timezone."""
return self.properties[DBUS_ATTR_TIMEZONE]
@property
@dbus_property
def local_rtc(self) -> bool:
"""Return whether rtc is local time or utc."""
return self.properties[DBUS_ATTR_LOCAL_RTC]
@property
@dbus_property
def ntp(self) -> bool:

View File

@ -2,6 +2,7 @@
from dataclasses import dataclass
from ipaddress import IPv4Address, IPv4Interface, IPv6Address, IPv6Interface
import logging
import socket
from ..dbus.const import (
@ -23,6 +24,8 @@ from .const import (
WifiMode,
)
_LOGGER: logging.Logger = logging.getLogger(__name__)
@dataclass(slots=True)
class AccessPoint:
@ -79,7 +82,7 @@ class VlanConfig:
"""Represent a vlan configuration."""
id: int
interface: str
interface: str | None
@dataclass(slots=True)
@ -108,7 +111,10 @@ class Interface:
if inet.settings.match and inet.settings.match.path:
return inet.settings.match.path == [self.path]
return inet.settings.connection.interface_name == self.name
return (
inet.settings.connection is not None
and inet.settings.connection.interface_name == self.name
)
@staticmethod
def from_dbus_interface(inet: NetworkInterface) -> "Interface":
@ -160,23 +166,23 @@ class Interface:
ipv6_setting = Ip6Setting(InterfaceMethod.DISABLED, [], None, [])
ipv4_ready = (
bool(inet.connection)
inet.connection is not None
and ConnectionStateFlags.IP4_READY in inet.connection.state_flags
)
ipv6_ready = (
bool(inet.connection)
inet.connection is not None
and ConnectionStateFlags.IP6_READY in inet.connection.state_flags
)
return Interface(
inet.name,
inet.hw_address,
inet.path,
inet.settings is not None,
Interface._map_nm_connected(inet.connection),
inet.primary,
Interface._map_nm_type(inet.type),
IpConfig(
name=inet.name,
mac=inet.hw_address,
path=inet.path,
enabled=inet.settings is not None,
connected=Interface._map_nm_connected(inet.connection),
primary=inet.primary,
type=Interface._map_nm_type(inet.type),
ipv4=IpConfig(
address=inet.connection.ipv4.address
if inet.connection.ipv4.address
else [],
@ -188,8 +194,8 @@ class Interface:
)
if inet.connection and inet.connection.ipv4
else IpConfig([], None, [], ipv4_ready),
ipv4_setting,
IpConfig(
ipv4setting=ipv4_setting,
ipv6=IpConfig(
address=inet.connection.ipv6.address
if inet.connection.ipv6.address
else [],
@ -201,30 +207,28 @@ class Interface:
)
if inet.connection and inet.connection.ipv6
else IpConfig([], None, [], ipv6_ready),
ipv6_setting,
Interface._map_nm_wifi(inet),
Interface._map_nm_vlan(inet),
ipv6setting=ipv6_setting,
wifi=Interface._map_nm_wifi(inet),
vlan=Interface._map_nm_vlan(inet),
)
@staticmethod
def _map_nm_method(method: str) -> InterfaceMethod:
def _map_nm_method(method: str | None) -> InterfaceMethod:
"""Map IP interface method."""
mapping = {
NMInterfaceMethod.AUTO: InterfaceMethod.AUTO,
NMInterfaceMethod.DISABLED: InterfaceMethod.DISABLED,
NMInterfaceMethod.MANUAL: InterfaceMethod.STATIC,
NMInterfaceMethod.LINK_LOCAL: InterfaceMethod.DISABLED,
}
return mapping.get(method, InterfaceMethod.DISABLED)
match method:
case NMInterfaceMethod.AUTO.value:
return InterfaceMethod.AUTO
case NMInterfaceMethod.MANUAL:
return InterfaceMethod.STATIC
return InterfaceMethod.DISABLED
@staticmethod
def _map_nm_addr_gen_mode(addr_gen_mode: int) -> InterfaceAddrGenMode:
"""Map IPv6 interface addr_gen_mode."""
mapping = {
NMInterfaceAddrGenMode.EUI64: InterfaceAddrGenMode.EUI64,
NMInterfaceAddrGenMode.STABLE_PRIVACY: InterfaceAddrGenMode.STABLE_PRIVACY,
NMInterfaceAddrGenMode.DEFAULT_OR_EUI64: InterfaceAddrGenMode.DEFAULT_OR_EUI64,
NMInterfaceAddrGenMode.EUI64.value: InterfaceAddrGenMode.EUI64,
NMInterfaceAddrGenMode.STABLE_PRIVACY.value: InterfaceAddrGenMode.STABLE_PRIVACY,
NMInterfaceAddrGenMode.DEFAULT_OR_EUI64.value: InterfaceAddrGenMode.DEFAULT_OR_EUI64,
}
return mapping.get(addr_gen_mode, InterfaceAddrGenMode.DEFAULT)
@ -233,9 +237,9 @@ class Interface:
def _map_nm_ip6_privacy(ip6_privacy: int) -> InterfaceIp6Privacy:
"""Map IPv6 interface ip6_privacy."""
mapping = {
NMInterfaceIp6Privacy.DISABLED: InterfaceIp6Privacy.DISABLED,
NMInterfaceIp6Privacy.ENABLED_PREFER_PUBLIC: InterfaceIp6Privacy.ENABLED_PREFER_PUBLIC,
NMInterfaceIp6Privacy.ENABLED: InterfaceIp6Privacy.ENABLED,
NMInterfaceIp6Privacy.DISABLED.value: InterfaceIp6Privacy.DISABLED,
NMInterfaceIp6Privacy.ENABLED_PREFER_PUBLIC.value: InterfaceIp6Privacy.ENABLED_PREFER_PUBLIC,
NMInterfaceIp6Privacy.ENABLED.value: InterfaceIp6Privacy.ENABLED,
}
return mapping.get(ip6_privacy, InterfaceIp6Privacy.DEFAULT)
@ -253,12 +257,14 @@ class Interface:
@staticmethod
def _map_nm_type(device_type: int) -> InterfaceType:
mapping = {
DeviceType.ETHERNET: InterfaceType.ETHERNET,
DeviceType.WIRELESS: InterfaceType.WIRELESS,
DeviceType.VLAN: InterfaceType.VLAN,
}
return mapping[device_type]
match device_type:
case DeviceType.ETHERNET.value:
return InterfaceType.ETHERNET
case DeviceType.WIRELESS.value:
return InterfaceType.WIRELESS
case DeviceType.VLAN.value:
return InterfaceType.VLAN
raise ValueError(f"Invalid device type: {device_type}")
@staticmethod
def _map_nm_wifi(inet: NetworkInterface) -> WifiConfig | None:
@ -267,15 +273,22 @@ class Interface:
return None
# Authentication and PSK
auth = None
auth = AuthMethod.OPEN
psk = None
if not inet.settings.wireless_security:
auth = AuthMethod.OPEN
elif inet.settings.wireless_security.key_mgmt == "none":
auth = AuthMethod.WEP
elif inet.settings.wireless_security.key_mgmt == "wpa-psk":
auth = AuthMethod.WPA_PSK
psk = inet.settings.wireless_security.psk
if inet.settings.wireless_security:
match inet.settings.wireless_security.key_mgmt:
case "none":
auth = AuthMethod.WEP
case "wpa-psk":
auth = AuthMethod.WPA_PSK
psk = inet.settings.wireless_security.psk
case _:
_LOGGER.warning(
"Auth method %s for network interface %s unsupported, skipping",
inet.settings.wireless_security.key_mgmt,
inet.name,
)
return None
# WifiMode
mode = WifiMode.INFRASTRUCTURE
@ -289,17 +302,17 @@ class Interface:
signal = None
return WifiConfig(
mode,
inet.settings.wireless.ssid,
auth,
psk,
signal,
mode=mode,
ssid=inet.settings.wireless.ssid if inet.settings.wireless else "",
auth=auth,
psk=psk,
signal=signal,
)
@staticmethod
def _map_nm_vlan(inet: NetworkInterface) -> WifiConfig | None:
def _map_nm_vlan(inet: NetworkInterface) -> VlanConfig | None:
"""Create mapping to nm vlan property."""
if inet.type != DeviceType.VLAN or not inet.settings:
if inet.type != DeviceType.VLAN or not inet.settings or not inet.settings.vlan:
return None
return VlanConfig(inet.settings.vlan.id, inet.settings.vlan.parent)

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Mapping
from contextlib import asynccontextmanager
import json
import logging
@ -205,7 +205,7 @@ class LogsControl(CoreSysAttributes):
async def journald_logs(
self,
path: str = "/entries",
params: dict[str, str | list[str]] | None = None,
params: Mapping[str, str | list[str]] | None = None,
range_header: str | None = None,
accept: LogFormat = LogFormat.TEXT,
timeout: ClientTimeout | None = None,
@ -226,7 +226,7 @@ class LogsControl(CoreSysAttributes):
base_url = "http://localhost/"
connector = UnixConnector(path=str(SYSTEMD_JOURNAL_GATEWAYD_SOCKET))
async with ClientSession(base_url=base_url, connector=connector) as session:
headers = {ACCEPT: accept}
headers = {ACCEPT: accept.value}
if range_header:
if range_header.endswith(":"):
# Make sure that num_entries is always set - before Systemd v256 it was

View File

@ -87,7 +87,7 @@ class NetworkManager(CoreSysAttributes):
for config in self.sys_dbus.network.dns.configuration:
if config.vpn or not config.nameservers:
continue
servers.extend(config.nameservers)
servers.extend([str(ns) for ns in config.nameservers])
return list(dict.fromkeys(servers))
@ -197,10 +197,16 @@ class NetworkManager(CoreSysAttributes):
with suppress(NetworkInterfaceNotFound):
inet = self.sys_dbus.network.get(interface.name)
con: NetworkConnection = None
con: NetworkConnection | None = None
# Update exist configuration
if inet and interface.equals_dbus_interface(inet) and interface.enabled:
if (
inet
and inet.settings
and inet.settings.connection
and interface.equals_dbus_interface(inet)
and interface.enabled
):
_LOGGER.debug("Updating existing configuration for %s", interface.name)
settings = get_connection_from_interface(
interface,
@ -211,12 +217,12 @@ class NetworkManager(CoreSysAttributes):
try:
await inet.settings.update(settings)
con = await self.sys_dbus.network.activate_connection(
con = activated = await self.sys_dbus.network.activate_connection(
inet.settings.object_path, inet.object_path
)
_LOGGER.debug(
"activate_connection returns %s",
con.object_path,
activated.object_path,
)
except DBusError as err:
raise HostNetworkError(
@ -236,12 +242,16 @@ class NetworkManager(CoreSysAttributes):
settings = get_connection_from_interface(interface, self.sys_dbus.network)
try:
settings, con = await self.sys_dbus.network.add_and_activate_connection(
(
settings,
activated,
) = await self.sys_dbus.network.add_and_activate_connection(
settings, inet.object_path
)
con = activated
_LOGGER.debug(
"add_and_activate_connection returns %s",
con.object_path,
activated.object_path,
)
except DBusError as err:
raise HostNetworkError(
@ -277,7 +287,7 @@ class NetworkManager(CoreSysAttributes):
)
if con:
async with con.dbus.signal(
async with con.connected_dbus.signal(
DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED
) as signal:
# From this point we monitor signals. However, it might be that
@ -303,7 +313,7 @@ class NetworkManager(CoreSysAttributes):
"""Scan on Interface for AccessPoint."""
inet = self.sys_dbus.network.get(interface.name)
if inet.type != DeviceType.WIRELESS:
if inet.type != DeviceType.WIRELESS or not inet.wireless:
raise HostNotSupportedError(
f"Can only scan with wireless card - {interface.name}", _LOGGER.error
)

View File

@ -1,13 +1,13 @@
"""Supervisor job manager."""
import asyncio
from collections.abc import Awaitable, Callable
from contextlib import contextmanager
from collections.abc import Callable, Coroutine, Generator
from contextlib import contextmanager, suppress
from contextvars import Context, ContextVar, Token
from dataclasses import dataclass
from datetime import datetime
import logging
from typing import Any
from typing import Any, Self
from uuid import uuid4
from attrs import Attribute, define, field
@ -27,7 +27,7 @@ from .validate import SCHEMA_JOBS_CONFIG
# When a new asyncio task is started the current context is copied over.
# Modifications to it in one task are not visible to others though.
# This allows us to track what job is currently in progress in each task.
_CURRENT_JOB: ContextVar[str] = ContextVar("current_job")
_CURRENT_JOB: ContextVar[str | None] = ContextVar("current_job", default=None)
_LOGGER: logging.Logger = logging.getLogger(__name__)
@ -75,7 +75,7 @@ class SupervisorJobError:
message: str = "Unknown error, see supervisor logs"
stage: str | None = None
def as_dict(self) -> dict[str, str]:
def as_dict(self) -> dict[str, str | None]:
"""Return dictionary representation."""
return {
"type": self.type_.__name__,
@ -101,9 +101,7 @@ class SupervisorJob:
stage: str | None = field(
default=None, validator=[_invalid_if_done], on_setattr=_on_change
)
parent_id: str | None = field(
factory=lambda: _CURRENT_JOB.get(None), on_setattr=frozen
)
parent_id: str | None = field(factory=_CURRENT_JOB.get, on_setattr=frozen)
done: bool | None = field(init=False, default=None, on_setattr=_on_change)
on_change: Callable[["SupervisorJob", Attribute, Any], None] | None = field(
default=None, on_setattr=frozen
@ -137,7 +135,7 @@ class SupervisorJob:
self.errors += [new_error]
@contextmanager
def start(self):
def start(self) -> Generator[Self]:
"""Start the job in the current task.
This can only be called if the parent ID matches the job running in the current task.
@ -146,11 +144,11 @@ class SupervisorJob:
"""
if self.done is not None:
raise JobStartException("Job has already been started")
if _CURRENT_JOB.get(None) != self.parent_id:
if _CURRENT_JOB.get() != self.parent_id:
raise JobStartException("Job has a different parent from current job")
self.done = False
token: Token[str] | None = None
token: Token[str | None] | None = None
try:
token = _CURRENT_JOB.set(self.uuid)
yield self
@ -193,17 +191,15 @@ class JobManager(FileConfiguration, CoreSysAttributes):
Must be called from within a job. Raises RuntimeError if there is no current job.
"""
try:
return self.get_job(_CURRENT_JOB.get())
except (LookupError, JobNotFound):
raise RuntimeError(
"No job for the current asyncio task!", _LOGGER.critical
) from None
if job_id := _CURRENT_JOB.get():
with suppress(JobNotFound):
return self.get_job(job_id)
raise RuntimeError("No job for the current asyncio task!", _LOGGER.critical)
@property
def is_job(self) -> bool:
"""Return true if there is an active job for the current asyncio task."""
return bool(_CURRENT_JOB.get(None))
return _CURRENT_JOB.get() is not None
def _notify_on_job_change(
self, job: SupervisorJob, attribute: Attribute, value: Any
@ -265,7 +261,7 @@ class JobManager(FileConfiguration, CoreSysAttributes):
def schedule_job(
self,
job_method: Callable[..., Awaitable[Any]],
job_method: Callable[..., Coroutine],
options: JobSchedulerOptions,
*args,
**kwargs,

View File

@ -1,12 +1,12 @@
"""Job decorator."""
import asyncio
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from contextlib import suppress
from datetime import datetime, timedelta
from functools import wraps
import logging
from typing import Any
from typing import Any, cast
from ..const import CoreState
from ..coresys import CoreSys, CoreSysAttributes
@ -54,11 +54,10 @@ class Job(CoreSysAttributes):
self.on_condition = on_condition
self.limit = limit
self._throttle_period = throttle_period
self.throttle_max_calls = throttle_max_calls
self._throttle_max_calls = throttle_max_calls
self._lock: asyncio.Semaphore | None = None
self._method = None
self._last_call: dict[str | None, datetime] = {}
self._rate_limited_calls: dict[str, list[datetime]] | None = None
self._rate_limited_calls: dict[str | None, list[datetime]] | None = None
self._internal = internal
# Validate Options
@ -82,13 +81,29 @@ class Job(CoreSysAttributes):
JobExecutionLimit.THROTTLE_RATE_LIMIT,
JobExecutionLimit.GROUP_THROTTLE_RATE_LIMIT,
):
if self.throttle_max_calls is None:
if self._throttle_max_calls is None:
raise RuntimeError(
f"Job {name} is using execution limit {limit} without throttle max calls!"
)
self._rate_limited_calls = {}
@property
def throttle_max_calls(self) -> int:
"""Return max calls for throttle."""
if self._throttle_max_calls is None:
raise RuntimeError("No throttle max calls set for job!")
return self._throttle_max_calls
@property
def lock(self) -> asyncio.Semaphore:
"""Return lock for limits."""
# asyncio.Semaphore objects must be created in event loop
# Since this is sync code it is not safe to create if missing here
if not self._lock:
raise RuntimeError("Lock has not been created yet!")
return self._lock
def last_call(self, group_name: str | None = None) -> datetime:
"""Return last call datetime."""
return self._last_call.get(group_name, datetime.min)
@ -97,12 +112,12 @@ class Job(CoreSysAttributes):
"""Set last call datetime."""
self._last_call[group_name] = value
def rate_limited_calls(
self, group_name: str | None = None
) -> list[datetime] | None:
def rate_limited_calls(self, group_name: str | None = None) -> list[datetime]:
"""Return rate limited calls if used."""
if self._rate_limited_calls is None:
return None
raise RuntimeError(
f"Rate limited calls not available for limit type {self.limit}"
)
return self._rate_limited_calls.get(group_name, [])
@ -131,10 +146,10 @@ class Job(CoreSysAttributes):
self._rate_limited_calls[group_name] = value
def throttle_period(self, group_name: str | None = None) -> timedelta | None:
def throttle_period(self, group_name: str | None = None) -> timedelta:
"""Return throttle period."""
if self._throttle_period is None:
return None
raise RuntimeError("No throttle period set for Job!")
if isinstance(self._throttle_period, timedelta):
return self._throttle_period
@ -142,7 +157,7 @@ class Job(CoreSysAttributes):
return self._throttle_period(
self.coresys,
self.last_call(group_name),
self.rate_limited_calls(group_name),
self.rate_limited_calls(group_name) if self._rate_limited_calls else None,
)
def _post_init(self, obj: JobGroup | CoreSysAttributes) -> JobGroup | None:
@ -158,12 +173,12 @@ class Job(CoreSysAttributes):
self._lock = asyncio.Semaphore()
# Job groups
try:
is_job_group = obj.acquire and obj.release
except AttributeError:
is_job_group = False
job_group: JobGroup | None = None
with suppress(AttributeError):
if obj.acquire and obj.release: # type: ignore
job_group = cast(JobGroup, obj)
if not is_job_group and self.limit in (
if not job_group and self.limit in (
JobExecutionLimit.GROUP_ONCE,
JobExecutionLimit.GROUP_WAIT,
JobExecutionLimit.GROUP_THROTTLE,
@ -174,7 +189,7 @@ class Job(CoreSysAttributes):
f"Job on {self.name} need to be a JobGroup to use group based limits!"
) from None
return obj if is_job_group else None
return job_group
def _handle_job_condition_exception(self, err: JobConditionException) -> None:
"""Handle a job condition failure."""
@ -184,9 +199,8 @@ class Job(CoreSysAttributes):
return
raise self.on_condition(error_msg, _LOGGER.warning) from None
def __call__(self, method):
def __call__(self, method: Callable[..., Awaitable]):
"""Call the wrapper logic."""
self._method = method
@wraps(method)
async def wrapper(
@ -221,7 +235,7 @@ class Job(CoreSysAttributes):
if self.conditions:
try:
await Job.check_conditions(
self, set(self.conditions), self._method.__qualname__
self, set(self.conditions), method.__qualname__
)
except JobConditionException as err:
return self._handle_job_condition_exception(err)
@ -237,7 +251,7 @@ class Job(CoreSysAttributes):
JobExecutionLimit.GROUP_WAIT,
):
try:
await obj.acquire(
await cast(JobGroup, job_group).acquire(
job, self.limit == JobExecutionLimit.GROUP_WAIT
)
except JobGroupExecutionLimitExceeded as err:
@ -296,12 +310,12 @@ class Job(CoreSysAttributes):
with job.start():
try:
self.set_last_call(datetime.now(), group_name)
if self.rate_limited_calls(group_name) is not None:
if self._rate_limited_calls is not None:
self.add_rate_limited_call(
self.last_call(group_name), group_name
)
return await self._method(obj, *args, **kwargs)
return await method(obj, *args, **kwargs)
# If a method has a conditional JobCondition, they must check it in the method
# These should be handled like normal JobConditions as much as possible
@ -317,11 +331,11 @@ class Job(CoreSysAttributes):
raise JobException() from err
finally:
self._release_exception_limits()
if self.limit in (
if job_group and self.limit in (
JobExecutionLimit.GROUP_ONCE,
JobExecutionLimit.GROUP_WAIT,
):
obj.release()
job_group.release()
# Jobs that weren't started are always cleaned up. Also clean up done jobs if required
finally:
@ -473,13 +487,13 @@ class Job(CoreSysAttributes):
):
return
if self.limit == JobExecutionLimit.ONCE and self._lock.locked():
if self.limit == JobExecutionLimit.ONCE and self.lock.locked():
on_condition = (
JobException if self.on_condition is None else self.on_condition
)
raise on_condition("Another job is running")
await self._lock.acquire()
await self.lock.acquire()
def _release_exception_limits(self) -> None:
"""Release possible exception limits."""
@ -490,4 +504,4 @@ class Job(CoreSysAttributes):
JobExecutionLimit.GROUP_THROTTLE_WAIT,
):
return
self._lock.release()
self.lock.release()

View File

@ -41,7 +41,7 @@ class JobGroup(CoreSysAttributes):
def has_lock(self) -> bool:
"""Return true if current task has the lock on this job group."""
return (
self.active_job
self.active_job is not None
and self.sys_jobs.is_job
and self.active_job == self.sys_jobs.current
)