mirror of
https://github.com/home-assistant/supervisor.git
synced 2025-07-23 00:56:29 +00:00
Fix mypy issues in host
and jobs
(#5939)
* Fix mypy issues in host * Fix mypy issues in job module * Fix mypy issues introduced in previously fixed modules * Apply suggestions from code review Co-authored-by: Stefan Agner <stefan@agner.ch> --------- Co-authored-by: Stefan Agner <stefan@agner.ch>
This commit is contained in:
parent
fd0b894d6a
commit
9682870c2c
@ -630,7 +630,7 @@ class CoreSys:
|
||||
def call_later(
|
||||
self,
|
||||
delay: float,
|
||||
funct: Callable[..., Coroutine[Any, Any, T]],
|
||||
funct: Callable[..., Any],
|
||||
*args: tuple[Any],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> asyncio.TimerHandle:
|
||||
@ -643,7 +643,7 @@ class CoreSys:
|
||||
def call_at(
|
||||
self,
|
||||
when: datetime,
|
||||
funct: Callable[..., Coroutine[Any, Any, T]],
|
||||
funct: Callable[..., Any],
|
||||
*args: tuple[Any],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> asyncio.TimerHandle:
|
||||
@ -843,7 +843,7 @@ class CoreSysAttributes:
|
||||
def sys_call_later(
|
||||
self,
|
||||
delay: float,
|
||||
funct: Callable[..., Coroutine[Any, Any, T]],
|
||||
funct: Callable[..., Any],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> asyncio.TimerHandle:
|
||||
@ -853,7 +853,7 @@ class CoreSysAttributes:
|
||||
def sys_call_at(
|
||||
self,
|
||||
when: datetime,
|
||||
funct: Callable[..., Coroutine[Any, Any, T]],
|
||||
funct: Callable[..., Any],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> asyncio.TimerHandle:
|
||||
|
@ -135,6 +135,7 @@ DBUS_ATTR_LAST_ERROR = "LastError"
|
||||
DBUS_ATTR_LLMNR = "LLMNR"
|
||||
DBUS_ATTR_LLMNR_HOSTNAME = "LLMNRHostname"
|
||||
DBUS_ATTR_LOADER_TIMESTAMP_MONOTONIC = "LoaderTimestampMonotonic"
|
||||
DBUS_ATTR_LOCAL_RTC = "LocalRTC"
|
||||
DBUS_ATTR_MANAGED = "Managed"
|
||||
DBUS_ATTR_MODE = "Mode"
|
||||
DBUS_ATTR_MODEL = "Model"
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""NetworkConnection objects for Network Manager."""
|
||||
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
|
||||
@ -29,7 +30,7 @@ class ConnectionProperties:
|
||||
class WirelessProperties:
|
||||
"""Wireless Properties object for Network Manager."""
|
||||
|
||||
ssid: str | None
|
||||
ssid: str
|
||||
assigned_mac: str | None
|
||||
mode: str | None
|
||||
powersave: int | None
|
||||
@ -55,7 +56,7 @@ class EthernetProperties:
|
||||
class VlanProperties:
|
||||
"""Ethernet properties object for Network Manager."""
|
||||
|
||||
id: int | None
|
||||
id: int
|
||||
parent: str | None
|
||||
|
||||
|
||||
@ -67,14 +68,20 @@ class IpAddress:
|
||||
prefix: int
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class IpProperties:
|
||||
@dataclass
|
||||
class IpProperties(ABC):
|
||||
"""IP properties object for Network Manager."""
|
||||
|
||||
method: str | None
|
||||
address_data: list[IpAddress] | None
|
||||
gateway: str | None
|
||||
dns: list[bytes | int] | None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Ip4Properties(IpProperties):
|
||||
"""IPv4 properties object."""
|
||||
|
||||
dns: list[int] | None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -83,6 +90,7 @@ class Ip6Properties(IpProperties):
|
||||
|
||||
addr_gen_mode: int
|
||||
ip6_privacy: int
|
||||
dns: list[bytes] | None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
@ -12,9 +12,9 @@ from ...utils import dbus_connected
|
||||
from ..configuration import (
|
||||
ConnectionProperties,
|
||||
EthernetProperties,
|
||||
Ip4Properties,
|
||||
Ip6Properties,
|
||||
IpAddress,
|
||||
IpProperties,
|
||||
MatchProperties,
|
||||
VlanProperties,
|
||||
WirelessProperties,
|
||||
@ -115,7 +115,7 @@ class NetworkSetting(DBusInterface):
|
||||
self._wireless_security: WirelessSecurityProperties | None = None
|
||||
self._ethernet: EthernetProperties | None = None
|
||||
self._vlan: VlanProperties | None = None
|
||||
self._ipv4: IpProperties | None = None
|
||||
self._ipv4: Ip4Properties | None = None
|
||||
self._ipv6: Ip6Properties | None = None
|
||||
self._match: MatchProperties | None = None
|
||||
super().__init__()
|
||||
@ -151,7 +151,7 @@ class NetworkSetting(DBusInterface):
|
||||
return self._vlan
|
||||
|
||||
@property
|
||||
def ipv4(self) -> IpProperties | None:
|
||||
def ipv4(self) -> Ip4Properties | None:
|
||||
"""Return ipv4 properties if any."""
|
||||
return self._ipv4
|
||||
|
||||
@ -271,16 +271,23 @@ class NetworkSetting(DBusInterface):
|
||||
)
|
||||
|
||||
if CONF_ATTR_VLAN in data:
|
||||
self._vlan = VlanProperties(
|
||||
id=data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_ID),
|
||||
parent=data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_PARENT),
|
||||
)
|
||||
if CONF_ATTR_VLAN_ID in data[CONF_ATTR_VLAN]:
|
||||
self._vlan = VlanProperties(
|
||||
data[CONF_ATTR_VLAN][CONF_ATTR_VLAN_ID],
|
||||
data[CONF_ATTR_VLAN].get(CONF_ATTR_VLAN_PARENT),
|
||||
)
|
||||
else:
|
||||
self._vlan = None
|
||||
_LOGGER.warning(
|
||||
"Network settings for vlan connection %s missing required vlan id, cannot process it",
|
||||
self.connection.interface_name,
|
||||
)
|
||||
|
||||
if CONF_ATTR_IPV4 in data:
|
||||
address_data = None
|
||||
if ips := data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_ADDRESS_DATA):
|
||||
address_data = [IpAddress(ip["address"], ip["prefix"]) for ip in ips]
|
||||
self._ipv4 = IpProperties(
|
||||
self._ipv4 = Ip4Properties(
|
||||
method=data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_METHOD),
|
||||
address_data=address_data,
|
||||
gateway=data[CONF_ATTR_IPV4].get(CONF_ATTR_IPV4_GATEWAY),
|
||||
|
@ -222,8 +222,10 @@ def get_connection_from_interface(
|
||||
}
|
||||
elif interface.type == "vlan":
|
||||
parent = cast(VlanConfig, interface.vlan).interface
|
||||
if parent in network_manager and (
|
||||
parent_connection := network_manager.get(parent).connection
|
||||
if (
|
||||
parent
|
||||
and parent in network_manager
|
||||
and (parent_connection := network_manager.get(parent).connection)
|
||||
):
|
||||
parent = parent_connection.uuid
|
||||
|
||||
|
@ -10,6 +10,7 @@ from dbus_fast.aio.message_bus import MessageBus
|
||||
from ..exceptions import DBusError, DBusInterfaceError, DBusServiceUnkownError
|
||||
from ..utils.dt import get_time_zone, utc_from_timestamp
|
||||
from .const import (
|
||||
DBUS_ATTR_LOCAL_RTC,
|
||||
DBUS_ATTR_NTP,
|
||||
DBUS_ATTR_NTPSYNCHRONIZED,
|
||||
DBUS_ATTR_TIMEUSEC,
|
||||
@ -46,6 +47,12 @@ class TimeDate(DBusInterfaceProxy):
|
||||
"""Return host timezone."""
|
||||
return self.properties[DBUS_ATTR_TIMEZONE]
|
||||
|
||||
@property
|
||||
@dbus_property
|
||||
def local_rtc(self) -> bool:
|
||||
"""Return whether rtc is local time or utc."""
|
||||
return self.properties[DBUS_ATTR_LOCAL_RTC]
|
||||
|
||||
@property
|
||||
@dbus_property
|
||||
def ntp(self) -> bool:
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv6Address, IPv6Interface
|
||||
import logging
|
||||
import socket
|
||||
|
||||
from ..dbus.const import (
|
||||
@ -23,6 +24,8 @@ from .const import (
|
||||
WifiMode,
|
||||
)
|
||||
|
||||
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AccessPoint:
|
||||
@ -79,7 +82,7 @@ class VlanConfig:
|
||||
"""Represent a vlan configuration."""
|
||||
|
||||
id: int
|
||||
interface: str
|
||||
interface: str | None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -108,7 +111,10 @@ class Interface:
|
||||
if inet.settings.match and inet.settings.match.path:
|
||||
return inet.settings.match.path == [self.path]
|
||||
|
||||
return inet.settings.connection.interface_name == self.name
|
||||
return (
|
||||
inet.settings.connection is not None
|
||||
and inet.settings.connection.interface_name == self.name
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_dbus_interface(inet: NetworkInterface) -> "Interface":
|
||||
@ -160,23 +166,23 @@ class Interface:
|
||||
ipv6_setting = Ip6Setting(InterfaceMethod.DISABLED, [], None, [])
|
||||
|
||||
ipv4_ready = (
|
||||
bool(inet.connection)
|
||||
inet.connection is not None
|
||||
and ConnectionStateFlags.IP4_READY in inet.connection.state_flags
|
||||
)
|
||||
ipv6_ready = (
|
||||
bool(inet.connection)
|
||||
inet.connection is not None
|
||||
and ConnectionStateFlags.IP6_READY in inet.connection.state_flags
|
||||
)
|
||||
|
||||
return Interface(
|
||||
inet.name,
|
||||
inet.hw_address,
|
||||
inet.path,
|
||||
inet.settings is not None,
|
||||
Interface._map_nm_connected(inet.connection),
|
||||
inet.primary,
|
||||
Interface._map_nm_type(inet.type),
|
||||
IpConfig(
|
||||
name=inet.name,
|
||||
mac=inet.hw_address,
|
||||
path=inet.path,
|
||||
enabled=inet.settings is not None,
|
||||
connected=Interface._map_nm_connected(inet.connection),
|
||||
primary=inet.primary,
|
||||
type=Interface._map_nm_type(inet.type),
|
||||
ipv4=IpConfig(
|
||||
address=inet.connection.ipv4.address
|
||||
if inet.connection.ipv4.address
|
||||
else [],
|
||||
@ -188,8 +194,8 @@ class Interface:
|
||||
)
|
||||
if inet.connection and inet.connection.ipv4
|
||||
else IpConfig([], None, [], ipv4_ready),
|
||||
ipv4_setting,
|
||||
IpConfig(
|
||||
ipv4setting=ipv4_setting,
|
||||
ipv6=IpConfig(
|
||||
address=inet.connection.ipv6.address
|
||||
if inet.connection.ipv6.address
|
||||
else [],
|
||||
@ -201,30 +207,28 @@ class Interface:
|
||||
)
|
||||
if inet.connection and inet.connection.ipv6
|
||||
else IpConfig([], None, [], ipv6_ready),
|
||||
ipv6_setting,
|
||||
Interface._map_nm_wifi(inet),
|
||||
Interface._map_nm_vlan(inet),
|
||||
ipv6setting=ipv6_setting,
|
||||
wifi=Interface._map_nm_wifi(inet),
|
||||
vlan=Interface._map_nm_vlan(inet),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _map_nm_method(method: str) -> InterfaceMethod:
|
||||
def _map_nm_method(method: str | None) -> InterfaceMethod:
|
||||
"""Map IP interface method."""
|
||||
mapping = {
|
||||
NMInterfaceMethod.AUTO: InterfaceMethod.AUTO,
|
||||
NMInterfaceMethod.DISABLED: InterfaceMethod.DISABLED,
|
||||
NMInterfaceMethod.MANUAL: InterfaceMethod.STATIC,
|
||||
NMInterfaceMethod.LINK_LOCAL: InterfaceMethod.DISABLED,
|
||||
}
|
||||
|
||||
return mapping.get(method, InterfaceMethod.DISABLED)
|
||||
match method:
|
||||
case NMInterfaceMethod.AUTO.value:
|
||||
return InterfaceMethod.AUTO
|
||||
case NMInterfaceMethod.MANUAL:
|
||||
return InterfaceMethod.STATIC
|
||||
return InterfaceMethod.DISABLED
|
||||
|
||||
@staticmethod
|
||||
def _map_nm_addr_gen_mode(addr_gen_mode: int) -> InterfaceAddrGenMode:
|
||||
"""Map IPv6 interface addr_gen_mode."""
|
||||
mapping = {
|
||||
NMInterfaceAddrGenMode.EUI64: InterfaceAddrGenMode.EUI64,
|
||||
NMInterfaceAddrGenMode.STABLE_PRIVACY: InterfaceAddrGenMode.STABLE_PRIVACY,
|
||||
NMInterfaceAddrGenMode.DEFAULT_OR_EUI64: InterfaceAddrGenMode.DEFAULT_OR_EUI64,
|
||||
NMInterfaceAddrGenMode.EUI64.value: InterfaceAddrGenMode.EUI64,
|
||||
NMInterfaceAddrGenMode.STABLE_PRIVACY.value: InterfaceAddrGenMode.STABLE_PRIVACY,
|
||||
NMInterfaceAddrGenMode.DEFAULT_OR_EUI64.value: InterfaceAddrGenMode.DEFAULT_OR_EUI64,
|
||||
}
|
||||
|
||||
return mapping.get(addr_gen_mode, InterfaceAddrGenMode.DEFAULT)
|
||||
@ -233,9 +237,9 @@ class Interface:
|
||||
def _map_nm_ip6_privacy(ip6_privacy: int) -> InterfaceIp6Privacy:
|
||||
"""Map IPv6 interface ip6_privacy."""
|
||||
mapping = {
|
||||
NMInterfaceIp6Privacy.DISABLED: InterfaceIp6Privacy.DISABLED,
|
||||
NMInterfaceIp6Privacy.ENABLED_PREFER_PUBLIC: InterfaceIp6Privacy.ENABLED_PREFER_PUBLIC,
|
||||
NMInterfaceIp6Privacy.ENABLED: InterfaceIp6Privacy.ENABLED,
|
||||
NMInterfaceIp6Privacy.DISABLED.value: InterfaceIp6Privacy.DISABLED,
|
||||
NMInterfaceIp6Privacy.ENABLED_PREFER_PUBLIC.value: InterfaceIp6Privacy.ENABLED_PREFER_PUBLIC,
|
||||
NMInterfaceIp6Privacy.ENABLED.value: InterfaceIp6Privacy.ENABLED,
|
||||
}
|
||||
|
||||
return mapping.get(ip6_privacy, InterfaceIp6Privacy.DEFAULT)
|
||||
@ -253,12 +257,14 @@ class Interface:
|
||||
|
||||
@staticmethod
|
||||
def _map_nm_type(device_type: int) -> InterfaceType:
|
||||
mapping = {
|
||||
DeviceType.ETHERNET: InterfaceType.ETHERNET,
|
||||
DeviceType.WIRELESS: InterfaceType.WIRELESS,
|
||||
DeviceType.VLAN: InterfaceType.VLAN,
|
||||
}
|
||||
return mapping[device_type]
|
||||
match device_type:
|
||||
case DeviceType.ETHERNET.value:
|
||||
return InterfaceType.ETHERNET
|
||||
case DeviceType.WIRELESS.value:
|
||||
return InterfaceType.WIRELESS
|
||||
case DeviceType.VLAN.value:
|
||||
return InterfaceType.VLAN
|
||||
raise ValueError(f"Invalid device type: {device_type}")
|
||||
|
||||
@staticmethod
|
||||
def _map_nm_wifi(inet: NetworkInterface) -> WifiConfig | None:
|
||||
@ -267,15 +273,22 @@ class Interface:
|
||||
return None
|
||||
|
||||
# Authentication and PSK
|
||||
auth = None
|
||||
auth = AuthMethod.OPEN
|
||||
psk = None
|
||||
if not inet.settings.wireless_security:
|
||||
auth = AuthMethod.OPEN
|
||||
elif inet.settings.wireless_security.key_mgmt == "none":
|
||||
auth = AuthMethod.WEP
|
||||
elif inet.settings.wireless_security.key_mgmt == "wpa-psk":
|
||||
auth = AuthMethod.WPA_PSK
|
||||
psk = inet.settings.wireless_security.psk
|
||||
if inet.settings.wireless_security:
|
||||
match inet.settings.wireless_security.key_mgmt:
|
||||
case "none":
|
||||
auth = AuthMethod.WEP
|
||||
case "wpa-psk":
|
||||
auth = AuthMethod.WPA_PSK
|
||||
psk = inet.settings.wireless_security.psk
|
||||
case _:
|
||||
_LOGGER.warning(
|
||||
"Auth method %s for network interface %s unsupported, skipping",
|
||||
inet.settings.wireless_security.key_mgmt,
|
||||
inet.name,
|
||||
)
|
||||
return None
|
||||
|
||||
# WifiMode
|
||||
mode = WifiMode.INFRASTRUCTURE
|
||||
@ -289,17 +302,17 @@ class Interface:
|
||||
signal = None
|
||||
|
||||
return WifiConfig(
|
||||
mode,
|
||||
inet.settings.wireless.ssid,
|
||||
auth,
|
||||
psk,
|
||||
signal,
|
||||
mode=mode,
|
||||
ssid=inet.settings.wireless.ssid if inet.settings.wireless else "",
|
||||
auth=auth,
|
||||
psk=psk,
|
||||
signal=signal,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _map_nm_vlan(inet: NetworkInterface) -> WifiConfig | None:
|
||||
def _map_nm_vlan(inet: NetworkInterface) -> VlanConfig | None:
|
||||
"""Create mapping to nm vlan property."""
|
||||
if inet.type != DeviceType.VLAN or not inet.settings:
|
||||
if inet.type != DeviceType.VLAN or not inet.settings or not inet.settings.vlan:
|
||||
return None
|
||||
|
||||
return VlanConfig(inet.settings.vlan.id, inet.settings.vlan.parent)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from contextlib import asynccontextmanager
|
||||
import json
|
||||
import logging
|
||||
@ -205,7 +205,7 @@ class LogsControl(CoreSysAttributes):
|
||||
async def journald_logs(
|
||||
self,
|
||||
path: str = "/entries",
|
||||
params: dict[str, str | list[str]] | None = None,
|
||||
params: Mapping[str, str | list[str]] | None = None,
|
||||
range_header: str | None = None,
|
||||
accept: LogFormat = LogFormat.TEXT,
|
||||
timeout: ClientTimeout | None = None,
|
||||
@ -226,7 +226,7 @@ class LogsControl(CoreSysAttributes):
|
||||
base_url = "http://localhost/"
|
||||
connector = UnixConnector(path=str(SYSTEMD_JOURNAL_GATEWAYD_SOCKET))
|
||||
async with ClientSession(base_url=base_url, connector=connector) as session:
|
||||
headers = {ACCEPT: accept}
|
||||
headers = {ACCEPT: accept.value}
|
||||
if range_header:
|
||||
if range_header.endswith(":"):
|
||||
# Make sure that num_entries is always set - before Systemd v256 it was
|
||||
|
@ -87,7 +87,7 @@ class NetworkManager(CoreSysAttributes):
|
||||
for config in self.sys_dbus.network.dns.configuration:
|
||||
if config.vpn or not config.nameservers:
|
||||
continue
|
||||
servers.extend(config.nameservers)
|
||||
servers.extend([str(ns) for ns in config.nameservers])
|
||||
|
||||
return list(dict.fromkeys(servers))
|
||||
|
||||
@ -197,10 +197,16 @@ class NetworkManager(CoreSysAttributes):
|
||||
with suppress(NetworkInterfaceNotFound):
|
||||
inet = self.sys_dbus.network.get(interface.name)
|
||||
|
||||
con: NetworkConnection = None
|
||||
con: NetworkConnection | None = None
|
||||
|
||||
# Update exist configuration
|
||||
if inet and interface.equals_dbus_interface(inet) and interface.enabled:
|
||||
if (
|
||||
inet
|
||||
and inet.settings
|
||||
and inet.settings.connection
|
||||
and interface.equals_dbus_interface(inet)
|
||||
and interface.enabled
|
||||
):
|
||||
_LOGGER.debug("Updating existing configuration for %s", interface.name)
|
||||
settings = get_connection_from_interface(
|
||||
interface,
|
||||
@ -211,12 +217,12 @@ class NetworkManager(CoreSysAttributes):
|
||||
|
||||
try:
|
||||
await inet.settings.update(settings)
|
||||
con = await self.sys_dbus.network.activate_connection(
|
||||
con = activated = await self.sys_dbus.network.activate_connection(
|
||||
inet.settings.object_path, inet.object_path
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"activate_connection returns %s",
|
||||
con.object_path,
|
||||
activated.object_path,
|
||||
)
|
||||
except DBusError as err:
|
||||
raise HostNetworkError(
|
||||
@ -236,12 +242,16 @@ class NetworkManager(CoreSysAttributes):
|
||||
settings = get_connection_from_interface(interface, self.sys_dbus.network)
|
||||
|
||||
try:
|
||||
settings, con = await self.sys_dbus.network.add_and_activate_connection(
|
||||
(
|
||||
settings,
|
||||
activated,
|
||||
) = await self.sys_dbus.network.add_and_activate_connection(
|
||||
settings, inet.object_path
|
||||
)
|
||||
con = activated
|
||||
_LOGGER.debug(
|
||||
"add_and_activate_connection returns %s",
|
||||
con.object_path,
|
||||
activated.object_path,
|
||||
)
|
||||
except DBusError as err:
|
||||
raise HostNetworkError(
|
||||
@ -277,7 +287,7 @@ class NetworkManager(CoreSysAttributes):
|
||||
)
|
||||
|
||||
if con:
|
||||
async with con.dbus.signal(
|
||||
async with con.connected_dbus.signal(
|
||||
DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED
|
||||
) as signal:
|
||||
# From this point we monitor signals. However, it might be that
|
||||
@ -303,7 +313,7 @@ class NetworkManager(CoreSysAttributes):
|
||||
"""Scan on Interface for AccessPoint."""
|
||||
inet = self.sys_dbus.network.get(interface.name)
|
||||
|
||||
if inet.type != DeviceType.WIRELESS:
|
||||
if inet.type != DeviceType.WIRELESS or not inet.wireless:
|
||||
raise HostNotSupportedError(
|
||||
f"Can only scan with wireless card - {interface.name}", _LOGGER.error
|
||||
)
|
||||
|
@ -1,13 +1,13 @@
|
||||
"""Supervisor job manager."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import contextmanager
|
||||
from collections.abc import Callable, Coroutine, Generator
|
||||
from contextlib import contextmanager, suppress
|
||||
from contextvars import Context, ContextVar, Token
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Self
|
||||
from uuid import uuid4
|
||||
|
||||
from attrs import Attribute, define, field
|
||||
@ -27,7 +27,7 @@ from .validate import SCHEMA_JOBS_CONFIG
|
||||
# When a new asyncio task is started the current context is copied over.
|
||||
# Modifications to it in one task are not visible to others though.
|
||||
# This allows us to track what job is currently in progress in each task.
|
||||
_CURRENT_JOB: ContextVar[str] = ContextVar("current_job")
|
||||
_CURRENT_JOB: ContextVar[str | None] = ContextVar("current_job", default=None)
|
||||
|
||||
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@ -75,7 +75,7 @@ class SupervisorJobError:
|
||||
message: str = "Unknown error, see supervisor logs"
|
||||
stage: str | None = None
|
||||
|
||||
def as_dict(self) -> dict[str, str]:
|
||||
def as_dict(self) -> dict[str, str | None]:
|
||||
"""Return dictionary representation."""
|
||||
return {
|
||||
"type": self.type_.__name__,
|
||||
@ -101,9 +101,7 @@ class SupervisorJob:
|
||||
stage: str | None = field(
|
||||
default=None, validator=[_invalid_if_done], on_setattr=_on_change
|
||||
)
|
||||
parent_id: str | None = field(
|
||||
factory=lambda: _CURRENT_JOB.get(None), on_setattr=frozen
|
||||
)
|
||||
parent_id: str | None = field(factory=_CURRENT_JOB.get, on_setattr=frozen)
|
||||
done: bool | None = field(init=False, default=None, on_setattr=_on_change)
|
||||
on_change: Callable[["SupervisorJob", Attribute, Any], None] | None = field(
|
||||
default=None, on_setattr=frozen
|
||||
@ -137,7 +135,7 @@ class SupervisorJob:
|
||||
self.errors += [new_error]
|
||||
|
||||
@contextmanager
|
||||
def start(self):
|
||||
def start(self) -> Generator[Self]:
|
||||
"""Start the job in the current task.
|
||||
|
||||
This can only be called if the parent ID matches the job running in the current task.
|
||||
@ -146,11 +144,11 @@ class SupervisorJob:
|
||||
"""
|
||||
if self.done is not None:
|
||||
raise JobStartException("Job has already been started")
|
||||
if _CURRENT_JOB.get(None) != self.parent_id:
|
||||
if _CURRENT_JOB.get() != self.parent_id:
|
||||
raise JobStartException("Job has a different parent from current job")
|
||||
|
||||
self.done = False
|
||||
token: Token[str] | None = None
|
||||
token: Token[str | None] | None = None
|
||||
try:
|
||||
token = _CURRENT_JOB.set(self.uuid)
|
||||
yield self
|
||||
@ -193,17 +191,15 @@ class JobManager(FileConfiguration, CoreSysAttributes):
|
||||
|
||||
Must be called from within a job. Raises RuntimeError if there is no current job.
|
||||
"""
|
||||
try:
|
||||
return self.get_job(_CURRENT_JOB.get())
|
||||
except (LookupError, JobNotFound):
|
||||
raise RuntimeError(
|
||||
"No job for the current asyncio task!", _LOGGER.critical
|
||||
) from None
|
||||
if job_id := _CURRENT_JOB.get():
|
||||
with suppress(JobNotFound):
|
||||
return self.get_job(job_id)
|
||||
raise RuntimeError("No job for the current asyncio task!", _LOGGER.critical)
|
||||
|
||||
@property
|
||||
def is_job(self) -> bool:
|
||||
"""Return true if there is an active job for the current asyncio task."""
|
||||
return bool(_CURRENT_JOB.get(None))
|
||||
return _CURRENT_JOB.get() is not None
|
||||
|
||||
def _notify_on_job_change(
|
||||
self, job: SupervisorJob, attribute: Attribute, value: Any
|
||||
@ -265,7 +261,7 @@ class JobManager(FileConfiguration, CoreSysAttributes):
|
||||
|
||||
def schedule_job(
|
||||
self,
|
||||
job_method: Callable[..., Awaitable[Any]],
|
||||
job_method: Callable[..., Coroutine],
|
||||
options: JobSchedulerOptions,
|
||||
*args,
|
||||
**kwargs,
|
||||
|
@ -1,12 +1,12 @@
|
||||
"""Job decorator."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import suppress
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from ..const import CoreState
|
||||
from ..coresys import CoreSys, CoreSysAttributes
|
||||
@ -54,11 +54,10 @@ class Job(CoreSysAttributes):
|
||||
self.on_condition = on_condition
|
||||
self.limit = limit
|
||||
self._throttle_period = throttle_period
|
||||
self.throttle_max_calls = throttle_max_calls
|
||||
self._throttle_max_calls = throttle_max_calls
|
||||
self._lock: asyncio.Semaphore | None = None
|
||||
self._method = None
|
||||
self._last_call: dict[str | None, datetime] = {}
|
||||
self._rate_limited_calls: dict[str, list[datetime]] | None = None
|
||||
self._rate_limited_calls: dict[str | None, list[datetime]] | None = None
|
||||
self._internal = internal
|
||||
|
||||
# Validate Options
|
||||
@ -82,13 +81,29 @@ class Job(CoreSysAttributes):
|
||||
JobExecutionLimit.THROTTLE_RATE_LIMIT,
|
||||
JobExecutionLimit.GROUP_THROTTLE_RATE_LIMIT,
|
||||
):
|
||||
if self.throttle_max_calls is None:
|
||||
if self._throttle_max_calls is None:
|
||||
raise RuntimeError(
|
||||
f"Job {name} is using execution limit {limit} without throttle max calls!"
|
||||
)
|
||||
|
||||
self._rate_limited_calls = {}
|
||||
|
||||
@property
|
||||
def throttle_max_calls(self) -> int:
|
||||
"""Return max calls for throttle."""
|
||||
if self._throttle_max_calls is None:
|
||||
raise RuntimeError("No throttle max calls set for job!")
|
||||
return self._throttle_max_calls
|
||||
|
||||
@property
|
||||
def lock(self) -> asyncio.Semaphore:
|
||||
"""Return lock for limits."""
|
||||
# asyncio.Semaphore objects must be created in event loop
|
||||
# Since this is sync code it is not safe to create if missing here
|
||||
if not self._lock:
|
||||
raise RuntimeError("Lock has not been created yet!")
|
||||
return self._lock
|
||||
|
||||
def last_call(self, group_name: str | None = None) -> datetime:
|
||||
"""Return last call datetime."""
|
||||
return self._last_call.get(group_name, datetime.min)
|
||||
@ -97,12 +112,12 @@ class Job(CoreSysAttributes):
|
||||
"""Set last call datetime."""
|
||||
self._last_call[group_name] = value
|
||||
|
||||
def rate_limited_calls(
|
||||
self, group_name: str | None = None
|
||||
) -> list[datetime] | None:
|
||||
def rate_limited_calls(self, group_name: str | None = None) -> list[datetime]:
|
||||
"""Return rate limited calls if used."""
|
||||
if self._rate_limited_calls is None:
|
||||
return None
|
||||
raise RuntimeError(
|
||||
f"Rate limited calls not available for limit type {self.limit}"
|
||||
)
|
||||
|
||||
return self._rate_limited_calls.get(group_name, [])
|
||||
|
||||
@ -131,10 +146,10 @@ class Job(CoreSysAttributes):
|
||||
|
||||
self._rate_limited_calls[group_name] = value
|
||||
|
||||
def throttle_period(self, group_name: str | None = None) -> timedelta | None:
|
||||
def throttle_period(self, group_name: str | None = None) -> timedelta:
|
||||
"""Return throttle period."""
|
||||
if self._throttle_period is None:
|
||||
return None
|
||||
raise RuntimeError("No throttle period set for Job!")
|
||||
|
||||
if isinstance(self._throttle_period, timedelta):
|
||||
return self._throttle_period
|
||||
@ -142,7 +157,7 @@ class Job(CoreSysAttributes):
|
||||
return self._throttle_period(
|
||||
self.coresys,
|
||||
self.last_call(group_name),
|
||||
self.rate_limited_calls(group_name),
|
||||
self.rate_limited_calls(group_name) if self._rate_limited_calls else None,
|
||||
)
|
||||
|
||||
def _post_init(self, obj: JobGroup | CoreSysAttributes) -> JobGroup | None:
|
||||
@ -158,12 +173,12 @@ class Job(CoreSysAttributes):
|
||||
self._lock = asyncio.Semaphore()
|
||||
|
||||
# Job groups
|
||||
try:
|
||||
is_job_group = obj.acquire and obj.release
|
||||
except AttributeError:
|
||||
is_job_group = False
|
||||
job_group: JobGroup | None = None
|
||||
with suppress(AttributeError):
|
||||
if obj.acquire and obj.release: # type: ignore
|
||||
job_group = cast(JobGroup, obj)
|
||||
|
||||
if not is_job_group and self.limit in (
|
||||
if not job_group and self.limit in (
|
||||
JobExecutionLimit.GROUP_ONCE,
|
||||
JobExecutionLimit.GROUP_WAIT,
|
||||
JobExecutionLimit.GROUP_THROTTLE,
|
||||
@ -174,7 +189,7 @@ class Job(CoreSysAttributes):
|
||||
f"Job on {self.name} need to be a JobGroup to use group based limits!"
|
||||
) from None
|
||||
|
||||
return obj if is_job_group else None
|
||||
return job_group
|
||||
|
||||
def _handle_job_condition_exception(self, err: JobConditionException) -> None:
|
||||
"""Handle a job condition failure."""
|
||||
@ -184,9 +199,8 @@ class Job(CoreSysAttributes):
|
||||
return
|
||||
raise self.on_condition(error_msg, _LOGGER.warning) from None
|
||||
|
||||
def __call__(self, method):
|
||||
def __call__(self, method: Callable[..., Awaitable]):
|
||||
"""Call the wrapper logic."""
|
||||
self._method = method
|
||||
|
||||
@wraps(method)
|
||||
async def wrapper(
|
||||
@ -221,7 +235,7 @@ class Job(CoreSysAttributes):
|
||||
if self.conditions:
|
||||
try:
|
||||
await Job.check_conditions(
|
||||
self, set(self.conditions), self._method.__qualname__
|
||||
self, set(self.conditions), method.__qualname__
|
||||
)
|
||||
except JobConditionException as err:
|
||||
return self._handle_job_condition_exception(err)
|
||||
@ -237,7 +251,7 @@ class Job(CoreSysAttributes):
|
||||
JobExecutionLimit.GROUP_WAIT,
|
||||
):
|
||||
try:
|
||||
await obj.acquire(
|
||||
await cast(JobGroup, job_group).acquire(
|
||||
job, self.limit == JobExecutionLimit.GROUP_WAIT
|
||||
)
|
||||
except JobGroupExecutionLimitExceeded as err:
|
||||
@ -296,12 +310,12 @@ class Job(CoreSysAttributes):
|
||||
with job.start():
|
||||
try:
|
||||
self.set_last_call(datetime.now(), group_name)
|
||||
if self.rate_limited_calls(group_name) is not None:
|
||||
if self._rate_limited_calls is not None:
|
||||
self.add_rate_limited_call(
|
||||
self.last_call(group_name), group_name
|
||||
)
|
||||
|
||||
return await self._method(obj, *args, **kwargs)
|
||||
return await method(obj, *args, **kwargs)
|
||||
|
||||
# If a method has a conditional JobCondition, they must check it in the method
|
||||
# These should be handled like normal JobConditions as much as possible
|
||||
@ -317,11 +331,11 @@ class Job(CoreSysAttributes):
|
||||
raise JobException() from err
|
||||
finally:
|
||||
self._release_exception_limits()
|
||||
if self.limit in (
|
||||
if job_group and self.limit in (
|
||||
JobExecutionLimit.GROUP_ONCE,
|
||||
JobExecutionLimit.GROUP_WAIT,
|
||||
):
|
||||
obj.release()
|
||||
job_group.release()
|
||||
|
||||
# Jobs that weren't started are always cleaned up. Also clean up done jobs if required
|
||||
finally:
|
||||
@ -473,13 +487,13 @@ class Job(CoreSysAttributes):
|
||||
):
|
||||
return
|
||||
|
||||
if self.limit == JobExecutionLimit.ONCE and self._lock.locked():
|
||||
if self.limit == JobExecutionLimit.ONCE and self.lock.locked():
|
||||
on_condition = (
|
||||
JobException if self.on_condition is None else self.on_condition
|
||||
)
|
||||
raise on_condition("Another job is running")
|
||||
|
||||
await self._lock.acquire()
|
||||
await self.lock.acquire()
|
||||
|
||||
def _release_exception_limits(self) -> None:
|
||||
"""Release possible exception limits."""
|
||||
@ -490,4 +504,4 @@ class Job(CoreSysAttributes):
|
||||
JobExecutionLimit.GROUP_THROTTLE_WAIT,
|
||||
):
|
||||
return
|
||||
self._lock.release()
|
||||
self.lock.release()
|
||||
|
@ -41,7 +41,7 @@ class JobGroup(CoreSysAttributes):
|
||||
def has_lock(self) -> bool:
|
||||
"""Return true if current task has the lock on this job group."""
|
||||
return (
|
||||
self.active_job
|
||||
self.active_job is not None
|
||||
and self.sys_jobs.is_job
|
||||
and self.active_job == self.sys_jobs.current
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user