mirror of
https://github.com/home-assistant/supervisor.git
synced 2025-07-15 13:16:29 +00:00
Add throttle to job execution (#2631)
* Add throttle to job execution * fix unittests * Add tests * address comments * add comment * better on __init__ * New text * Simplify logic
This commit is contained in:
parent
78d9c60be5
commit
31f5033dca
@ -7,7 +7,8 @@ from typing import Dict, Optional, Union
|
|||||||
from ruamel.yaml import YAML, YAMLError
|
from ruamel.yaml import YAML, YAMLError
|
||||||
|
|
||||||
from ..coresys import CoreSys, CoreSysAttributes
|
from ..coresys import CoreSys, CoreSysAttributes
|
||||||
from ..utils import AsyncThrottle
|
from ..jobs.const import JobExecutionLimit
|
||||||
|
from ..jobs.decorator import Job
|
||||||
|
|
||||||
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -40,7 +41,7 @@ class HomeAssistantSecrets(CoreSysAttributes):
|
|||||||
"""Reload secrets."""
|
"""Reload secrets."""
|
||||||
await self._read_secrets()
|
await self._read_secrets()
|
||||||
|
|
||||||
@AsyncThrottle(timedelta(seconds=60))
|
@Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(seconds=60))
|
||||||
async def _read_secrets(self):
|
async def _read_secrets(self):
|
||||||
"""Read secrets.yaml into memory."""
|
"""Read secrets.yaml into memory."""
|
||||||
if not self.path_secrets.exists():
|
if not self.path_secrets.exists():
|
||||||
|
@ -9,7 +9,8 @@ from pulsectl import Pulse, PulseError, PulseIndexError, PulseOperationFailed
|
|||||||
|
|
||||||
from ..coresys import CoreSys, CoreSysAttributes
|
from ..coresys import CoreSys, CoreSysAttributes
|
||||||
from ..exceptions import PulseAudioError
|
from ..exceptions import PulseAudioError
|
||||||
from ..utils import AsyncThrottle
|
from ..jobs.const import JobExecutionLimit
|
||||||
|
from ..jobs.decorator import Job
|
||||||
|
|
||||||
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -217,7 +218,7 @@ class SoundControl(CoreSysAttributes):
|
|||||||
await self.sys_run_in_executor(_activate_profile)
|
await self.sys_run_in_executor(_activate_profile)
|
||||||
await self.update()
|
await self.update()
|
||||||
|
|
||||||
@AsyncThrottle(timedelta(seconds=10))
|
@Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(seconds=10))
|
||||||
async def update(self):
|
async def update(self):
|
||||||
"""Update properties over dbus."""
|
"""Update properties over dbus."""
|
||||||
_LOGGER.info("Updating PulseAudio information")
|
_LOGGER.info("Updating PulseAudio information")
|
||||||
|
@ -23,3 +23,5 @@ class JobExecutionLimit(str, Enum):
|
|||||||
"""Job Execution limits."""
|
"""Job Execution limits."""
|
||||||
|
|
||||||
SINGLE_WAIT = "single_wait"
|
SINGLE_WAIT = "single_wait"
|
||||||
|
THROTTLE = "throttle"
|
||||||
|
THROTTLE_WAIT = "throttle_wait"
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Job decorator."""
|
"""Job decorator."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from functools import wraps
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
@ -24,6 +26,7 @@ class Job(CoreSysAttributes):
|
|||||||
cleanup: bool = True,
|
cleanup: bool = True,
|
||||||
on_condition: Optional[JobException] = None,
|
on_condition: Optional[JobException] = None,
|
||||||
limit: Optional[JobExecutionLimit] = None,
|
limit: Optional[JobExecutionLimit] = None,
|
||||||
|
throttle_period: Optional[timedelta] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the Job class."""
|
"""Initialize the Job class."""
|
||||||
self.name = name
|
self.name = name
|
||||||
@ -31,8 +34,17 @@ class Job(CoreSysAttributes):
|
|||||||
self.cleanup = cleanup
|
self.cleanup = cleanup
|
||||||
self.on_condition = on_condition
|
self.on_condition = on_condition
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
|
self.throttle_period = throttle_period
|
||||||
self._lock: Optional[asyncio.Semaphore] = None
|
self._lock: Optional[asyncio.Semaphore] = None
|
||||||
self._method = None
|
self._method = None
|
||||||
|
self._last_call = datetime.min
|
||||||
|
|
||||||
|
# Validate Options
|
||||||
|
if (
|
||||||
|
self.limit in (JobExecutionLimit.THROTTLE, JobExecutionLimit.THROTTLE_WAIT)
|
||||||
|
and self.throttle_period is None
|
||||||
|
):
|
||||||
|
raise RuntimeError("Using Job without a Throttle period!")
|
||||||
|
|
||||||
def _post_init(self, args: Tuple[Any]) -> None:
|
def _post_init(self, args: Tuple[Any]) -> None:
|
||||||
"""Runtime init."""
|
"""Runtime init."""
|
||||||
@ -45,8 +57,9 @@ class Job(CoreSysAttributes):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
if not self.coresys:
|
if not self.coresys:
|
||||||
raise JobException(f"coresys is missing on {self.name}")
|
raise RuntimeError(f"Job on {self.name} need to be an coresys object!")
|
||||||
|
|
||||||
|
# Others
|
||||||
if self._lock is None:
|
if self._lock is None:
|
||||||
self._lock = asyncio.Semaphore()
|
self._lock = asyncio.Semaphore()
|
||||||
|
|
||||||
@ -54,6 +67,7 @@ class Job(CoreSysAttributes):
|
|||||||
"""Call the wrapper logic."""
|
"""Call the wrapper logic."""
|
||||||
self._method = method
|
self._method = method
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
async def wrapper(*args, **kwargs) -> Any:
|
async def wrapper(*args, **kwargs) -> Any:
|
||||||
"""Wrap the method."""
|
"""Wrap the method."""
|
||||||
self._post_init(args)
|
self._post_init(args)
|
||||||
@ -67,11 +81,22 @@ class Job(CoreSysAttributes):
|
|||||||
raise self.on_condition()
|
raise self.on_condition()
|
||||||
|
|
||||||
# Handle exection limits
|
# Handle exection limits
|
||||||
if self.limit:
|
if self.limit == JobExecutionLimit.SINGLE_WAIT:
|
||||||
await self._acquire_exection_limit()
|
await self._acquire_exection_limit()
|
||||||
|
elif self.limit == JobExecutionLimit.THROTTLE:
|
||||||
|
time_since_last_call = datetime.now() - self._last_call
|
||||||
|
if time_since_last_call < self.throttle_period:
|
||||||
|
return
|
||||||
|
elif self.limit == JobExecutionLimit.THROTTLE_WAIT:
|
||||||
|
await self._acquire_exection_limit()
|
||||||
|
time_since_last_call = datetime.now() - self._last_call
|
||||||
|
if time_since_last_call < self.throttle_period:
|
||||||
|
self._release_exception_limits()
|
||||||
|
return
|
||||||
|
|
||||||
# Execute Job
|
# Execute Job
|
||||||
try:
|
try:
|
||||||
|
self._last_call = datetime.now()
|
||||||
return await self._method(*args, **kwargs)
|
return await self._method(*args, **kwargs)
|
||||||
except HassioError as err:
|
except HassioError as err:
|
||||||
raise err
|
raise err
|
||||||
@ -155,12 +180,18 @@ class Job(CoreSysAttributes):
|
|||||||
|
|
||||||
async def _acquire_exection_limit(self) -> None:
|
async def _acquire_exection_limit(self) -> None:
|
||||||
"""Process exection limits."""
|
"""Process exection limits."""
|
||||||
|
if self.limit not in (
|
||||||
if self.limit == JobExecutionLimit.SINGLE_WAIT:
|
JobExecutionLimit.SINGLE_WAIT,
|
||||||
|
JobExecutionLimit.THROTTLE_WAIT,
|
||||||
|
):
|
||||||
|
return
|
||||||
await self._lock.acquire()
|
await self._lock.acquire()
|
||||||
|
|
||||||
def _release_exception_limits(self) -> None:
|
def _release_exception_limits(self) -> None:
|
||||||
"""Release possible exception limits."""
|
"""Release possible exception limits."""
|
||||||
|
if self.limit not in (
|
||||||
if self.limit == JobExecutionLimit.SINGLE_WAIT:
|
JobExecutionLimit.SINGLE_WAIT,
|
||||||
|
JobExecutionLimit.THROTTLE_WAIT,
|
||||||
|
):
|
||||||
|
return
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
"""Helpers to check core security."""
|
"""Helpers to check core security."""
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
from datetime import timedelta
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from ...const import AddonState, CoreState
|
from ...const import AddonState, CoreState
|
||||||
from ...exceptions import PwnedError
|
from ...exceptions import PwnedError
|
||||||
from ...jobs.const import JobCondition
|
from ...jobs.const import JobCondition, JobExecutionLimit
|
||||||
from ...jobs.decorator import Job
|
from ...jobs.decorator import Job
|
||||||
from ...utils.pwned import check_pwned_password
|
from ...utils.pwned import check_pwned_password
|
||||||
from ..const import ContextType, IssueType, SuggestionType
|
from ..const import ContextType, IssueType, SuggestionType
|
||||||
@ -14,7 +15,11 @@ from .base import CheckBase
|
|||||||
class CheckAddonPwned(CheckBase):
|
class CheckAddonPwned(CheckBase):
|
||||||
"""CheckAddonPwned class for check."""
|
"""CheckAddonPwned class for check."""
|
||||||
|
|
||||||
@Job(conditions=[JobCondition.INTERNET_SYSTEM])
|
@Job(
|
||||||
|
conditions=[JobCondition.INTERNET_SYSTEM],
|
||||||
|
limit=JobExecutionLimit.THROTTLE,
|
||||||
|
throttle_period=timedelta(hours=24),
|
||||||
|
)
|
||||||
async def run_check(self) -> None:
|
async def run_check(self) -> None:
|
||||||
"""Run check if not affected by issue."""
|
"""Run check if not affected by issue."""
|
||||||
await self.sys_homeassistant.secrets.reload()
|
await self.sys_homeassistant.secrets.reload()
|
||||||
|
@ -9,6 +9,8 @@ from typing import Optional
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from awesomeversion import AwesomeVersion
|
from awesomeversion import AwesomeVersion
|
||||||
|
|
||||||
|
from supervisor.jobs.const import JobExecutionLimit
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
ATTR_AUDIO,
|
ATTR_AUDIO,
|
||||||
ATTR_CHANNEL,
|
ATTR_CHANNEL,
|
||||||
@ -28,7 +30,6 @@ from .const import (
|
|||||||
from .coresys import CoreSysAttributes
|
from .coresys import CoreSysAttributes
|
||||||
from .exceptions import UpdaterError, UpdaterJobError
|
from .exceptions import UpdaterError, UpdaterJobError
|
||||||
from .jobs.decorator import Job, JobCondition
|
from .jobs.decorator import Job, JobCondition
|
||||||
from .utils import AsyncThrottle
|
|
||||||
from .utils.json import JsonConfig
|
from .utils.json import JsonConfig
|
||||||
from .validate import SCHEMA_UPDATER_CONFIG
|
from .validate import SCHEMA_UPDATER_CONFIG
|
||||||
|
|
||||||
@ -165,10 +166,11 @@ class Updater(JsonConfig, CoreSysAttributes):
|
|||||||
"""Set upstream mode."""
|
"""Set upstream mode."""
|
||||||
self._data[ATTR_CHANNEL] = value
|
self._data[ATTR_CHANNEL] = value
|
||||||
|
|
||||||
@AsyncThrottle(timedelta(seconds=30))
|
|
||||||
@Job(
|
@Job(
|
||||||
conditions=[JobCondition.INTERNET_SYSTEM],
|
conditions=[JobCondition.INTERNET_SYSTEM],
|
||||||
on_condition=UpdaterJobError,
|
on_condition=UpdaterJobError,
|
||||||
|
limit=JobExecutionLimit.THROTTLE_WAIT,
|
||||||
|
throttle_period=timedelta(seconds=30),
|
||||||
)
|
)
|
||||||
async def fetch_data(self):
|
async def fetch_data(self):
|
||||||
"""Fetch current versions from Github.
|
"""Fetch current versions from Github.
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
"""Tools file for Supervisor."""
|
"""Tools file for Supervisor."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -35,65 +34,6 @@ def process_lock(method):
|
|||||||
return wrap_api
|
return wrap_api
|
||||||
|
|
||||||
|
|
||||||
class AsyncThrottle:
|
|
||||||
"""A class for throttling the execution of tasks.
|
|
||||||
|
|
||||||
Decorator that prevents a function from being called more than once every
|
|
||||||
time period with blocking.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, delta):
|
|
||||||
"""Initialize async throttle."""
|
|
||||||
self.throttle_period = delta
|
|
||||||
self.time_of_last_call = datetime.min
|
|
||||||
self.synchronize: Optional[asyncio.Lock] = None
|
|
||||||
|
|
||||||
def __call__(self, method):
|
|
||||||
"""Throttle function."""
|
|
||||||
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
"""Throttle function wrapper."""
|
|
||||||
if not self.synchronize:
|
|
||||||
self.synchronize = asyncio.Lock()
|
|
||||||
|
|
||||||
async with self.synchronize:
|
|
||||||
now = datetime.now()
|
|
||||||
time_since_last_call = now - self.time_of_last_call
|
|
||||||
|
|
||||||
if time_since_last_call > self.throttle_period:
|
|
||||||
self.time_of_last_call = now
|
|
||||||
return await method(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncCallFilter:
|
|
||||||
"""A class for throttling the execution of tasks, with a filter.
|
|
||||||
|
|
||||||
Decorator that prevents a function from being called more than once every
|
|
||||||
time period.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, delta):
|
|
||||||
"""Initialize async throttle."""
|
|
||||||
self.throttle_period = delta
|
|
||||||
self.time_of_last_call = datetime.min
|
|
||||||
|
|
||||||
def __call__(self, method):
|
|
||||||
"""Throttle function."""
|
|
||||||
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
"""Throttle function wrapper."""
|
|
||||||
now = datetime.now()
|
|
||||||
time_since_last_call = now - self.time_of_last_call
|
|
||||||
|
|
||||||
if time_since_last_call > self.throttle_period:
|
|
||||||
self.time_of_last_call = now
|
|
||||||
return await method(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def check_port(address: IPv4Address, port: int) -> bool:
|
def check_port(address: IPv4Address, port: int) -> bool:
|
||||||
"""Check if port is mapped."""
|
"""Check if port is mapped."""
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Test the condition decorators."""
|
"""Test the condition decorators."""
|
||||||
# pylint: disable=protected-access,import-error
|
# pylint: disable=protected-access,import-error
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import timedelta
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -284,3 +285,63 @@ async def test_exectution_limit_single_wait(
|
|||||||
test = TestClass(coresys)
|
test = TestClass(coresys)
|
||||||
|
|
||||||
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
|
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_exectution_limit_throttle_wait(
|
||||||
|
coresys: CoreSys, loop: asyncio.BaseEventLoop
|
||||||
|
):
|
||||||
|
"""Test the ignore conditions decorator."""
|
||||||
|
|
||||||
|
class TestClass:
|
||||||
|
"""Test class."""
|
||||||
|
|
||||||
|
def __init__(self, coresys: CoreSys):
|
||||||
|
"""Initialize the test class."""
|
||||||
|
self.coresys = coresys
|
||||||
|
self.run = asyncio.Lock()
|
||||||
|
self.call = 0
|
||||||
|
|
||||||
|
@Job(limit=JobExecutionLimit.THROTTLE_WAIT, throttle_period=timedelta(hours=1))
|
||||||
|
async def execute(self, sleep: float):
|
||||||
|
"""Execute the class method."""
|
||||||
|
assert not self.run.locked()
|
||||||
|
async with self.run:
|
||||||
|
await asyncio.sleep(sleep)
|
||||||
|
self.call += 1
|
||||||
|
|
||||||
|
test = TestClass(coresys)
|
||||||
|
|
||||||
|
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
|
||||||
|
assert test.call == 1
|
||||||
|
|
||||||
|
await asyncio.gather(*[test.execute(0.1)])
|
||||||
|
assert test.call == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_exectution_limit_throttle(coresys: CoreSys, loop: asyncio.BaseEventLoop):
|
||||||
|
"""Test the ignore conditions decorator."""
|
||||||
|
|
||||||
|
class TestClass:
|
||||||
|
"""Test class."""
|
||||||
|
|
||||||
|
def __init__(self, coresys: CoreSys):
|
||||||
|
"""Initialize the test class."""
|
||||||
|
self.coresys = coresys
|
||||||
|
self.run = asyncio.Lock()
|
||||||
|
self.call = 0
|
||||||
|
|
||||||
|
@Job(limit=JobExecutionLimit.THROTTLE, throttle_period=timedelta(hours=1))
|
||||||
|
async def execute(self, sleep: float):
|
||||||
|
"""Execute the class method."""
|
||||||
|
assert not self.run.locked()
|
||||||
|
async with self.run:
|
||||||
|
await asyncio.sleep(sleep)
|
||||||
|
self.call += 1
|
||||||
|
|
||||||
|
test = TestClass(coresys)
|
||||||
|
|
||||||
|
await asyncio.gather(*[test.execute(0.1), test.execute(0.1), test.execute(0.1)])
|
||||||
|
assert test.call == 1
|
||||||
|
|
||||||
|
await asyncio.gather(*[test.execute(0.1)])
|
||||||
|
assert test.call == 1
|
||||||
|
@ -30,7 +30,7 @@ async def test_check(coresys: CoreSys):
|
|||||||
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
|
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
|
||||||
AsyncMock(return_value=True),
|
AsyncMock(return_value=True),
|
||||||
) as mock:
|
) as mock:
|
||||||
await addon_pwned.run_check()
|
await addon_pwned.run_check.__wrapped__(addon_pwned)
|
||||||
assert not mock.called
|
assert not mock.called
|
||||||
|
|
||||||
addon.pwned.add("123456")
|
addon.pwned.add("123456")
|
||||||
@ -38,7 +38,7 @@ async def test_check(coresys: CoreSys):
|
|||||||
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
|
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
|
||||||
AsyncMock(return_value=False),
|
AsyncMock(return_value=False),
|
||||||
) as mock:
|
) as mock:
|
||||||
await addon_pwned.run_check()
|
await addon_pwned.run_check.__wrapped__(addon_pwned)
|
||||||
assert mock.called
|
assert mock.called
|
||||||
|
|
||||||
assert len(coresys.resolution.issues) == 0
|
assert len(coresys.resolution.issues) == 0
|
||||||
@ -47,7 +47,7 @@ async def test_check(coresys: CoreSys):
|
|||||||
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
|
"supervisor.resolution.checks.addon_pwned.check_pwned_password",
|
||||||
AsyncMock(return_value=True),
|
AsyncMock(return_value=True),
|
||||||
) as mock:
|
) as mock:
|
||||||
await addon_pwned.run_check()
|
await addon_pwned.run_check.__wrapped__(addon_pwned)
|
||||||
assert mock.called
|
assert mock.called
|
||||||
|
|
||||||
assert len(coresys.resolution.issues) == 1
|
assert len(coresys.resolution.issues) == 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user