From 31f5033dca254096bd5d2c594ccdf4359bca5246 Mon Sep 17 00:00:00 2001 From: Pascal Vizeli Date: Thu, 25 Feb 2021 23:29:03 +0100 Subject: [PATCH] Add throttle to job execution (#2631) * Add throttle to job execution * fix unittests * Add tests * address comments * add comment * better on __init__ * New text * Simplify logic --- supervisor/homeassistant/secrets.py | 5 +- supervisor/host/sound.py | 5 +- supervisor/jobs/const.py | 2 + supervisor/jobs/decorator.py | 47 +++++++++++--- supervisor/resolution/checks/addon_pwned.py | 9 ++- supervisor/updater.py | 6 +- supervisor/utils/__init__.py | 62 +------------------ tests/jobs/test_job_decorator.py | 61 ++++++++++++++++++ .../check/test_check_addon_pwned.py | 6 +- 9 files changed, 123 insertions(+), 80 deletions(-) diff --git a/supervisor/homeassistant/secrets.py b/supervisor/homeassistant/secrets.py index 5214233a8..3f2cfd0a0 100644 --- a/supervisor/homeassistant/secrets.py +++ b/supervisor/homeassistant/secrets.py @@ -7,7 +7,8 @@ from typing import Dict, Optional, Union from ruamel.yaml import YAML, YAMLError from ..coresys import CoreSys, CoreSysAttributes -from ..utils import AsyncThrottle +from ..jobs.const import JobExecutionLimit +from ..jobs.decorator import Job _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -40,7 +41,7 @@ class HomeAssistantSecrets(CoreSysAttributes): """Reload secrets.""" await self._read_secrets() - @AsyncThrottle(timedelta(seconds=60)) + @Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(seconds=60)) async def _read_secrets(self): """Read secrets.yaml into memory.""" if not self.path_secrets.exists(): diff --git a/supervisor/host/sound.py b/supervisor/host/sound.py index 71cb2806c..a534765b9 100644 --- a/supervisor/host/sound.py +++ b/supervisor/host/sound.py @@ -9,7 +9,8 @@ from pulsectl import Pulse, PulseError, PulseIndexError, PulseOperationFailed from ..coresys import CoreSys, CoreSysAttributes from ..exceptions import PulseAudioError -from ..utils import AsyncThrottle +from ..jobs.const import JobExecutionLimit +from ..jobs.decorator import Job _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -217,7 +218,7 @@ class SoundControl(CoreSysAttributes): await self.sys_run_in_executor(_activate_profile) await self.update() - @AsyncThrottle(timedelta(seconds=10)) + @Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(seconds=10)) async def update(self): """Update properties over dbus.""" _LOGGER.info("Updating PulseAudio information") diff --git a/supervisor/jobs/const.py b/supervisor/jobs/const.py index 946d9d538..ae7d41d65 100644 --- a/supervisor/jobs/const.py +++ b/supervisor/jobs/const.py @@ -23,3 +23,5 @@ class JobExecutionLimit(str, Enum): """Job Execution limits.""" SINGLE_WAIT = "single_wait" + THROTTLE = "throttle" + THROTTLE_WAIT = "throttle_wait" diff --git a/supervisor/jobs/decorator.py b/supervisor/jobs/decorator.py index d697c927a..4c17592f0 100644 --- a/supervisor/jobs/decorator.py +++ b/supervisor/jobs/decorator.py @@ -1,5 +1,7 @@ """Job decorator.""" import asyncio +from datetime import datetime, timedelta +from functools import wraps import logging from typing import Any, List, Optional, Tuple @@ -24,6 +26,7 @@ class Job(CoreSysAttributes): cleanup: bool = True, on_condition: Optional[JobException] = None, limit: Optional[JobExecutionLimit] = None, + throttle_period: Optional[timedelta] = None, ): """Initialize the Job class.""" self.name = name @@ -31,8 +34,17 @@ class Job(CoreSysAttributes): self.cleanup = cleanup self.on_condition = on_condition self.limit = limit + self.throttle_period = throttle_period self._lock: Optional[asyncio.Semaphore] = None self._method = None + self._last_call = datetime.min + + # Validate Options + if ( + self.limit in (JobExecutionLimit.THROTTLE, JobExecutionLimit.THROTTLE_WAIT) + and self.throttle_period is None + ): + raise RuntimeError("Using Job without a Throttle period!") def _post_init(self, args: Tuple[Any]) -> None: """Runtime init.""" @@ -45,8 +57,9 @@ class Job(CoreSysAttributes): except AttributeError: pass if not self.coresys: - raise JobException(f"coresys is missing on {self.name}") + raise RuntimeError(f"Job on {self.name} need to be an coresys object!") + # Others if self._lock is None: self._lock = asyncio.Semaphore() @@ -54,6 +67,7 @@ class Job(CoreSysAttributes): """Call the wrapper logic.""" self._method = method + @wraps(method) async def wrapper(*args, **kwargs) -> Any: """Wrap the method.""" self._post_init(args) @@ -67,11 +81,22 @@ class Job(CoreSysAttributes): raise self.on_condition() # Handle exection limits - if self.limit: + if self.limit == JobExecutionLimit.SINGLE_WAIT: await self._acquire_exection_limit() + elif self.limit == JobExecutionLimit.THROTTLE: + time_since_last_call = datetime.now() - self._last_call + if time_since_last_call < self.throttle_period: + return + elif self.limit == JobExecutionLimit.THROTTLE_WAIT: + await self._acquire_exection_limit() + time_since_last_call = datetime.now() - self._last_call + if time_since_last_call < self.throttle_period: + self._release_exception_limits() + return # Execute Job try: + self._last_call = datetime.now() return await self._method(*args, **kwargs) except HassioError as err: raise err @@ -155,12 +180,18 @@ class Job(CoreSysAttributes): async def _acquire_exection_limit(self) -> None: """Process exection limits.""" - - if self.limit == JobExecutionLimit.SINGLE_WAIT: - await self._lock.acquire() + if self.limit not in ( + JobExecutionLimit.SINGLE_WAIT, + JobExecutionLimit.THROTTLE_WAIT, + ): + return + await self._lock.acquire() def _release_exception_limits(self) -> None: """Release possible exception limits.""" - - if self.limit == JobExecutionLimit.SINGLE_WAIT: - self._lock.release() + if self.limit not in ( + JobExecutionLimit.SINGLE_WAIT, + JobExecutionLimit.THROTTLE_WAIT, + ): + return + self._lock.release() diff --git a/supervisor/resolution/checks/addon_pwned.py b/supervisor/resolution/checks/addon_pwned.py index b46bd3085..3c323edf0 100644 --- a/supervisor/resolution/checks/addon_pwned.py +++ b/supervisor/resolution/checks/addon_pwned.py @@ -1,10 +1,11 @@ """Helpers to check core security.""" from contextlib import suppress +from datetime import timedelta from typing import List, Optional from ...const import AddonState, CoreState from ...exceptions import PwnedError -from ...jobs.const import JobCondition +from ...jobs.const import JobCondition, JobExecutionLimit from ...jobs.decorator import Job from ...utils.pwned import check_pwned_password from ..const import ContextType, IssueType, SuggestionType @@ -14,7 +15,11 @@ from .base import CheckBase class CheckAddonPwned(CheckBase): """CheckAddonPwned class for check.""" - @Job(conditions=[JobCondition.INTERNET_SYSTEM]) + @Job( + conditions=[JobCondition.INTERNET_SYSTEM], + limit=JobExecutionLimit.THROTTLE, + throttle_period=timedelta(hours=24), + ) async def run_check(self) -> None: """Run check if not affected by issue.""" await self.sys_homeassistant.secrets.reload() diff --git a/supervisor/updater.py b/supervisor/updater.py index bea8267df..950ac3206 100644 --- a/supervisor/updater.py +++ b/supervisor/updater.py @@ -9,6 +9,8 @@ from typing import Optional import aiohttp from awesomeversion import AwesomeVersion +from supervisor.jobs.const import JobExecutionLimit + from .const import ( ATTR_AUDIO, ATTR_CHANNEL, @@ -28,7 +30,6 @@ from .const import ( from .coresys import CoreSysAttributes from .exceptions import UpdaterError, UpdaterJobError from .jobs.decorator import Job, JobCondition -from .utils import AsyncThrottle from .utils.json import JsonConfig from .validate import SCHEMA_UPDATER_CONFIG @@ -165,10 +166,11 @@ class Updater(JsonConfig, CoreSysAttributes): """Set upstream mode.""" self._data[ATTR_CHANNEL] = value - @AsyncThrottle(timedelta(seconds=30)) @Job( conditions=[JobCondition.INTERNET_SYSTEM], on_condition=UpdaterJobError, + limit=JobExecutionLimit.THROTTLE_WAIT, + throttle_period=timedelta(seconds=30), ) async def fetch_data(self): """Fetch current versions from Github. diff --git a/supervisor/utils/__init__.py b/supervisor/utils/__init__.py index c27e0feb9..6b5907a6e 100644 --- a/supervisor/utils/__init__.py +++ b/supervisor/utils/__init__.py @@ -1,12 +1,11 @@ """Tools file for Supervisor.""" import asyncio -from datetime import datetime from ipaddress import IPv4Address import logging from pathlib import Path import re import socket -from typing import Any, Optional +from typing import Any _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -35,65 +34,6 @@ def process_lock(method): return wrap_api -class AsyncThrottle: - """A class for throttling the execution of tasks. - - Decorator that prevents a function from being called more than once every - time period with blocking. - """ - - def __init__(self, delta): - """Initialize async throttle.""" - self.throttle_period = delta - self.time_of_last_call = datetime.min - self.synchronize: Optional[asyncio.Lock] = None - - def __call__(self, method): - """Throttle function.""" - - async def wrapper(*args, **kwargs): - """Throttle function wrapper.""" - if not self.synchronize: - self.synchronize = asyncio.Lock() - - async with self.synchronize: - now = datetime.now() - time_since_last_call = now - self.time_of_last_call - - if time_since_last_call > self.throttle_period: - self.time_of_last_call = now - return await method(*args, **kwargs) - - return wrapper - - -class AsyncCallFilter: - """A class for throttling the execution of tasks, with a filter. - - Decorator that prevents a function from being called more than once every - time period. - """ - - def __init__(self, delta): - """Initialize async throttle.""" - self.throttle_period = delta - self.time_of_last_call = datetime.min - - def __call__(self, method): - """Throttle function.""" - - async def wrapper(*args, **kwargs): - """Throttle function wrapper.""" - now = datetime.now() - time_since_last_call = now - self.time_of_last_call - - if time_since_last_call > self.throttle_period: - self.time_of_last_call = now - return await method(*args, **kwargs) - - return wrapper - - def check_port(address: IPv4Address, port: int) -> bool: """Check if port is mapped.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) diff --git a/tests/jobs/test_job_decorator.py b/tests/jobs/test_job_decorator.py index 597c25377..d5f5e99ad 100644 --- a/tests/jobs/test_job_decorator.py +++ b/tests/jobs/test_job_decorator.py @@ -1,6 +1,7 @@ """Test the condition decorators.""" # pylint: disable=protected-access,import-error import asyncio +from datetime import timedelta from unittest.mock import patch import pytest @@ -284,3 +285,63 @@ async def test_exectution_limit_single_wait( test = TestClass(coresys) await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)]) + + +async def test_exectution_limit_throttle_wait( + coresys: CoreSys, loop: asyncio.BaseEventLoop +): + """Test the ignore conditions decorator.""" + + class TestClass: + """Test class.""" + + def __init__(self, coresys: CoreSys): + """Initialize the test class.""" + self.coresys = coresys + self.run = asyncio.Lock() + self.call = 0 + + @Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(hours=1)) + async def execute(self, sleep: float): + """Execute the class method.""" + assert not self.run.locked() + async with self.run: + await asyncio.sleep(sleep) + self.call += 1 + + test = TestClass(coresys) + + await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)]) + assert test.call == 1 + + await asyncio.gather(*[test.execute(0.1)]) + assert test.call == 1 + + +async def test_exectution_limit_throttle(coresys: CoreSys, loop: asyncio.BaseEventLoop): + """Test the ignore conditions decorator.""" + + class TestClass: + """Test class.""" + + def __init__(self, coresys: CoreSys): + """Initialize the test class.""" + self.coresys = coresys + self.run = asyncio.Lock() + self.call = 0 + + @Job(limit=JobExecutionLimit.THROTTLE, throttle_period=timedelta(hours=1)) + async def execute(self, sleep: float): + """Execute the class method.""" + assert not self.run.locked() + async with self.run: + await asyncio.sleep(sleep) + self.call += 1 + + test = TestClass(coresys) + + await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)]) + assert test.call == 1 + + await asyncio.gather(*[test.execute(0.1)]) + assert test.call == 1 diff --git a/tests/resolution/check/test_check_addon_pwned.py b/tests/resolution/check/test_check_addon_pwned.py index a76b5ba9e..3c6844d93 100644 --- a/tests/resolution/check/test_check_addon_pwned.py +++ b/tests/resolution/check/test_check_addon_pwned.py @@ -30,7 +30,7 @@ async def test_check(coresys: CoreSys): "supervisor.resolution.checks.addon_pwned.check_pwned_password", AsyncMock(return_value=True), ) as mock: - await addon_pwned.run_check() + await addon_pwned.run_check.__wrapped__(addon_pwned) assert not mock.called addon.pwned.add("123456") @@ -38,7 +38,7 @@ async def test_check(coresys: CoreSys): "supervisor.resolution.checks.addon_pwned.check_pwned_password", AsyncMock(return_value=False), ) as mock: - await addon_pwned.run_check() + await addon_pwned.run_check.__wrapped__(addon_pwned) assert mock.called assert len(coresys.resolution.issues) == 0 @@ -47,7 +47,7 @@ async def test_check(coresys: CoreSys): "supervisor.resolution.checks.addon_pwned.check_pwned_password", AsyncMock(return_value=True), ) as mock: - await addon_pwned.run_check() + await addon_pwned.run_check.__wrapped__(addon_pwned) assert mock.called assert len(coresys.resolution.issues) == 1