Add throttle to job execution (#2631)

* Add throttle to job execution

* fix unittests

* Add tests

* address comments

* add comment

* better on __init__

* New text

* Simplify logic
This commit is contained in:
Pascal Vizeli 2021-02-25 23:29:03 +01:00 committed by GitHub
parent 78d9c60be5
commit 31f5033dca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 123 additions and 80 deletions

View File

@ -7,7 +7,8 @@ from typing import Dict, Optional, Union
from ruamel.yaml import YAML, YAMLError
from ..coresys import CoreSys, CoreSysAttributes
from ..utils import AsyncThrottle
from ..jobs.const import JobExecutionLimit
from ..jobs.decorator import Job
_LOGGER: logging.Logger = logging.getLogger(__name__)
@ -40,7 +41,7 @@ class HomeAssistantSecrets(CoreSysAttributes):
"""Reload secrets."""
await self._read_secrets()
@AsyncThrottle(timedelta(seconds=60))
@Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(seconds=60))
async def _read_secrets(self):
"""Read secrets.yaml into memory."""
if not self.path_secrets.exists():

View File

@ -9,7 +9,8 @@ from pulsectl import Pulse, PulseError, PulseIndexError, PulseOperationFailed
from ..coresys import CoreSys, CoreSysAttributes
from ..exceptions import PulseAudioError
from ..utils import AsyncThrottle
from ..jobs.const import JobExecutionLimit
from ..jobs.decorator import Job
_LOGGER: logging.Logger = logging.getLogger(__name__)
@ -217,7 +218,7 @@ class SoundControl(CoreSysAttributes):
await self.sys_run_in_executor(_activate_profile)
await self.update()
@AsyncThrottle(timedelta(seconds=10))
@Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(seconds=10))
async def update(self):
"""Update properties over dbus."""
_LOGGER.info("Updating PulseAudio information")

View File

@ -23,3 +23,5 @@ class JobExecutionLimit(str, Enum):
"""Job Execution limits."""
SINGLE_WAIT = "single_wait"
THROTTLE = "throttle"
THROTTLE_WAIT = "throttle_wait"

View File

@ -1,5 +1,7 @@
"""Job decorator."""
import asyncio
from datetime import datetime, timedelta
from functools import wraps
import logging
from typing import Any, List, Optional, Tuple
@ -24,6 +26,7 @@ class Job(CoreSysAttributes):
cleanup: bool = True,
on_condition: Optional[JobException] = None,
limit: Optional[JobExecutionLimit] = None,
throttle_period: Optional[timedelta] = None,
):
"""Initialize the Job class."""
self.name = name
@ -31,8 +34,17 @@ class Job(CoreSysAttributes):
self.cleanup = cleanup
self.on_condition = on_condition
self.limit = limit
self.throttle_period = throttle_period
self._lock: Optional[asyncio.Semaphore] = None
self._method = None
self._last_call = datetime.min
# Validate Options
if (
self.limit in (JobExecutionLimit.THROTTLE, JobExecutionLimit.THROTTLE_WAIT)
and self.throttle_period is None
):
raise RuntimeError("Using Job without a Throttle period!")
def _post_init(self, args: Tuple[Any]) -> None:
"""Runtime init."""
@ -45,8 +57,9 @@ class Job(CoreSysAttributes):
except AttributeError:
pass
if not self.coresys:
raise JobException(f"coresys is missing on {self.name}")
raise RuntimeError(f"Job on {self.name} need to be an coresys object!")
# Others
if self._lock is None:
self._lock = asyncio.Semaphore()
@ -54,6 +67,7 @@ class Job(CoreSysAttributes):
"""Call the wrapper logic."""
self._method = method
@wraps(method)
async def wrapper(*args, **kwargs) -> Any:
"""Wrap the method."""
self._post_init(args)
@ -67,11 +81,22 @@ class Job(CoreSysAttributes):
raise self.on_condition()
# Handle exection limits
if self.limit:
if self.limit == JobExecutionLimit.SINGLE_WAIT:
await self._acquire_exection_limit()
elif self.limit == JobExecutionLimit.THROTTLE:
time_since_last_call = datetime.now() - self._last_call
if time_since_last_call < self.throttle_period:
return
elif self.limit == JobExecutionLimit.THROTTLE_WAIT:
await self._acquire_exection_limit()
time_since_last_call = datetime.now() - self._last_call
if time_since_last_call < self.throttle_period:
self._release_exception_limits()
return
# Execute Job
try:
self._last_call = datetime.now()
return await self._method(*args, **kwargs)
except HassioError as err:
raise err
@ -155,12 +180,18 @@ class Job(CoreSysAttributes):
async def _acquire_exection_limit(self) -> None:
"""Process exection limits."""
if self.limit == JobExecutionLimit.SINGLE_WAIT:
await self._lock.acquire()
if self.limit not in (
JobExecutionLimit.SINGLE_WAIT,
JobExecutionLimit.THROTTLE_WAIT,
):
return
await self._lock.acquire()
def _release_exception_limits(self) -> None:
"""Release possible exception limits."""
if self.limit == JobExecutionLimit.SINGLE_WAIT:
self._lock.release()
if self.limit not in (
JobExecutionLimit.SINGLE_WAIT,
JobExecutionLimit.THROTTLE_WAIT,
):
return
self._lock.release()

View File

@ -1,10 +1,11 @@
"""Helpers to check core security."""
from contextlib import suppress
from datetime import timedelta
from typing import List, Optional
from ...const import AddonState, CoreState
from ...exceptions import PwnedError
from ...jobs.const import JobCondition
from ...jobs.const import JobCondition, JobExecutionLimit
from ...jobs.decorator import Job
from ...utils.pwned import check_pwned_password
from ..const import ContextType, IssueType, SuggestionType
@ -14,7 +15,11 @@ from .base import CheckBase
class CheckAddonPwned(CheckBase):
"""CheckAddonPwned class for check."""
@Job(conditions=[JobCondition.INTERNET_SYSTEM])
@Job(
conditions=[JobCondition.INTERNET_SYSTEM],
limit=JobExecutionLimit.THROTTLE,
throttle_period=timedelta(hours=24),
)
async def run_check(self) -> None:
"""Run check if not affected by issue."""
await self.sys_homeassistant.secrets.reload()

View File

@ -9,6 +9,8 @@ from typing import Optional
import aiohttp
from awesomeversion import AwesomeVersion
from supervisor.jobs.const import JobExecutionLimit
from .const import (
ATTR_AUDIO,
ATTR_CHANNEL,
@ -28,7 +30,6 @@ from .const import (
from .coresys import CoreSysAttributes
from .exceptions import UpdaterError, UpdaterJobError
from .jobs.decorator import Job, JobCondition
from .utils import AsyncThrottle
from .utils.json import JsonConfig
from .validate import SCHEMA_UPDATER_CONFIG
@ -165,10 +166,11 @@ class Updater(JsonConfig, CoreSysAttributes):
"""Set upstream mode."""
self._data[ATTR_CHANNEL] = value
@AsyncThrottle(timedelta(seconds=30))
@Job(
conditions=[JobCondition.INTERNET_SYSTEM],
on_condition=UpdaterJobError,
limit=JobExecutionLimit.THROTTLE_WAIT,
throttle_period=timedelta(seconds=30),
)
async def fetch_data(self):
"""Fetch current versions from Github.

View File

@ -1,12 +1,11 @@
"""Tools file for Supervisor."""
import asyncio
from datetime import datetime
from ipaddress import IPv4Address
import logging
from pathlib import Path
import re
import socket
from typing import Any, Optional
from typing import Any
_LOGGER: logging.Logger = logging.getLogger(__name__)
@ -35,65 +34,6 @@ def process_lock(method):
return wrap_api
class AsyncThrottle:
"""A class for throttling the execution of tasks.
Decorator that prevents a function from being called more than once every
time period with blocking.
"""
def __init__(self, delta):
"""Initialize async throttle."""
self.throttle_period = delta
self.time_of_last_call = datetime.min
self.synchronize: Optional[asyncio.Lock] = None
def __call__(self, method):
"""Throttle function."""
async def wrapper(*args, **kwargs):
"""Throttle function wrapper."""
if not self.synchronize:
self.synchronize = asyncio.Lock()
async with self.synchronize:
now = datetime.now()
time_since_last_call = now - self.time_of_last_call
if time_since_last_call > self.throttle_period:
self.time_of_last_call = now
return await method(*args, **kwargs)
return wrapper
class AsyncCallFilter:
"""A class for throttling the execution of tasks, with a filter.
Decorator that prevents a function from being called more than once every
time period.
"""
def __init__(self, delta):
"""Initialize async throttle."""
self.throttle_period = delta
self.time_of_last_call = datetime.min
def __call__(self, method):
"""Throttle function."""
async def wrapper(*args, **kwargs):
"""Throttle function wrapper."""
now = datetime.now()
time_since_last_call = now - self.time_of_last_call
if time_since_last_call > self.throttle_period:
self.time_of_last_call = now
return await method(*args, **kwargs)
return wrapper
def check_port(address: IPv4Address, port: int) -> bool:
"""Check if port is mapped."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

View File

@ -1,6 +1,7 @@
"""Test the condition decorators."""
# pylint: disable=protected-access,import-error
import asyncio
from datetime import timedelta
from unittest.mock import patch
import pytest
@ -284,3 +285,63 @@ async def test_exectution_limit_single_wait(
test = TestClass(coresys)
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
async def test_exectution_limit_throttle_wait(
coresys: CoreSys, loop: asyncio.BaseEventLoop
):
"""Test the ignore conditions decorator."""
class TestClass:
"""Test class."""
def __init__(self, coresys: CoreSys):
"""Initialize the test class."""
self.coresys = coresys
self.run = asyncio.Lock()
self.call = 0
@Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(hours=1))
async def execute(self, sleep: float):
"""Execute the class method."""
assert not self.run.locked()
async with self.run:
await asyncio.sleep(sleep)
self.call += 1
test = TestClass(coresys)
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
assert test.call == 1
await asyncio.gather(*[test.execute(0.1)])
assert test.call == 1
async def test_exectution_limit_throttle(coresys: CoreSys, loop: asyncio.BaseEventLoop):
"""Test the ignore conditions decorator."""
class TestClass:
"""Test class."""
def __init__(self, coresys: CoreSys):
"""Initialize the test class."""
self.coresys = coresys
self.run = asyncio.Lock()
self.call = 0
@Job(limit=JobExecutionLimit.THROTTLE, throttle_period=timedelta(hours=1))
async def execute(self, sleep: float):
"""Execute the class method."""
assert not self.run.locked()
async with self.run:
await asyncio.sleep(sleep)
self.call += 1
test = TestClass(coresys)
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
assert test.call == 1
await asyncio.gather(*[test.execute(0.1)])
assert test.call == 1

View File

@ -30,7 +30,7 @@ async def test_check(coresys: CoreSys):
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
AsyncMock(return_value=True),
) as mock:
await addon_pwned.run_check()
await addon_pwned.run_check.__wrapped__(addon_pwned)
assert not mock.called
addon.pwned.add("123456")
@ -38,7 +38,7 @@ async def test_check(coresys: CoreSys):
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
AsyncMock(return_value=False),
) as mock:
await addon_pwned.run_check()
await addon_pwned.run_check.__wrapped__(addon_pwned)
assert mock.called
assert len(coresys.resolution.issues) == 0
@ -47,7 +47,7 @@ async def test_check(coresys: CoreSys):
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
AsyncMock(return_value=True),
) as mock:
await addon_pwned.run_check()
await addon_pwned.run_check.__wrapped__(addon_pwned)
assert mock.called
assert len(coresys.resolution.issues) == 1