diff --git a/supervisor/coresys.py b/supervisor/coresys.py index 41bdfdc26..f69a3e82a 100644 --- a/supervisor/coresys.py +++ b/supervisor/coresys.py @@ -630,7 +630,7 @@ class CoreSys: def call_later( self, delay: float, - funct: Callable[..., Coroutine[Any, Any, T]], + funct: Callable[..., Any], *args: tuple[Any], **kwargs: dict[str, Any], ) -> asyncio.TimerHandle: @@ -643,7 +643,7 @@ class CoreSys: def call_at( self, when: datetime, - funct: Callable[..., Coroutine[Any, Any, T]], + funct: Callable[..., Any], *args: tuple[Any], **kwargs: dict[str, Any], ) -> asyncio.TimerHandle: @@ -843,7 +843,7 @@ class CoreSysAttributes: def sys_call_later( self, delay: float, - funct: Callable[..., Coroutine[Any, Any, T]], + funct: Callable[..., Any], *args, **kwargs, ) -> asyncio.TimerHandle: @@ -853,7 +853,7 @@ class CoreSysAttributes: def sys_call_at( self, when: datetime, - funct: Callable[..., Coroutine[Any, Any, T]], + funct: Callable[..., Any], *args, **kwargs, ) -> asyncio.TimerHandle: diff --git a/supervisor/dbus/const.py b/supervisor/dbus/const.py index fc324d3b9..d78850aec 100644 --- a/supervisor/dbus/const.py +++ b/supervisor/dbus/const.py @@ -135,6 +135,7 @@ DBUS_ATTR_LAST_ERROR = "LastError" DBUS_ATTR_LLMNR = "LLMNR" DBUS_ATTR_LLMNR_HOSTNAME = "LLMNRHostname" DBUS_ATTR_LOADER_TIMESTAMP_MONOTONIC = "LoaderTimestampMonotonic" +DBUS_ATTR_LOCAL_RTC = "LocalRTC" DBUS_ATTR_MANAGED = "Managed" DBUS_ATTR_MODE = "Mode" DBUS_ATTR_MODEL = "Model" diff --git a/supervisor/dbus/network/configuration.py b/supervisor/dbus/network/configuration.py index 066fad41d..6a58aef21 100644 --- a/supervisor/dbus/network/configuration.py +++ b/supervisor/dbus/network/configuration.py @@ -1,5 +1,6 @@ """NetworkConnection objects for Network Manager.""" +from abc import ABC from dataclasses import dataclass from ipaddress import IPv4Address, IPv6Address @@ -29,7 +30,7 @@ class ConnectionProperties: class WirelessProperties: """Wireless Properties object for Network Manager.""" - ssid: str | None + ssid: str assigned_mac: str | None mode: str | None powersave: int | None @@ -55,7 +56,7 @@ class EthernetProperties: class VlanProperties: """Ethernet properties object for Network Manager.""" - id: int | None + id: int parent: str | None @@ -67,14 +68,20 @@ class IpAddress: prefix: int -@dataclass(slots=True) -class IpProperties: +@dataclass +class IpProperties(ABC): """IP properties object for Network Manager.""" method: str | None address_data: list[IpAddress] | None gateway: str | None - dns: list[bytes | int] | None + + +@dataclass(slots=True) +class Ip4Properties(IpProperties): + """IPv4 properties object.""" + + dns: list[int] | None @dataclass(slots=True) @@ -83,6 +90,7 @@ class Ip6Properties(IpProperties): addr_gen_mode: int ip6_privacy: int + dns: list[bytes] | None @dataclass(slots=True) diff --git a/supervisor/dbus/network/setting/__init__.py b/supervisor/dbus/network/setting/__init__.py index 8ad1c6ec4..8e5404f1e 100644 --- a/supervisor/dbus/network/setting/__init__.py +++ b/supervisor/dbus/network/setting/__init__.py @@ -12,9 +12,9 @@ from ...utils import dbus_connected from ..configuration import ( ConnectionProperties, EthernetProperties, + Ip4Properties, Ip6Properties, IpAddress, - IpProperties, MatchProperties, VlanProperties, WirelessProperties, @@ -115,7 +115,7 @@ class NetworkSetting(DBusInterface): self._wireless_security: WirelessSecurityProperties | None = None self._ethernet: EthernetProperties | None = None self._vlan: VlanProperties | None = None - self._ipv4: IpProperties | None = None + self._ipv4: Ip4Properties | None = None self._ipv6: Ip6Properties | None = None self._match: MatchProperties | None = None super().__init__() @@ -151,7 +151,7 @@ class NetworkSetting(DBusInterface): return self._vlan @property - def ipv4(self) -> IpProperties | None: + def ipv4(self) -> Ip4Properties | None: """Return ipv4 properties if any.""" return self._ipv4 @@ -271,16 +271,23 @@ class NetworkSetting(DBusInterface): ) if CONF_ATTR_VLAN in data: - self._vlan = VlanProperties( - id=data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_ID), - parent=data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_PARENT), - ) + if CONF_ATTR_VLAN_ID in data[CONF_ATTR_VLAN]: + self._vlan = VlanProperties( + data[CONF_ATTR_VLAN][CONF_ATTR_VLAN_ID], + data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_PARENT), + ) + else: + self._vlan = None + _LOGGER.warning( + "Network settings for vlan connection %s missing required vlan id, cannot process it", + self.connection.interface_name, + ) if CONF_ATTR_IPV4 in data: address_data = None if ips := data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_ADDRESS_DATA): address_data = [IpAddress(ip["address"], ip["prefix"]) for ip in ips] - self._ipv4 = IpProperties( + self._ipv4 = Ip4Properties( method=data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_METHOD), address_data=address_data, gateway=data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_GATEWAY), diff --git a/supervisor/dbus/network/setting/generate.py b/supervisor/dbus/network/setting/generate.py index 31849989f..b15a29e1e 100644 --- a/supervisor/dbus/network/setting/generate.py +++ b/supervisor/dbus/network/setting/generate.py @@ -222,8 +222,10 @@ def get_connection_from_interface( } elif interface.type == "vlan": parent = cast(VlanConfig, interface.vlan).interface - if parent in network_manager and ( - parent_connection := network_manager.get(parent).connection + if ( + parent + and parent in network_manager + and (parent_connection := network_manager.get(parent).connection) ): parent = parent_connection.uuid diff --git a/supervisor/dbus/timedate.py b/supervisor/dbus/timedate.py index cfee27a0d..407847e06 100644 --- a/supervisor/dbus/timedate.py +++ b/supervisor/dbus/timedate.py @@ -10,6 +10,7 @@ from dbus_fast.aio.message_bus import MessageBus from ..exceptions import DBusError, DBusInterfaceError, DBusServiceUnkownError from ..utils.dt import get_time_zone, utc_from_timestamp from .const import ( + DBUS_ATTR_LOCAL_RTC, DBUS_ATTR_NTP, DBUS_ATTR_NTPSYNCHRONIZED, DBUS_ATTR_TIMEUSEC, @@ -46,6 +47,12 @@ class TimeDate(DBusInterfaceProxy): """Return host timezone.""" return self.properties[DBUS_ATTR_TIMEZONE] + @property + @dbus_property + def local_rtc(self) -> bool: + """Return whether rtc is local time or utc.""" + return self.properties[DBUS_ATTR_LOCAL_RTC] + @property @dbus_property def ntp(self) -> bool: diff --git a/supervisor/host/configuration.py b/supervisor/host/configuration.py index ddf1d950e..a81f33b94 100644 --- a/supervisor/host/configuration.py +++ b/supervisor/host/configuration.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from ipaddress import IPv4Address, IPv4Interface, IPv6Address, IPv6Interface +import logging import socket from ..dbus.const import ( @@ -23,6 +24,8 @@ from .const import ( WifiMode, ) +_LOGGER: logging.Logger = logging.getLogger(__name__) + @dataclass(slots=True) class AccessPoint: @@ -79,7 +82,7 @@ class VlanConfig: """Represent a vlan configuration.""" id: int - interface: str + interface: str | None @dataclass(slots=True) @@ -108,7 +111,10 @@ class Interface: if inet.settings.match and inet.settings.match.path: return inet.settings.match.path == [self.path] - return inet.settings.connection.interface_name == self.name + return ( + inet.settings.connection is not None + and inet.settings.connection.interface_name == self.name + ) @staticmethod def from_dbus_interface(inet: NetworkInterface) -> "Interface": @@ -160,23 +166,23 @@ class Interface: ipv6_setting = Ip6Setting(InterfaceMethod.DISABLED, [], None, []) ipv4_ready = ( - bool(inet.connection) + inet.connection is not None and ConnectionStateFlags.IP4_READY in inet.connection.state_flags ) ipv6_ready = ( - bool(inet.connection) + inet.connection is not None and ConnectionStateFlags.IP6_READY in inet.connection.state_flags ) return Interface( - inet.name, - inet.hw_address, - inet.path, - inet.settings is not None, - Interface._map_nm_connected(inet.connection), - inet.primary, - Interface._map_nm_type(inet.type), - IpConfig( + name=inet.name, + mac=inet.hw_address, + path=inet.path, + enabled=inet.settings is not None, + connected=Interface._map_nm_connected(inet.connection), + primary=inet.primary, + type=Interface._map_nm_type(inet.type), + ipv4=IpConfig( address=inet.connection.ipv4.address if inet.connection.ipv4.address else [], @@ -188,8 +194,8 @@ class Interface: ) if inet.connection and inet.connection.ipv4 else IpConfig([], None, [], ipv4_ready), - ipv4_setting, - IpConfig( + ipv4setting=ipv4_setting, + ipv6=IpConfig( address=inet.connection.ipv6.address if inet.connection.ipv6.address else [], @@ -201,30 +207,28 @@ class Interface: ) if inet.connection and inet.connection.ipv6 else IpConfig([], None, [], ipv6_ready), - ipv6_setting, - Interface._map_nm_wifi(inet), - Interface._map_nm_vlan(inet), + ipv6setting=ipv6_setting, + wifi=Interface._map_nm_wifi(inet), + vlan=Interface._map_nm_vlan(inet), ) @staticmethod - def _map_nm_method(method: str) -> InterfaceMethod: + def _map_nm_method(method: str | None) -> InterfaceMethod: """Map IP interface method.""" - mapping = { - NMInterfaceMethod.AUTO: InterfaceMethod.AUTO, - NMInterfaceMethod.DISABLED: InterfaceMethod.DISABLED, - NMInterfaceMethod.MANUAL: InterfaceMethod.STATIC, - NMInterfaceMethod.LINK_LOCAL: InterfaceMethod.DISABLED, - } - - return mapping.get(method, InterfaceMethod.DISABLED) + match method: + case NMInterfaceMethod.AUTO.value: + return InterfaceMethod.AUTO + case NMInterfaceMethod.MANUAL: + return InterfaceMethod.STATIC + return InterfaceMethod.DISABLED @staticmethod def _map_nm_addr_gen_mode(addr_gen_mode: int) -> InterfaceAddrGenMode: """Map IPv6 interface addr_gen_mode.""" mapping = { - NMInterfaceAddrGenMode.EUI64: InterfaceAddrGenMode.EUI64, - NMInterfaceAddrGenMode.STABLE_PRIVACY: InterfaceAddrGenMode.STABLE_PRIVACY, - NMInterfaceAddrGenMode.DEFAULT_OR_EUI64: InterfaceAddrGenMode.DEFAULT_OR_EUI64, + NMInterfaceAddrGenMode.EUI64.value: InterfaceAddrGenMode.EUI64, + NMInterfaceAddrGenMode.STABLE_PRIVACY.value: InterfaceAddrGenMode.STABLE_PRIVACY, + NMInterfaceAddrGenMode.DEFAULT_OR_EUI64.value: InterfaceAddrGenMode.DEFAULT_OR_EUI64, } return mapping.get(addr_gen_mode, InterfaceAddrGenMode.DEFAULT) @@ -233,9 +237,9 @@ class Interface: def _map_nm_ip6_privacy(ip6_privacy: int) -> InterfaceIp6Privacy: """Map IPv6 interface ip6_privacy.""" mapping = { - NMInterfaceIp6Privacy.DISABLED: InterfaceIp6Privacy.DISABLED, - NMInterfaceIp6Privacy.ENABLED_PREFER_PUBLIC: InterfaceIp6Privacy.ENABLED_PREFER_PUBLIC, - NMInterfaceIp6Privacy.ENABLED: InterfaceIp6Privacy.ENABLED, + NMInterfaceIp6Privacy.DISABLED.value: InterfaceIp6Privacy.DISABLED, + NMInterfaceIp6Privacy.ENABLED_PREFER_PUBLIC.value: InterfaceIp6Privacy.ENABLED_PREFER_PUBLIC, + NMInterfaceIp6Privacy.ENABLED.value: InterfaceIp6Privacy.ENABLED, } return mapping.get(ip6_privacy, InterfaceIp6Privacy.DEFAULT) @@ -253,12 +257,14 @@ class Interface: @staticmethod def _map_nm_type(device_type: int) -> InterfaceType: - mapping = { - DeviceType.ETHERNET: InterfaceType.ETHERNET, - DeviceType.WIRELESS: InterfaceType.WIRELESS, - DeviceType.VLAN: InterfaceType.VLAN, - } - return mapping[device_type] + match device_type: + case DeviceType.ETHERNET.value: + return InterfaceType.ETHERNET + case DeviceType.WIRELESS.value: + return InterfaceType.WIRELESS + case DeviceType.VLAN.value: + return InterfaceType.VLAN + raise ValueError(f"Invalid device type: {device_type}") @staticmethod def _map_nm_wifi(inet: NetworkInterface) -> WifiConfig | None: @@ -267,15 +273,22 @@ class Interface: return None # Authentication and PSK - auth = None + auth = AuthMethod.OPEN psk = None - if not inet.settings.wireless_security: - auth = AuthMethod.OPEN - elif inet.settings.wireless_security.key_mgmt == "none": - auth = AuthMethod.WEP - elif inet.settings.wireless_security.key_mgmt == "wpa-psk": - auth = AuthMethod.WPA_PSK - psk = inet.settings.wireless_security.psk + if inet.settings.wireless_security: + match inet.settings.wireless_security.key_mgmt: + case "none": + auth = AuthMethod.WEP + case "wpa-psk": + auth = AuthMethod.WPA_PSK + psk = inet.settings.wireless_security.psk + case _: + _LOGGER.warning( + "Auth method %s for network interface %s unsupported, skipping", + inet.settings.wireless_security.key_mgmt, + inet.name, + ) + return None # WifiMode mode = WifiMode.INFRASTRUCTURE @@ -289,17 +302,17 @@ class Interface: signal = None return WifiConfig( - mode, - inet.settings.wireless.ssid, - auth, - psk, - signal, + mode=mode, + ssid=inet.settings.wireless.ssid if inet.settings.wireless else "", + auth=auth, + psk=psk, + signal=signal, ) @staticmethod - def _map_nm_vlan(inet: NetworkInterface) -> WifiConfig | None: + def _map_nm_vlan(inet: NetworkInterface) -> VlanConfig | None: """Create mapping to nm vlan property.""" - if inet.type != DeviceType.VLAN or not inet.settings: + if inet.type != DeviceType.VLAN or not inet.settings or not inet.settings.vlan: return None return VlanConfig(inet.settings.vlan.id, inet.settings.vlan.parent) diff --git a/supervisor/host/logs.py b/supervisor/host/logs.py index 3c1174eac..f1ed7315d 100644 --- a/supervisor/host/logs.py +++ b/supervisor/host/logs.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Mapping from contextlib import asynccontextmanager import json import logging @@ -205,7 +205,7 @@ class LogsControl(CoreSysAttributes): async def journald_logs( self, path: str = "/entries", - params: dict[str, str | list[str]] | None = None, + params: Mapping[str, str | list[str]] | None = None, range_header: str | None = None, accept: LogFormat = LogFormat.TEXT, timeout: ClientTimeout | None = None, @@ -226,7 +226,7 @@ class LogsControl(CoreSysAttributes): base_url = "http://localhost/" connector = UnixConnector(path=str(SYSTEMD_JOURNAL_GATEWAYD_SOCKET)) async with ClientSession(base_url=base_url, connector=connector) as session: - headers = {ACCEPT: accept} + headers = {ACCEPT: accept.value} if range_header: if range_header.endswith(":"): # Make sure that num_entries is always set - before Systemd v256 it was diff --git a/supervisor/host/network.py b/supervisor/host/network.py index cd37a7d36..d7e9b9538 100644 --- a/supervisor/host/network.py +++ b/supervisor/host/network.py @@ -87,7 +87,7 @@ class NetworkManager(CoreSysAttributes): for config in self.sys_dbus.network.dns.configuration: if config.vpn or not config.nameservers: continue - servers.extend(config.nameservers) + servers.extend([str(ns) for ns in config.nameservers]) return list(dict.fromkeys(servers)) @@ -197,10 +197,16 @@ class NetworkManager(CoreSysAttributes): with suppress(NetworkInterfaceNotFound): inet = self.sys_dbus.network.get(interface.name) - con: NetworkConnection = None + con: NetworkConnection | None = None # Update exist configuration - if inet and interface.equals_dbus_interface(inet) and interface.enabled: + if ( + inet + and inet.settings + and inet.settings.connection + and interface.equals_dbus_interface(inet) + and interface.enabled + ): _LOGGER.debug("Updating existing configuration for %s", interface.name) settings = get_connection_from_interface( interface, @@ -211,12 +217,12 @@ class NetworkManager(CoreSysAttributes): try: await inet.settings.update(settings) - con = await self.sys_dbus.network.activate_connection( + con = activated = await self.sys_dbus.network.activate_connection( inet.settings.object_path, inet.object_path ) _LOGGER.debug( "activate_connection returns %s", - con.object_path, + activated.object_path, ) except DBusError as err: raise HostNetworkError( @@ -236,12 +242,16 @@ class NetworkManager(CoreSysAttributes): settings = get_connection_from_interface(interface, self.sys_dbus.network) try: - settings, con = await self.sys_dbus.network.add_and_activate_connection( + ( + settings, + activated, + ) = await self.sys_dbus.network.add_and_activate_connection( settings, inet.object_path ) + con = activated _LOGGER.debug( "add_and_activate_connection returns %s", - con.object_path, + activated.object_path, ) except DBusError as err: raise HostNetworkError( @@ -277,7 +287,7 @@ class NetworkManager(CoreSysAttributes): ) if con: - async with con.dbus.signal( + async with con.connected_dbus.signal( DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED ) as signal: # From this point we monitor signals. However, it might be that @@ -303,7 +313,7 @@ class NetworkManager(CoreSysAttributes): """Scan on Interface for AccessPoint.""" inet = self.sys_dbus.network.get(interface.name) - if inet.type != DeviceType.WIRELESS: + if inet.type != DeviceType.WIRELESS or not inet.wireless: raise HostNotSupportedError( f"Can only scan with wireless card - {interface.name}", _LOGGER.error ) diff --git a/supervisor/jobs/__init__.py b/supervisor/jobs/__init__.py index 6a84026c4..5ce52de1d 100644 --- a/supervisor/jobs/__init__.py +++ b/supervisor/jobs/__init__.py @@ -1,13 +1,13 @@ """Supervisor job manager.""" import asyncio -from collections.abc import Awaitable, Callable -from contextlib import contextmanager +from collections.abc import Callable, Coroutine, Generator +from contextlib import contextmanager, suppress from contextvars import Context, ContextVar, Token from dataclasses import dataclass from datetime import datetime import logging -from typing import Any +from typing import Any, Self from uuid import uuid4 from attrs import Attribute, define, field @@ -27,7 +27,7 @@ from .validate import SCHEMA_JOBS_CONFIG # When a new asyncio task is started the current context is copied over. # Modifications to it in one task are not visible to others though. # This allows us to track what job is currently in progress in each task. -_CURRENT_JOB: ContextVar[str] = ContextVar("current_job") +_CURRENT_JOB: ContextVar[str | None] = ContextVar("current_job", default=None) _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -75,7 +75,7 @@ class SupervisorJobError: message: str = "Unknown error, see supervisor logs" stage: str | None = None - def as_dict(self) -> dict[str, str]: + def as_dict(self) -> dict[str, str | None]: """Return dictionary representation.""" return { "type": self.type_.__name__, @@ -101,9 +101,7 @@ class SupervisorJob: stage: str | None = field( default=None, validator=[_invalid_if_done], on_setattr=_on_change ) - parent_id: str | None = field( - factory=lambda: _CURRENT_JOB.get(None), on_setattr=frozen - ) + parent_id: str | None = field(factory=_CURRENT_JOB.get, on_setattr=frozen) done: bool | None = field(init=False, default=None, on_setattr=_on_change) on_change: Callable[["SupervisorJob", Attribute, Any], None] | None = field( default=None, on_setattr=frozen @@ -137,7 +135,7 @@ class SupervisorJob: self.errors += [new_error] @contextmanager - def start(self): + def start(self) -> Generator[Self]: """Start the job in the current task. This can only be called if the parent ID matches the job running in the current task. @@ -146,11 +144,11 @@ class SupervisorJob: """ if self.done is not None: raise JobStartException("Job has already been started") - if _CURRENT_JOB.get(None) != self.parent_id: + if _CURRENT_JOB.get() != self.parent_id: raise JobStartException("Job has a different parent from current job") self.done = False - token: Token[str] | None = None + token: Token[str | None] | None = None try: token = _CURRENT_JOB.set(self.uuid) yield self @@ -193,17 +191,15 @@ class JobManager(FileConfiguration, CoreSysAttributes): Must be called from within a job. Raises RuntimeError if there is no current job. """ - try: - return self.get_job(_CURRENT_JOB.get()) - except (LookupError, JobNotFound): - raise RuntimeError( - "No job for the current asyncio task!", _LOGGER.critical - ) from None + if job_id := _CURRENT_JOB.get(): + with suppress(JobNotFound): + return self.get_job(job_id) + raise RuntimeError("No job for the current asyncio task!", _LOGGER.critical) @property def is_job(self) -> bool: """Return true if there is an active job for the current asyncio task.""" - return bool(_CURRENT_JOB.get(None)) + return _CURRENT_JOB.get() is not None def _notify_on_job_change( self, job: SupervisorJob, attribute: Attribute, value: Any @@ -265,7 +261,7 @@ class JobManager(FileConfiguration, CoreSysAttributes): def schedule_job( self, - job_method: Callable[..., Awaitable[Any]], + job_method: Callable[..., Coroutine], options: JobSchedulerOptions, *args, **kwargs, diff --git a/supervisor/jobs/decorator.py b/supervisor/jobs/decorator.py index 07db1eacd..29514dd98 100644 --- a/supervisor/jobs/decorator.py +++ b/supervisor/jobs/decorator.py @@ -1,12 +1,12 @@ """Job decorator.""" import asyncio -from collections.abc import Callable +from collections.abc import Awaitable, Callable from contextlib import suppress from datetime import datetime, timedelta from functools import wraps import logging -from typing import Any +from typing import Any, cast from ..const import CoreState from ..coresys import CoreSys, CoreSysAttributes @@ -54,11 +54,10 @@ class Job(CoreSysAttributes): self.on_condition = on_condition self.limit = limit self._throttle_period = throttle_period - self.throttle_max_calls = throttle_max_calls + self._throttle_max_calls = throttle_max_calls self._lock: asyncio.Semaphore | None = None - self._method = None self._last_call: dict[str | None, datetime] = {} - self._rate_limited_calls: dict[str, list[datetime]] | None = None + self._rate_limited_calls: dict[str | None, list[datetime]] | None = None self._internal = internal # Validate Options @@ -82,13 +81,29 @@ class Job(CoreSysAttributes): JobExecutionLimit.THROTTLE_RATE_LIMIT, JobExecutionLimit.GROUP_THROTTLE_RATE_LIMIT, ): - if self.throttle_max_calls is None: + if self._throttle_max_calls is None: raise RuntimeError( f"Job {name} is using execution limit {limit} without throttle max calls!" ) self._rate_limited_calls = {} + @property + def throttle_max_calls(self) -> int: + """Return max calls for throttle.""" + if self._throttle_max_calls is None: + raise RuntimeError("No throttle max calls set for job!") + return self._throttle_max_calls + + @property + def lock(self) -> asyncio.Semaphore: + """Return lock for limits.""" + # asyncio.Semaphore objects must be created in event loop + # Since this is sync code it is not safe to create if missing here + if not self._lock: + raise RuntimeError("Lock has not been created yet!") + return self._lock + def last_call(self, group_name: str | None = None) -> datetime: """Return last call datetime.""" return self._last_call.get(group_name, datetime.min) @@ -97,12 +112,12 @@ class Job(CoreSysAttributes): """Set last call datetime.""" self._last_call[group_name] = value - def rate_limited_calls( - self, group_name: str | None = None - ) -> list[datetime] | None: + def rate_limited_calls(self, group_name: str | None = None) -> list[datetime]: """Return rate limited calls if used.""" if self._rate_limited_calls is None: - return None + raise RuntimeError( + f"Rate limited calls not available for limit type {self.limit}" + ) return self._rate_limited_calls.get(group_name, []) @@ -131,10 +146,10 @@ class Job(CoreSysAttributes): self._rate_limited_calls[group_name] = value - def throttle_period(self, group_name: str | None = None) -> timedelta | None: + def throttle_period(self, group_name: str | None = None) -> timedelta: """Return throttle period.""" if self._throttle_period is None: - return None + raise RuntimeError("No throttle period set for Job!") if isinstance(self._throttle_period, timedelta): return self._throttle_period @@ -142,7 +157,7 @@ class Job(CoreSysAttributes): return self._throttle_period( self.coresys, self.last_call(group_name), - self.rate_limited_calls(group_name), + self.rate_limited_calls(group_name) if self._rate_limited_calls else None, ) def _post_init(self, obj: JobGroup | CoreSysAttributes) -> JobGroup | None: @@ -158,12 +173,12 @@ class Job(CoreSysAttributes): self._lock = asyncio.Semaphore() # Job groups - try: - is_job_group = obj.acquire and obj.release - except AttributeError: - is_job_group = False + job_group: JobGroup | None = None + with suppress(AttributeError): + if obj.acquire and obj.release: # type: ignore + job_group = cast(JobGroup, obj) - if not is_job_group and self.limit in ( + if not job_group and self.limit in ( JobExecutionLimit.GROUP_ONCE, JobExecutionLimit.GROUP_WAIT, JobExecutionLimit.GROUP_THROTTLE, @@ -174,7 +189,7 @@ class Job(CoreSysAttributes): f"Job on {self.name} need to be a JobGroup to use group based limits!" ) from None - return obj if is_job_group else None + return job_group def _handle_job_condition_exception(self, err: JobConditionException) -> None: """Handle a job condition failure.""" @@ -184,9 +199,8 @@ class Job(CoreSysAttributes): return raise self.on_condition(error_msg, _LOGGER.warning) from None - def __call__(self, method): + def __call__(self, method: Callable[..., Awaitable]): """Call the wrapper logic.""" - self._method = method @wraps(method) async def wrapper( @@ -221,7 +235,7 @@ class Job(CoreSysAttributes): if self.conditions: try: await Job.check_conditions( - self, set(self.conditions), self._method.__qualname__ + self, set(self.conditions), method.__qualname__ ) except JobConditionException as err: return self._handle_job_condition_exception(err) @@ -237,7 +251,7 @@ class Job(CoreSysAttributes): JobExecutionLimit.GROUP_WAIT, ): try: - await obj.acquire( + await cast(JobGroup, job_group).acquire( job, self.limit == JobExecutionLimit.GROUP_WAIT ) except JobGroupExecutionLimitExceeded as err: @@ -296,12 +310,12 @@ class Job(CoreSysAttributes): with job.start(): try: self.set_last_call(datetime.now(), group_name) - if self.rate_limited_calls(group_name) is not None: + if self._rate_limited_calls is not None: self.add_rate_limited_call( self.last_call(group_name), group_name ) - return await self._method(obj, *args, **kwargs) + return await method(obj, *args, **kwargs) # If a method has a conditional JobCondition, they must check it in the method # These should be handled like normal JobConditions as much as possible @@ -317,11 +331,11 @@ class Job(CoreSysAttributes): raise JobException() from err finally: self._release_exception_limits() - if self.limit in ( + if job_group and self.limit in ( JobExecutionLimit.GROUP_ONCE, JobExecutionLimit.GROUP_WAIT, ): - obj.release() + job_group.release() # Jobs that weren't started are always cleaned up. Also clean up done jobs if required finally: @@ -473,13 +487,13 @@ class Job(CoreSysAttributes): ): return - if self.limit == JobExecutionLimit.ONCE and self._lock.locked(): + if self.limit == JobExecutionLimit.ONCE and self.lock.locked(): on_condition = ( JobException if self.on_condition is None else self.on_condition ) raise on_condition("Another job is running") - await self._lock.acquire() + await self.lock.acquire() def _release_exception_limits(self) -> None: """Release possible exception limits.""" @@ -490,4 +504,4 @@ class Job(CoreSysAttributes): JobExecutionLimit.GROUP_THROTTLE_WAIT, ): return - self._lock.release() + self.lock.release() diff --git a/supervisor/jobs/job_group.py b/supervisor/jobs/job_group.py index 50f034c7b..4dece17e3 100644 --- a/supervisor/jobs/job_group.py +++ b/supervisor/jobs/job_group.py @@ -41,7 +41,7 @@ class JobGroup(CoreSysAttributes): def has_lock(self) -> bool: """Return true if current task has the lock on this job group.""" return ( - self.active_job + self.active_job is not None and self.sys_jobs.is_job and self.active_job == self.sys_jobs.current )