Add throttle to job execution (#2631)

* Add throttle to job execution

* fix unittests

* Add tests

* address comments

* add comment

* better on __init__

* New text

* Simplify logic
This commit is contained in:
Pascal Vizeli 2021-02-25 23:29:03 +01:00 committed by GitHub
parent 78d9c60be5
commit 31f5033dca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 123 additions and 80 deletions

View File

@ -7,7 +7,8 @@ from typing import Dict, Optional, Union
from ruamel.yaml import YAML, YAMLError from ruamel.yaml import YAML, YAMLError
from ..coresys import CoreSys, CoreSysAttributes from ..coresys import CoreSys, CoreSysAttributes
from ..utils import AsyncThrottle from ..jobs.const import JobExecutionLimit
from ..jobs.decorator import Job
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -40,7 +41,7 @@ class HomeAssistantSecrets(CoreSysAttributes):
"""Reload secrets.""" """Reload secrets."""
await self._read_secrets() await self._read_secrets()
@AsyncThrottle(timedelta(seconds=60)) @Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(seconds=60))
async def _read_secrets(self): async def _read_secrets(self):
"""Read secrets.yaml into memory.""" """Read secrets.yaml into memory."""
if not self.path_secrets.exists(): if not self.path_secrets.exists():

View File

@ -9,7 +9,8 @@ from pulsectl import Pulse, PulseError, PulseIndexError, PulseOperationFailed
from ..coresys import CoreSys, CoreSysAttributes from ..coresys import CoreSys, CoreSysAttributes
from ..exceptions import PulseAudioError from ..exceptions import PulseAudioError
from ..utils import AsyncThrottle from ..jobs.const import JobExecutionLimit
from ..jobs.decorator import Job
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -217,7 +218,7 @@ class SoundControl(CoreSysAttributes):
await self.sys_run_in_executor(_activate_profile) await self.sys_run_in_executor(_activate_profile)
await self.update() await self.update()
@AsyncThrottle(timedelta(seconds=10)) @Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(seconds=10))
async def update(self): async def update(self):
"""Update properties over dbus.""" """Update properties over dbus."""
_LOGGER.info("Updating PulseAudio information") _LOGGER.info("Updating PulseAudio information")

View File

@ -23,3 +23,5 @@ class JobExecutionLimit(str, Enum):
"""Job Execution limits.""" """Job Execution limits."""
SINGLE_WAIT = "single_wait" SINGLE_WAIT = "single_wait"
THROTTLE = "throttle"
THROTTLE_WAIT = "throttle_wait"

View File

@ -1,5 +1,7 @@
"""Job decorator.""" """Job decorator."""
import asyncio import asyncio
from datetime import datetime, timedelta
from functools import wraps
import logging import logging
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
@ -24,6 +26,7 @@ class Job(CoreSysAttributes):
cleanup: bool = True, cleanup: bool = True,
on_condition: Optional[JobException] = None, on_condition: Optional[JobException] = None,
limit: Optional[JobExecutionLimit] = None, limit: Optional[JobExecutionLimit] = None,
throttle_period: Optional[timedelta] = None,
): ):
"""Initialize the Job class.""" """Initialize the Job class."""
self.name = name self.name = name
@ -31,8 +34,17 @@ class Job(CoreSysAttributes):
self.cleanup = cleanup self.cleanup = cleanup
self.on_condition = on_condition self.on_condition = on_condition
self.limit = limit self.limit = limit
self.throttle_period = throttle_period
self._lock: Optional[asyncio.Semaphore] = None self._lock: Optional[asyncio.Semaphore] = None
self._method = None self._method = None
self._last_call = datetime.min
# Validate Options
if (
self.limit in (JobExecutionLimit.THROTTLE, JobExecutionLimit.THROTTLE_WAIT)
and self.throttle_period is None
):
raise RuntimeError("Using Job without a Throttle period!")
def _post_init(self, args: Tuple[Any]) -> None: def _post_init(self, args: Tuple[Any]) -> None:
"""Runtime init.""" """Runtime init."""
@ -45,8 +57,9 @@ class Job(CoreSysAttributes):
except AttributeError: except AttributeError:
pass pass
if not self.coresys: if not self.coresys:
raise JobException(f"coresys is missing on {self.name}") raise RuntimeError(f"Job on {self.name} need to be an coresys object!")
# Others
if self._lock is None: if self._lock is None:
self._lock = asyncio.Semaphore() self._lock = asyncio.Semaphore()
@ -54,6 +67,7 @@ class Job(CoreSysAttributes):
"""Call the wrapper logic.""" """Call the wrapper logic."""
self._method = method self._method = method
@wraps(method)
async def wrapper(*args, **kwargs) -> Any: async def wrapper(*args, **kwargs) -> Any:
"""Wrap the method.""" """Wrap the method."""
self._post_init(args) self._post_init(args)
@ -67,11 +81,22 @@ class Job(CoreSysAttributes):
raise self.on_condition() raise self.on_condition()
# Handle exection limits # Handle exection limits
if self.limit: if self.limit == JobExecutionLimit.SINGLE_WAIT:
await self._acquire_exection_limit() await self._acquire_exection_limit()
elif self.limit == JobExecutionLimit.THROTTLE:
time_since_last_call = datetime.now() - self._last_call
if time_since_last_call < self.throttle_period:
return
elif self.limit == JobExecutionLimit.THROTTLE_WAIT:
await self._acquire_exection_limit()
time_since_last_call = datetime.now() - self._last_call
if time_since_last_call < self.throttle_period:
self._release_exception_limits()
return
# Execute Job # Execute Job
try: try:
self._last_call = datetime.now()
return await self._method(*args, **kwargs) return await self._method(*args, **kwargs)
except HassioError as err: except HassioError as err:
raise err raise err
@ -155,12 +180,18 @@ class Job(CoreSysAttributes):
async def _acquire_exection_limit(self) -> None: async def _acquire_exection_limit(self) -> None:
"""Process exection limits.""" """Process exection limits."""
if self.limit not in (
if self.limit == JobExecutionLimit.SINGLE_WAIT: JobExecutionLimit.SINGLE_WAIT,
JobExecutionLimit.THROTTLE_WAIT,
):
return
await self._lock.acquire() await self._lock.acquire()
def _release_exception_limits(self) -> None: def _release_exception_limits(self) -> None:
"""Release possible exception limits.""" """Release possible exception limits."""
if self.limit not in (
if self.limit == JobExecutionLimit.SINGLE_WAIT: JobExecutionLimit.SINGLE_WAIT,
JobExecutionLimit.THROTTLE_WAIT,
):
return
self._lock.release() self._lock.release()

View File

@ -1,10 +1,11 @@
"""Helpers to check core security.""" """Helpers to check core security."""
from contextlib import suppress from contextlib import suppress
from datetime import timedelta
from typing import List, Optional from typing import List, Optional
from ...const import AddonState, CoreState from ...const import AddonState, CoreState
from ...exceptions import PwnedError from ...exceptions import PwnedError
from ...jobs.const import JobCondition from ...jobs.const import JobCondition, JobExecutionLimit
from ...jobs.decorator import Job from ...jobs.decorator import Job
from ...utils.pwned import check_pwned_password from ...utils.pwned import check_pwned_password
from ..const import ContextType, IssueType, SuggestionType from ..const import ContextType, IssueType, SuggestionType
@ -14,7 +15,11 @@ from .base import CheckBase
class CheckAddonPwned(CheckBase): class CheckAddonPwned(CheckBase):
"""CheckAddonPwned class for check.""" """CheckAddonPwned class for check."""
@Job(conditions=[JobCondition.INTERNET_SYSTEM]) @Job(
conditions=[JobCondition.INTERNET_SYSTEM],
limit=JobExecutionLimit.THROTTLE,
throttle_period=timedelta(hours=24),
)
async def run_check(self) -> None: async def run_check(self) -> None:
"""Run check if not affected by issue.""" """Run check if not affected by issue."""
await self.sys_homeassistant.secrets.reload() await self.sys_homeassistant.secrets.reload()

View File

@ -9,6 +9,8 @@ from typing import Optional
import aiohttp import aiohttp
from awesomeversion import AwesomeVersion from awesomeversion import AwesomeVersion
from supervisor.jobs.const import JobExecutionLimit
from .const import ( from .const import (
ATTR_AUDIO, ATTR_AUDIO,
ATTR_CHANNEL, ATTR_CHANNEL,
@ -28,7 +30,6 @@ from .const import (
from .coresys import CoreSysAttributes from .coresys import CoreSysAttributes
from .exceptions import UpdaterError, UpdaterJobError from .exceptions import UpdaterError, UpdaterJobError
from .jobs.decorator import Job, JobCondition from .jobs.decorator import Job, JobCondition
from .utils import AsyncThrottle
from .utils.json import JsonConfig from .utils.json import JsonConfig
from .validate import SCHEMA_UPDATER_CONFIG from .validate import SCHEMA_UPDATER_CONFIG
@ -165,10 +166,11 @@ class Updater(JsonConfig, CoreSysAttributes):
"""Set upstream mode.""" """Set upstream mode."""
self._data[ATTR_CHANNEL] = value self._data[ATTR_CHANNEL] = value
@AsyncThrottle(timedelta(seconds=30))
@Job( @Job(
conditions=[JobCondition.INTERNET_SYSTEM], conditions=[JobCondition.INTERNET_SYSTEM],
on_condition=UpdaterJobError, on_condition=UpdaterJobError,
limit=JobExecutionLimit.THROTTLE_WAIT,
throttle_period=timedelta(seconds=30),
) )
async def fetch_data(self): async def fetch_data(self):
"""Fetch current versions from Github. """Fetch current versions from Github.

View File

@ -1,12 +1,11 @@
"""Tools file for Supervisor.""" """Tools file for Supervisor."""
import asyncio import asyncio
from datetime import datetime
from ipaddress import IPv4Address from ipaddress import IPv4Address
import logging import logging
from pathlib import Path from pathlib import Path
import re import re
import socket import socket
from typing import Any, Optional from typing import Any
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -35,65 +34,6 @@ def process_lock(method):
return wrap_api return wrap_api
class AsyncThrottle:
"""A class for throttling the execution of tasks.
Decorator that prevents a function from being called more than once every
time period with blocking.
"""
def __init__(self, delta):
"""Initialize async throttle."""
self.throttle_period = delta
self.time_of_last_call = datetime.min
self.synchronize: Optional[asyncio.Lock] = None
def __call__(self, method):
"""Throttle function."""
async def wrapper(*args, **kwargs):
"""Throttle function wrapper."""
if not self.synchronize:
self.synchronize = asyncio.Lock()
async with self.synchronize:
now = datetime.now()
time_since_last_call = now - self.time_of_last_call
if time_since_last_call > self.throttle_period:
self.time_of_last_call = now
return await method(*args, **kwargs)
return wrapper
class AsyncCallFilter:
"""A class for throttling the execution of tasks, with a filter.
Decorator that prevents a function from being called more than once every
time period.
"""
def __init__(self, delta):
"""Initialize async throttle."""
self.throttle_period = delta
self.time_of_last_call = datetime.min
def __call__(self, method):
"""Throttle function."""
async def wrapper(*args, **kwargs):
"""Throttle function wrapper."""
now = datetime.now()
time_since_last_call = now - self.time_of_last_call
if time_since_last_call > self.throttle_period:
self.time_of_last_call = now
return await method(*args, **kwargs)
return wrapper
def check_port(address: IPv4Address, port: int) -> bool: def check_port(address: IPv4Address, port: int) -> bool:
"""Check if port is mapped.""" """Check if port is mapped."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

View File

@ -1,6 +1,7 @@
"""Test the condition decorators.""" """Test the condition decorators."""
# pylint: disable=protected-access,import-error # pylint: disable=protected-access,import-error
import asyncio import asyncio
from datetime import timedelta
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -284,3 +285,63 @@ async def test_exectution_limit_single_wait(
test = TestClass(coresys) test = TestClass(coresys)
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)]) await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
async def test_exectution_limit_throttle_wait(
coresys: CoreSys, loop: asyncio.BaseEventLoop
):
"""Test the ignore conditions decorator."""
class TestClass:
"""Test class."""
def __init__(self, coresys: CoreSys):
"""Initialize the test class."""
self.coresys = coresys
self.run = asyncio.Lock()
self.call = 0
@Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(hours=1))
async def execute(self, sleep: float):
"""Execute the class method."""
assert not self.run.locked()
async with self.run:
await asyncio.sleep(sleep)
self.call += 1
test = TestClass(coresys)
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
assert test.call == 1
await asyncio.gather(*[test.execute(0.1)])
assert test.call == 1
async def test_exectution_limit_throttle(coresys: CoreSys, loop: asyncio.BaseEventLoop):
"""Test the ignore conditions decorator."""
class TestClass:
"""Test class."""
def __init__(self, coresys: CoreSys):
"""Initialize the test class."""
self.coresys = coresys
self.run = asyncio.Lock()
self.call = 0
@Job(limit=JobExecutionLimit.THROTTLE, throttle_period=timedelta(hours=1))
async def execute(self, sleep: float):
"""Execute the class method."""
assert not self.run.locked()
async with self.run:
await asyncio.sleep(sleep)
self.call += 1
test = TestClass(coresys)
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
assert test.call == 1
await asyncio.gather(*[test.execute(0.1)])
assert test.call == 1

View File

@ -30,7 +30,7 @@ async def test_check(coresys: CoreSys):
"supervisor.resolution.checks.addon_pwned.check_pwned_password", "supervisor.resolution.checks.addon_pwned.check_pwned_password",
AsyncMock(return_value=True), AsyncMock(return_value=True),
) as mock: ) as mock:
await addon_pwned.run_check() await addon_pwned.run_check.__wrapped__(addon_pwned)
assert not mock.called assert not mock.called
addon.pwned.add("123456") addon.pwned.add("123456")
@ -38,7 +38,7 @@ async def test_check(coresys: CoreSys):
"supervisor.resolution.checks.addon_pwned.check_pwned_password", "supervisor.resolution.checks.addon_pwned.check_pwned_password",
AsyncMock(return_value=False), AsyncMock(return_value=False),
) as mock: ) as mock:
await addon_pwned.run_check() await addon_pwned.run_check.__wrapped__(addon_pwned)
assert mock.called assert mock.called
assert len(coresys.resolution.issues) == 0 assert len(coresys.resolution.issues) == 0
@ -47,7 +47,7 @@ async def test_check(coresys: CoreSys):
"supervisor.resolution.checks.addon_pwned.check_pwned_password", "supervisor.resolution.checks.addon_pwned.check_pwned_password",
AsyncMock(return_value=True), AsyncMock(return_value=True),
) as mock: ) as mock:
await addon_pwned.run_check() await addon_pwned.run_check.__wrapped__(addon_pwned)
assert mock.called assert mock.called
assert len(coresys.resolution.issues) == 1 assert len(coresys.resolution.issues) == 1