From 3b093200e3550bbba97a8f45c6357257c3757cbc Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 7 Aug 2025 00:14:58 +0200 Subject: [PATCH] Improve JobGroup locking with external ownership tracking (#6074) * Use context manager for Job concurrency control * Allow to release lock outside of Job running context * Improve JobGroup locking with external ownership tracking Track lock ownership by job UUID instead of execution context. This allows external lock release via job parameter. * Fix acquire lock in nested Jobs * Simplify nested lock tracking * Simplify Job group lock acquisition logic * Simplify by using helper methods * Allow throttling with group concurrency * Use Lock instead of Semaphore for job concurrency control Use the same synchronization primitive (Lock) for job concurrency control as used in job groups. * Go back to lock ownership tracking with references * Drop unused property `active_job_id` * Drop unused property `can_acquire` * Replace assert with cast --- supervisor/jobs/decorator.py | 182 +++++++++++++++---------------- supervisor/jobs/job_group.py | 54 +++++---- tests/jobs/test_job_decorator.py | 56 ++++++++++ 3 files changed, 176 insertions(+), 116 deletions(-) diff --git a/supervisor/jobs/decorator.py b/supervisor/jobs/decorator.py index 13f0ed4bb..f66e0c648 100644 --- a/supervisor/jobs/decorator.py +++ b/supervisor/jobs/decorator.py @@ -1,8 +1,8 @@ """Job decorator.""" import asyncio -from collections.abc import Awaitable, Callable -from contextlib import suppress +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import asynccontextmanager, suppress from datetime import datetime, timedelta from functools import wraps import logging @@ -71,7 +71,7 @@ class Job(CoreSysAttributes): self.on_condition = on_condition self._throttle_period = throttle_period self._throttle_max_calls = throttle_max_calls - self._lock: asyncio.Semaphore | None = None + self._lock: asyncio.Lock | None = None self._last_call: dict[str | None, datetime] = {} self._rate_limited_calls: dict[str | None, list[datetime]] | None = None self._internal = internal @@ -82,43 +82,42 @@ class Job(CoreSysAttributes): # Validate Options self._validate_parameters() + def _is_group_concurrency(self) -> bool: + """Check if this job uses group-level concurrency.""" + return self.concurrency in ( + JobConcurrency.GROUP_REJECT, + JobConcurrency.GROUP_QUEUE, + ) + + def _is_group_throttle(self) -> bool: + """Check if this job uses group-level throttling.""" + return self.throttle in ( + JobThrottle.GROUP_THROTTLE, + JobThrottle.GROUP_RATE_LIMIT, + ) + + def _is_rate_limit_throttle(self) -> bool: + """Check if this job uses rate limiting (job or group level).""" + return self.throttle in ( + JobThrottle.RATE_LIMIT, + JobThrottle.GROUP_RATE_LIMIT, + ) + def _validate_parameters(self) -> None: """Validate job parameters.""" # Validate throttle parameters - if ( - self.throttle - in ( - JobThrottle.THROTTLE, - JobThrottle.GROUP_THROTTLE, - JobThrottle.RATE_LIMIT, - JobThrottle.GROUP_RATE_LIMIT, - ) - and self._throttle_period is None - ): + if self.throttle is not None and self._throttle_period is None: raise RuntimeError( f"Job {self.name} is using throttle {self.throttle} without a throttle period!" ) - if self.throttle in ( - JobThrottle.RATE_LIMIT, - JobThrottle.GROUP_RATE_LIMIT, - ): + if self._is_rate_limit_throttle(): if self._throttle_max_calls is None: raise RuntimeError( f"Job {self.name} is using throttle {self.throttle} without throttle max calls!" ) self._rate_limited_calls = {} - if self.throttle is not None and self.concurrency in ( - JobConcurrency.GROUP_REJECT, - JobConcurrency.GROUP_QUEUE, - ): - # We cannot release group locks when Job is not running (e.g. throttled) - # which makes these combinations impossible to use currently. - raise RuntimeError( - f"Job {self.name} is using throttling ({self.throttle}) with group concurrency ({self.concurrency}), which is not allowed!" - ) - @property def throttle_max_calls(self) -> int: """Return max calls for throttle.""" @@ -127,9 +126,9 @@ class Job(CoreSysAttributes): return self._throttle_max_calls @property - def lock(self) -> asyncio.Semaphore: + def lock(self) -> asyncio.Lock: """Return lock for limits.""" - # asyncio.Semaphore objects must be created in event loop + # asyncio.Lock objects must be created in event loop # Since this is sync code it is not safe to create if missing here if not self._lock: raise RuntimeError("Lock has not been created yet!") @@ -201,7 +200,7 @@ class Job(CoreSysAttributes): # Setup lock for limits if self._lock is None: - self._lock = asyncio.Semaphore() + self._lock = asyncio.Lock() # Job groups job_group: JobGroup | None = None @@ -211,18 +210,12 @@ class Job(CoreSysAttributes): # Check for group-based parameters if not job_group: - if self.concurrency in ( - JobConcurrency.GROUP_REJECT, - JobConcurrency.GROUP_QUEUE, - ): + if self._is_group_concurrency(): raise RuntimeError( f"Job {self.name} uses group concurrency ({self.concurrency}) but is not on a JobGroup! " f"The class must inherit from JobGroup to use GROUP_REJECT or GROUP_QUEUE." ) from None - if self.throttle in ( - JobThrottle.GROUP_THROTTLE, - JobThrottle.GROUP_RATE_LIMIT, - ): + if self._is_group_throttle(): raise RuntimeError( f"Job {self.name} uses group throttling ({self.throttle}) but is not on a JobGroup! " f"The class must inherit from JobGroup to use GROUP_THROTTLE or GROUP_RATE_LIMIT." @@ -279,41 +272,34 @@ class Job(CoreSysAttributes): except JobConditionException as err: return self._handle_job_condition_exception(err) - # Handle execution limits - await self._handle_concurrency_control(job_group, job) - try: + # Handle execution limits using context manager + async with self._concurrency_control(job_group, job): if not await self._handle_throttling(group_name): - self._release_concurrency_control(job_group) return # Job was throttled, exit early - except Exception: - self._release_concurrency_control(job_group) - raise - # Execute Job - with job.start(): - try: - self.set_last_call(datetime.now(), group_name) - if self._rate_limited_calls is not None: - self.add_rate_limited_call( - self.last_call(group_name), group_name - ) + # Execute Job + with job.start(): + try: + self.set_last_call(datetime.now(), group_name) + if self._rate_limited_calls is not None: + self.add_rate_limited_call( + self.last_call(group_name), group_name + ) - return await method(obj, *args, **kwargs) + return await method(obj, *args, **kwargs) - # If a method has a conditional JobCondition, they must check it in the method - # These should be handled like normal JobConditions as much as possible - except JobConditionException as err: - return self._handle_job_condition_exception(err) - except HassioError as err: - job.capture_error(err) - raise err - except Exception as err: - _LOGGER.exception("Unhandled exception: %s", err) - job.capture_error() - await async_capture_exception(err) - raise JobException() from err - finally: - self._release_concurrency_control(job_group) + # If a method has a conditional JobCondition, they must check it in the method + # These should be handled like normal JobConditions as much as possible + except JobConditionException as err: + return self._handle_job_condition_exception(err) + except HassioError as err: + job.capture_error(err) + raise err + except Exception as err: + _LOGGER.exception("Unhandled exception: %s", err) + job.capture_error() + await async_capture_exception(err) + raise JobException() from err # Jobs that weren't started are always cleaned up. Also clean up done jobs if required finally: @@ -455,26 +441,34 @@ class Job(CoreSysAttributes): f"'{method_name}' blocked from execution, mounting not supported on system" ) - def _release_concurrency_control(self, job_group: JobGroup | None) -> None: + def _release_concurrency_control( + self, job_group: JobGroup | None, job: SupervisorJob + ) -> None: """Release concurrency control locks.""" - if self.concurrency == JobConcurrency.REJECT: + if self._is_group_concurrency(): + # Group-level concurrency: delegate to job group + cast(JobGroup, job_group).release(job) + elif self.concurrency in (JobConcurrency.REJECT, JobConcurrency.QUEUE): + # Job-level concurrency: use semaphore if self.lock.locked(): self.lock.release() - elif self.concurrency == JobConcurrency.QUEUE: - if self.lock.locked(): - self.lock.release() - elif self.concurrency in ( - JobConcurrency.GROUP_REJECT, - JobConcurrency.GROUP_QUEUE, - ): - if job_group and job_group.has_lock: - job_group.release() async def _handle_concurrency_control( self, job_group: JobGroup | None, job: SupervisorJob ) -> None: """Handle concurrency control limits.""" - if self.concurrency == JobConcurrency.REJECT: + if self._is_group_concurrency(): + # Group-level concurrency: delegate to job group + try: + await cast(JobGroup, job_group).acquire( + job, wait=self.concurrency == JobConcurrency.GROUP_QUEUE + ) + except JobGroupExecutionLimitExceeded as err: + if self.on_condition: + raise self.on_condition(str(err)) from err + raise err + elif self.concurrency == JobConcurrency.REJECT: + # Job-level reject: fail if lock is taken if self.lock.locked(): on_condition = ( JobException if self.on_condition is None else self.on_condition @@ -482,21 +476,19 @@ class Job(CoreSysAttributes): raise on_condition("Another job is running") await self.lock.acquire() elif self.concurrency == JobConcurrency.QUEUE: + # Job-level queue: wait for lock await self.lock.acquire() - elif self.concurrency == JobConcurrency.GROUP_REJECT: - try: - await cast(JobGroup, job_group).acquire(job, wait=False) - except JobGroupExecutionLimitExceeded as err: - if self.on_condition: - raise self.on_condition(str(err)) from err - raise err - elif self.concurrency == JobConcurrency.GROUP_QUEUE: - try: - await cast(JobGroup, job_group).acquire(job, wait=True) - except JobGroupExecutionLimitExceeded as err: - if self.on_condition: - raise self.on_condition(str(err)) from err - raise err + + @asynccontextmanager + async def _concurrency_control( + self, job_group: JobGroup | None, job: SupervisorJob + ) -> AsyncIterator[None]: + """Context manager for concurrency control that ensures locks are always released.""" + await self._handle_concurrency_control(job_group, job) + try: + yield + finally: + self._release_concurrency_control(job_group, job) async def _handle_throttling(self, group_name: str | None) -> bool: """Handle throttling limits. Returns True if job should continue, False if throttled.""" @@ -506,7 +498,7 @@ class Job(CoreSysAttributes): if time_since_last_call < throttle_period: # Always return False when throttled (skip execution) return False - elif self.throttle in (JobThrottle.RATE_LIMIT, JobThrottle.GROUP_RATE_LIMIT): + elif self._is_rate_limit_throttle(): # Only reprocess array when necessary (at limit) if len(self.rate_limited_calls(group_name)) >= self.throttle_max_calls: self.set_rate_limited_calls( diff --git a/supervisor/jobs/job_group.py b/supervisor/jobs/job_group.py index 4dece17e3..bd870483f 100644 --- a/supervisor/jobs/job_group.py +++ b/supervisor/jobs/job_group.py @@ -23,14 +23,14 @@ class JobGroup(CoreSysAttributes): self.coresys: CoreSys = coresys self._group_name: str = group_name self._lock: Lock = Lock() - self._active_job: SupervisorJob | None = None + self._lock_owner: SupervisorJob | None = None self._parent_jobs: list[SupervisorJob] = [] self._job_reference: str | None = job_reference @property def active_job(self) -> SupervisorJob | None: """Get active job ID.""" - return self._active_job + return self._lock_owner @property def group_name(self) -> str: @@ -40,42 +40,54 @@ class JobGroup(CoreSysAttributes): @property def has_lock(self) -> bool: """Return true if current task has the lock on this job group.""" - return ( - self.active_job is not None - and self.sys_jobs.is_job - and self.active_job == self.sys_jobs.current - ) + if not self._lock_owner: + return False + + if not self.sys_jobs.is_job: + return False + + current_job = self.sys_jobs.current + # Check if current job owns lock directly + return current_job == self._lock_owner @property def job_reference(self) -> str | None: """Return value to use as reference for all jobs created for this job group.""" return self._job_reference + def is_locked_by(self, job: SupervisorJob) -> bool: + """Check if this specific job owns the lock.""" + return self._lock_owner == job + async def acquire(self, job: SupervisorJob, wait: bool = False) -> None: """Acquire the lock for the group for the specified job.""" + # If we already own the lock (nested call or same job chain), just update parent stack + if self.has_lock: + if self._lock_owner: + self._parent_jobs.append(self._lock_owner) + self._lock_owner = job + return + # If there's another job running and we're not waiting, raise - if self.active_job and not self.has_lock and not wait: + if self._lock_owner and not wait: raise JobGroupExecutionLimitExceeded( f"Another job is running for job group {self.group_name}" ) - # Else if we don't have the lock, acquire it - if not self.has_lock: - await self._lock.acquire() + # Acquire the actual asyncio lock + await self._lock.acquire() - # Store the job ID we acquired the lock for - if self.active_job: - self._parent_jobs.append(self.active_job) + # Set ownership + self._lock_owner = job - self._active_job = job - - def release(self) -> None: + def release(self, job: SupervisorJob) -> None: """Release the lock for the group or return it to parent.""" - if not self.has_lock: - raise JobException("Cannot release as caller does not own lock") + if not self.is_locked_by(job): + raise JobException(f"Job {job.uuid} does not own the lock") + # Return to parent job if exists if self._parent_jobs: - self._active_job = self._parent_jobs.pop() + self._lock_owner = self._parent_jobs.pop() else: - self._active_job = None + self._lock_owner = None self._lock.release() diff --git a/tests/jobs/test_job_decorator.py b/tests/jobs/test_job_decorator.py index eacc8a893..48f59d3da 100644 --- a/tests/jobs/test_job_decorator.py +++ b/tests/jobs/test_job_decorator.py @@ -1300,3 +1300,59 @@ async def test_concurency_reject_and_rate_limit( await test.execute() assert test.call == 2 + + +async def test_group_concurrency_with_group_throttling(coresys: CoreSys): + """Test that group concurrency works with group throttling.""" + + class TestClass(JobGroup): + """Test class.""" + + def __init__(self, coresys: CoreSys): + """Initialize the test class.""" + super().__init__(coresys, "TestGroupConcurrencyThrottle") + self.call_count = 0 + self.nested_call_count = 0 + + @Job( + name="test_group_concurrency_throttle_main", + concurrency=JobConcurrency.GROUP_QUEUE, + throttle=JobThrottle.GROUP_THROTTLE, + throttle_period=timedelta(milliseconds=50), + on_condition=JobException, + ) + async def main_method(self) -> None: + """Make nested call with group concurrency and throttling.""" + self.call_count += 1 + # Test nested call to ensure lock handling works + await self.nested_method() + + @Job( + name="test_group_concurrency_throttle_nested", + concurrency=JobConcurrency.GROUP_QUEUE, + throttle=JobThrottle.GROUP_THROTTLE, + throttle_period=timedelta(milliseconds=50), + on_condition=JobException, + ) + async def nested_method(self) -> None: + """Nested method with group concurrency and throttling.""" + self.nested_call_count += 1 + + test = TestClass(coresys) + + # First call should work + await test.main_method() + assert test.call_count == 1 + assert test.nested_call_count == 1 + + # Second call should be throttled (not execute due to throttle period) + await test.main_method() + assert test.call_count == 1 # Still 1, throttled + assert test.nested_call_count == 1 # Still 1, throttled + + # Wait for throttle period to pass and try again + with time_machine.travel(utcnow() + timedelta(milliseconds=60)): + await test.main_method() + + assert test.call_count == 2 # Should execute now + assert test.nested_call_count == 2 # Nested call should also execute