mirror of
https://github.com/home-assistant/supervisor.git
synced 2025-07-15 05:06:30 +00:00
Add throttle to job execution (#2631)
* Add throttle to job execution * fix unittests * Add tests * address comments * add comment * better on __init__ * New text * Simplify logic
This commit is contained in:
parent
78d9c60be5
commit
31f5033dca
@ -7,7 +7,8 @@ from typing import Dict, Optional, Union
|
||||
from ruamel.yaml import YAML, YAMLError
|
||||
|
||||
from ..coresys import CoreSys, CoreSysAttributes
|
||||
from ..utils import AsyncThrottle
|
||||
from ..jobs.const import JobExecutionLimit
|
||||
from ..jobs.decorator import Job
|
||||
|
||||
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@ -40,7 +41,7 @@ class HomeAssistantSecrets(CoreSysAttributes):
|
||||
"""Reload secrets."""
|
||||
await self._read_secrets()
|
||||
|
||||
@AsyncThrottle(timedelta(seconds=60))
|
||||
@Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(seconds=60))
|
||||
async def _read_secrets(self):
|
||||
"""Read secrets.yaml into memory."""
|
||||
if not self.path_secrets.exists():
|
||||
|
@ -9,7 +9,8 @@ from pulsectl import Pulse, PulseError, PulseIndexError, PulseOperationFailed
|
||||
|
||||
from ..coresys import CoreSys, CoreSysAttributes
|
||||
from ..exceptions import PulseAudioError
|
||||
from ..utils import AsyncThrottle
|
||||
from ..jobs.const import JobExecutionLimit
|
||||
from ..jobs.decorator import Job
|
||||
|
||||
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@ -217,7 +218,7 @@ class SoundControl(CoreSysAttributes):
|
||||
await self.sys_run_in_executor(_activate_profile)
|
||||
await self.update()
|
||||
|
||||
@AsyncThrottle(timedelta(seconds=10))
|
||||
@Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(seconds=10))
|
||||
async def update(self):
|
||||
"""Update properties over dbus."""
|
||||
_LOGGER.info("Updating PulseAudio information")
|
||||
|
@ -23,3 +23,5 @@ class JobExecutionLimit(str, Enum):
|
||||
"""Job Execution limits."""
|
||||
|
||||
SINGLE_WAIT = "single_wait"
|
||||
THROTTLE = "throttle"
|
||||
THROTTLE_WAIT = "throttle_wait"
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Job decorator."""
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
import logging
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
@ -24,6 +26,7 @@ class Job(CoreSysAttributes):
|
||||
cleanup: bool = True,
|
||||
on_condition: Optional[JobException] = None,
|
||||
limit: Optional[JobExecutionLimit] = None,
|
||||
throttle_period: Optional[timedelta] = None,
|
||||
):
|
||||
"""Initialize the Job class."""
|
||||
self.name = name
|
||||
@ -31,8 +34,17 @@ class Job(CoreSysAttributes):
|
||||
self.cleanup = cleanup
|
||||
self.on_condition = on_condition
|
||||
self.limit = limit
|
||||
self.throttle_period = throttle_period
|
||||
self._lock: Optional[asyncio.Semaphore] = None
|
||||
self._method = None
|
||||
self._last_call = datetime.min
|
||||
|
||||
# Validate Options
|
||||
if (
|
||||
self.limit in (JobExecutionLimit.THROTTLE, JobExecutionLimit.THROTTLE_WAIT)
|
||||
and self.throttle_period is None
|
||||
):
|
||||
raise RuntimeError("Using Job without a Throttle period!")
|
||||
|
||||
def _post_init(self, args: Tuple[Any]) -> None:
|
||||
"""Runtime init."""
|
||||
@ -45,8 +57,9 @@ class Job(CoreSysAttributes):
|
||||
except AttributeError:
|
||||
pass
|
||||
if not self.coresys:
|
||||
raise JobException(f"coresys is missing on {self.name}")
|
||||
raise RuntimeError(f"Job on {self.name} need to be an coresys object!")
|
||||
|
||||
# Others
|
||||
if self._lock is None:
|
||||
self._lock = asyncio.Semaphore()
|
||||
|
||||
@ -54,6 +67,7 @@ class Job(CoreSysAttributes):
|
||||
"""Call the wrapper logic."""
|
||||
self._method = method
|
||||
|
||||
@wraps(method)
|
||||
async def wrapper(*args, **kwargs) -> Any:
|
||||
"""Wrap the method."""
|
||||
self._post_init(args)
|
||||
@ -67,11 +81,22 @@ class Job(CoreSysAttributes):
|
||||
raise self.on_condition()
|
||||
|
||||
# Handle exection limits
|
||||
if self.limit:
|
||||
if self.limit == JobExecutionLimit.SINGLE_WAIT:
|
||||
await self._acquire_exection_limit()
|
||||
elif self.limit == JobExecutionLimit.THROTTLE:
|
||||
time_since_last_call = datetime.now() - self._last_call
|
||||
if time_since_last_call < self.throttle_period:
|
||||
return
|
||||
elif self.limit == JobExecutionLimit.THROTTLE_WAIT:
|
||||
await self._acquire_exection_limit()
|
||||
time_since_last_call = datetime.now() - self._last_call
|
||||
if time_since_last_call < self.throttle_period:
|
||||
self._release_exception_limits()
|
||||
return
|
||||
|
||||
# Execute Job
|
||||
try:
|
||||
self._last_call = datetime.now()
|
||||
return await self._method(*args, **kwargs)
|
||||
except HassioError as err:
|
||||
raise err
|
||||
@ -155,12 +180,18 @@ class Job(CoreSysAttributes):
|
||||
|
||||
async def _acquire_exection_limit(self) -> None:
|
||||
"""Process exection limits."""
|
||||
|
||||
if self.limit == JobExecutionLimit.SINGLE_WAIT:
|
||||
if self.limit not in (
|
||||
JobExecutionLimit.SINGLE_WAIT,
|
||||
JobExecutionLimit.THROTTLE_WAIT,
|
||||
):
|
||||
return
|
||||
await self._lock.acquire()
|
||||
|
||||
def _release_exception_limits(self) -> None:
|
||||
"""Release possible exception limits."""
|
||||
|
||||
if self.limit == JobExecutionLimit.SINGLE_WAIT:
|
||||
if self.limit not in (
|
||||
JobExecutionLimit.SINGLE_WAIT,
|
||||
JobExecutionLimit.THROTTLE_WAIT,
|
||||
):
|
||||
return
|
||||
self._lock.release()
|
||||
|
@ -1,10 +1,11 @@
|
||||
"""Helpers to check core security."""
|
||||
from contextlib import suppress
|
||||
from datetime import timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from ...const import AddonState, CoreState
|
||||
from ...exceptions import PwnedError
|
||||
from ...jobs.const import JobCondition
|
||||
from ...jobs.const import JobCondition, JobExecutionLimit
|
||||
from ...jobs.decorator import Job
|
||||
from ...utils.pwned import check_pwned_password
|
||||
from ..const import ContextType, IssueType, SuggestionType
|
||||
@ -14,7 +15,11 @@ from .base import CheckBase
|
||||
class CheckAddonPwned(CheckBase):
|
||||
"""CheckAddonPwned class for check."""
|
||||
|
||||
@Job(conditions=[JobCondition.INTERNET_SYSTEM])
|
||||
@Job(
|
||||
conditions=[JobCondition.INTERNET_SYSTEM],
|
||||
limit=JobExecutionLimit.THROTTLE,
|
||||
throttle_period=timedelta(hours=24),
|
||||
)
|
||||
async def run_check(self) -> None:
|
||||
"""Run check if not affected by issue."""
|
||||
await self.sys_homeassistant.secrets.reload()
|
||||
|
@ -9,6 +9,8 @@ from typing import Optional
|
||||
import aiohttp
|
||||
from awesomeversion import AwesomeVersion
|
||||
|
||||
from supervisor.jobs.const import JobExecutionLimit
|
||||
|
||||
from .const import (
|
||||
ATTR_AUDIO,
|
||||
ATTR_CHANNEL,
|
||||
@ -28,7 +30,6 @@ from .const import (
|
||||
from .coresys import CoreSysAttributes
|
||||
from .exceptions import UpdaterError, UpdaterJobError
|
||||
from .jobs.decorator import Job, JobCondition
|
||||
from .utils import AsyncThrottle
|
||||
from .utils.json import JsonConfig
|
||||
from .validate import SCHEMA_UPDATER_CONFIG
|
||||
|
||||
@ -165,10 +166,11 @@ class Updater(JsonConfig, CoreSysAttributes):
|
||||
"""Set upstream mode."""
|
||||
self._data[ATTR_CHANNEL] = value
|
||||
|
||||
@AsyncThrottle(timedelta(seconds=30))
|
||||
@Job(
|
||||
conditions=[JobCondition.INTERNET_SYSTEM],
|
||||
on_condition=UpdaterJobError,
|
||||
limit=JobExecutionLimit.THROTTLE_WAIT,
|
||||
throttle_period=timedelta(seconds=30),
|
||||
)
|
||||
async def fetch_data(self):
|
||||
"""Fetch current versions from Github.
|
||||
|
@ -1,12 +1,11 @@
|
||||
"""Tools file for Supervisor."""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import re
|
||||
import socket
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@ -35,65 +34,6 @@ def process_lock(method):
|
||||
return wrap_api
|
||||
|
||||
|
||||
class AsyncThrottle:
|
||||
"""A class for throttling the execution of tasks.
|
||||
|
||||
Decorator that prevents a function from being called more than once every
|
||||
time period with blocking.
|
||||
"""
|
||||
|
||||
def __init__(self, delta):
|
||||
"""Initialize async throttle."""
|
||||
self.throttle_period = delta
|
||||
self.time_of_last_call = datetime.min
|
||||
self.synchronize: Optional[asyncio.Lock] = None
|
||||
|
||||
def __call__(self, method):
|
||||
"""Throttle function."""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
"""Throttle function wrapper."""
|
||||
if not self.synchronize:
|
||||
self.synchronize = asyncio.Lock()
|
||||
|
||||
async with self.synchronize:
|
||||
now = datetime.now()
|
||||
time_since_last_call = now - self.time_of_last_call
|
||||
|
||||
if time_since_last_call > self.throttle_period:
|
||||
self.time_of_last_call = now
|
||||
return await method(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AsyncCallFilter:
|
||||
"""A class for throttling the execution of tasks, with a filter.
|
||||
|
||||
Decorator that prevents a function from being called more than once every
|
||||
time period.
|
||||
"""
|
||||
|
||||
def __init__(self, delta):
|
||||
"""Initialize async throttle."""
|
||||
self.throttle_period = delta
|
||||
self.time_of_last_call = datetime.min
|
||||
|
||||
def __call__(self, method):
|
||||
"""Throttle function."""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
"""Throttle function wrapper."""
|
||||
now = datetime.now()
|
||||
time_since_last_call = now - self.time_of_last_call
|
||||
|
||||
if time_since_last_call > self.throttle_period:
|
||||
self.time_of_last_call = now
|
||||
return await method(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_port(address: IPv4Address, port: int) -> bool:
|
||||
"""Check if port is mapped."""
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Test the condition decorators."""
|
||||
# pylint: disable=protected-access,import-error
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -284,3 +285,63 @@ async def test_exectution_limit_single_wait(
|
||||
test = TestClass(coresys)
|
||||
|
||||
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
|
||||
|
||||
|
||||
async def test_exectution_limit_throttle_wait(
|
||||
coresys: CoreSys, loop: asyncio.BaseEventLoop
|
||||
):
|
||||
"""Test the ignore conditions decorator."""
|
||||
|
||||
class TestClass:
|
||||
"""Test class."""
|
||||
|
||||
def __init__(self, coresys: CoreSys):
|
||||
"""Initialize the test class."""
|
||||
self.coresys = coresys
|
||||
self.run = asyncio.Lock()
|
||||
self.call = 0
|
||||
|
||||
@Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(hours=1))
|
||||
async def execute(self, sleep: float):
|
||||
"""Execute the class method."""
|
||||
assert not self.run.locked()
|
||||
async with self.run:
|
||||
await asyncio.sleep(sleep)
|
||||
self.call += 1
|
||||
|
||||
test = TestClass(coresys)
|
||||
|
||||
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
|
||||
assert test.call == 1
|
||||
|
||||
await asyncio.gather(*[test.execute(0.1)])
|
||||
assert test.call == 1
|
||||
|
||||
|
||||
async def test_exectution_limit_throttle(coresys: CoreSys, loop: asyncio.BaseEventLoop):
|
||||
"""Test the ignore conditions decorator."""
|
||||
|
||||
class TestClass:
|
||||
"""Test class."""
|
||||
|
||||
def __init__(self, coresys: CoreSys):
|
||||
"""Initialize the test class."""
|
||||
self.coresys = coresys
|
||||
self.run = asyncio.Lock()
|
||||
self.call = 0
|
||||
|
||||
@Job(limit=JobExecutionLimit.THROTTLE, throttle_period=timedelta(hours=1))
|
||||
async def execute(self, sleep: float):
|
||||
"""Execute the class method."""
|
||||
assert not self.run.locked()
|
||||
async with self.run:
|
||||
await asyncio.sleep(sleep)
|
||||
self.call += 1
|
||||
|
||||
test = TestClass(coresys)
|
||||
|
||||
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
|
||||
assert test.call == 1
|
||||
|
||||
await asyncio.gather(*[test.execute(0.1)])
|
||||
assert test.call == 1
|
||||
|
@ -30,7 +30,7 @@ async def test_check(coresys: CoreSys):
|
||||
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
|
||||
AsyncMock(return_value=True),
|
||||
) as mock:
|
||||
await addon_pwned.run_check()
|
||||
await addon_pwned.run_check.__wrapped__(addon_pwned)
|
||||
assert not mock.called
|
||||
|
||||
addon.pwned.add("123456")
|
||||
@ -38,7 +38,7 @@ async def test_check(coresys: CoreSys):
|
||||
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
|
||||
AsyncMock(return_value=False),
|
||||
) as mock:
|
||||
await addon_pwned.run_check()
|
||||
await addon_pwned.run_check.__wrapped__(addon_pwned)
|
||||
assert mock.called
|
||||
|
||||
assert len(coresys.resolution.issues) == 0
|
||||
@ -47,7 +47,7 @@ async def test_check(coresys: CoreSys):
|
||||
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
|
||||
AsyncMock(return_value=True),
|
||||
) as mock:
|
||||
await addon_pwned.run_check()
|
||||
await addon_pwned.run_check.__wrapped__(addon_pwned)
|
||||
assert mock.called
|
||||
|
||||
assert len(coresys.resolution.issues) == 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user