Improve JobGroup locking with external ownership tracking (#6074)

* Use context manager for Job concurrency control

* Allow to release lock outside of Job running context

* Improve JobGroup locking with external ownership tracking

Track lock ownership by job UUID instead of execution context. This
allows external lock release via job parameter.

* Fix acquire lock in nested Jobs

* Simplify nested lock tracking

* Simplify Job group lock acquisition logic

* Simplify by using helper methods

* Allow throttling with group concurrency

* Use Lock instead of Semaphore for job concurrency control

Use the same synchronization primitive (Lock) for job concurrency
control as used in job groups.

* Go back to lock ownership tracking with references

* Drop unused property `active_job_id`

* Drop unused property `can_acquire`

* Replace assert with cast
This commit is contained in:
Stefan Agner 2025-08-07 00:14:58 +02:00 committed by GitHub
parent 15ba1a3c94
commit 3b093200e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 176 additions and 116 deletions

View File

@ -1,8 +1,8 @@
"""Job decorator.""" """Job decorator."""
import asyncio import asyncio
from collections.abc import Awaitable, Callable from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import suppress from contextlib import asynccontextmanager, suppress
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import wraps from functools import wraps
import logging import logging
@ -71,7 +71,7 @@ class Job(CoreSysAttributes):
self.on_condition = on_condition self.on_condition = on_condition
self._throttle_period = throttle_period self._throttle_period = throttle_period
self._throttle_max_calls = throttle_max_calls self._throttle_max_calls = throttle_max_calls
self._lock: asyncio.Semaphore | None = None self._lock: asyncio.Lock | None = None
self._last_call: dict[str | None, datetime] = {} self._last_call: dict[str | None, datetime] = {}
self._rate_limited_calls: dict[str | None, list[datetime]] | None = None self._rate_limited_calls: dict[str | None, list[datetime]] | None = None
self._internal = internal self._internal = internal
@ -82,43 +82,42 @@ class Job(CoreSysAttributes):
# Validate Options # Validate Options
self._validate_parameters() self._validate_parameters()
def _is_group_concurrency(self) -> bool:
"""Check if this job uses group-level concurrency."""
return self.concurrency in (
JobConcurrency.GROUP_REJECT,
JobConcurrency.GROUP_QUEUE,
)
def _is_group_throttle(self) -> bool:
"""Check if this job uses group-level throttling."""
return self.throttle in (
JobThrottle.GROUP_THROTTLE,
JobThrottle.GROUP_RATE_LIMIT,
)
def _is_rate_limit_throttle(self) -> bool:
"""Check if this job uses rate limiting (job or group level)."""
return self.throttle in (
JobThrottle.RATE_LIMIT,
JobThrottle.GROUP_RATE_LIMIT,
)
def _validate_parameters(self) -> None: def _validate_parameters(self) -> None:
"""Validate job parameters.""" """Validate job parameters."""
# Validate throttle parameters # Validate throttle parameters
if ( if self.throttle is not None and self._throttle_period is None:
self.throttle
in (
JobThrottle.THROTTLE,
JobThrottle.GROUP_THROTTLE,
JobThrottle.RATE_LIMIT,
JobThrottle.GROUP_RATE_LIMIT,
)
and self._throttle_period is None
):
raise RuntimeError( raise RuntimeError(
f"Job {self.name} is using throttle {self.throttle} without a throttle period!" f"Job {self.name} is using throttle {self.throttle} without a throttle period!"
) )
if self.throttle in ( if self._is_rate_limit_throttle():
JobThrottle.RATE_LIMIT,
JobThrottle.GROUP_RATE_LIMIT,
):
if self._throttle_max_calls is None: if self._throttle_max_calls is None:
raise RuntimeError( raise RuntimeError(
f"Job {self.name} is using throttle {self.throttle} without throttle max calls!" f"Job {self.name} is using throttle {self.throttle} without throttle max calls!"
) )
self._rate_limited_calls = {} self._rate_limited_calls = {}
if self.throttle is not None and self.concurrency in (
JobConcurrency.GROUP_REJECT,
JobConcurrency.GROUP_QUEUE,
):
# We cannot release group locks when Job is not running (e.g. throttled)
# which makes these combinations impossible to use currently.
raise RuntimeError(
f"Job {self.name} is using throttling ({self.throttle}) with group concurrency ({self.concurrency}), which is not allowed!"
)
@property @property
def throttle_max_calls(self) -> int: def throttle_max_calls(self) -> int:
"""Return max calls for throttle.""" """Return max calls for throttle."""
@ -127,9 +126,9 @@ class Job(CoreSysAttributes):
return self._throttle_max_calls return self._throttle_max_calls
@property @property
def lock(self) -> asyncio.Semaphore: def lock(self) -> asyncio.Lock:
"""Return lock for limits.""" """Return lock for limits."""
# asyncio.Semaphore objects must be created in event loop # asyncio.Lock objects must be created in event loop
# Since this is sync code it is not safe to create if missing here # Since this is sync code it is not safe to create if missing here
if not self._lock: if not self._lock:
raise RuntimeError("Lock has not been created yet!") raise RuntimeError("Lock has not been created yet!")
@ -201,7 +200,7 @@ class Job(CoreSysAttributes):
# Setup lock for limits # Setup lock for limits
if self._lock is None: if self._lock is None:
self._lock = asyncio.Semaphore() self._lock = asyncio.Lock()
# Job groups # Job groups
job_group: JobGroup | None = None job_group: JobGroup | None = None
@ -211,18 +210,12 @@ class Job(CoreSysAttributes):
# Check for group-based parameters # Check for group-based parameters
if not job_group: if not job_group:
if self.concurrency in ( if self._is_group_concurrency():
JobConcurrency.GROUP_REJECT,
JobConcurrency.GROUP_QUEUE,
):
raise RuntimeError( raise RuntimeError(
f"Job {self.name} uses group concurrency ({self.concurrency}) but is not on a JobGroup! " f"Job {self.name} uses group concurrency ({self.concurrency}) but is not on a JobGroup! "
f"The class must inherit from JobGroup to use GROUP_REJECT or GROUP_QUEUE." f"The class must inherit from JobGroup to use GROUP_REJECT or GROUP_QUEUE."
) from None ) from None
if self.throttle in ( if self._is_group_throttle():
JobThrottle.GROUP_THROTTLE,
JobThrottle.GROUP_RATE_LIMIT,
):
raise RuntimeError( raise RuntimeError(
f"Job {self.name} uses group throttling ({self.throttle}) but is not on a JobGroup! " f"Job {self.name} uses group throttling ({self.throttle}) but is not on a JobGroup! "
f"The class must inherit from JobGroup to use GROUP_THROTTLE or GROUP_RATE_LIMIT." f"The class must inherit from JobGroup to use GROUP_THROTTLE or GROUP_RATE_LIMIT."
@ -279,41 +272,34 @@ class Job(CoreSysAttributes):
except JobConditionException as err: except JobConditionException as err:
return self._handle_job_condition_exception(err) return self._handle_job_condition_exception(err)
# Handle execution limits # Handle execution limits using context manager
await self._handle_concurrency_control(job_group, job) async with self._concurrency_control(job_group, job):
try:
if not await self._handle_throttling(group_name): if not await self._handle_throttling(group_name):
self._release_concurrency_control(job_group)
return # Job was throttled, exit early return # Job was throttled, exit early
except Exception:
self._release_concurrency_control(job_group)
raise
# Execute Job # Execute Job
with job.start(): with job.start():
try: try:
self.set_last_call(datetime.now(), group_name) self.set_last_call(datetime.now(), group_name)
if self._rate_limited_calls is not None: if self._rate_limited_calls is not None:
self.add_rate_limited_call( self.add_rate_limited_call(
self.last_call(group_name), group_name self.last_call(group_name), group_name
) )
return await method(obj, *args, **kwargs) return await method(obj, *args, **kwargs)
# If a method has a conditional JobCondition, they must check it in the method # If a method has a conditional JobCondition, they must check it in the method
# These should be handled like normal JobConditions as much as possible # These should be handled like normal JobConditions as much as possible
except JobConditionException as err: except JobConditionException as err:
return self._handle_job_condition_exception(err) return self._handle_job_condition_exception(err)
except HassioError as err: except HassioError as err:
job.capture_error(err) job.capture_error(err)
raise err raise err
except Exception as err: except Exception as err:
_LOGGER.exception("Unhandled exception: %s", err) _LOGGER.exception("Unhandled exception: %s", err)
job.capture_error() job.capture_error()
await async_capture_exception(err) await async_capture_exception(err)
raise JobException() from err raise JobException() from err
finally:
self._release_concurrency_control(job_group)
# Jobs that weren't started are always cleaned up. Also clean up done jobs if required # Jobs that weren't started are always cleaned up. Also clean up done jobs if required
finally: finally:
@ -455,26 +441,34 @@ class Job(CoreSysAttributes):
f"'{method_name}' blocked from execution, mounting not supported on system" f"'{method_name}' blocked from execution, mounting not supported on system"
) )
def _release_concurrency_control(self, job_group: JobGroup | None) -> None: def _release_concurrency_control(
self, job_group: JobGroup | None, job: SupervisorJob
) -> None:
"""Release concurrency control locks.""" """Release concurrency control locks."""
if self.concurrency == JobConcurrency.REJECT: if self._is_group_concurrency():
# Group-level concurrency: delegate to job group
cast(JobGroup, job_group).release(job)
elif self.concurrency in (JobConcurrency.REJECT, JobConcurrency.QUEUE):
# Job-level concurrency: use semaphore
if self.lock.locked(): if self.lock.locked():
self.lock.release() self.lock.release()
elif self.concurrency == JobConcurrency.QUEUE:
if self.lock.locked():
self.lock.release()
elif self.concurrency in (
JobConcurrency.GROUP_REJECT,
JobConcurrency.GROUP_QUEUE,
):
if job_group and job_group.has_lock:
job_group.release()
async def _handle_concurrency_control( async def _handle_concurrency_control(
self, job_group: JobGroup | None, job: SupervisorJob self, job_group: JobGroup | None, job: SupervisorJob
) -> None: ) -> None:
"""Handle concurrency control limits.""" """Handle concurrency control limits."""
if self.concurrency == JobConcurrency.REJECT: if self._is_group_concurrency():
# Group-level concurrency: delegate to job group
try:
await cast(JobGroup, job_group).acquire(
job, wait=self.concurrency == JobConcurrency.GROUP_QUEUE
)
except JobGroupExecutionLimitExceeded as err:
if self.on_condition:
raise self.on_condition(str(err)) from err
raise err
elif self.concurrency == JobConcurrency.REJECT:
# Job-level reject: fail if lock is taken
if self.lock.locked(): if self.lock.locked():
on_condition = ( on_condition = (
JobException if self.on_condition is None else self.on_condition JobException if self.on_condition is None else self.on_condition
@ -482,21 +476,19 @@ class Job(CoreSysAttributes):
raise on_condition("Another job is running") raise on_condition("Another job is running")
await self.lock.acquire() await self.lock.acquire()
elif self.concurrency == JobConcurrency.QUEUE: elif self.concurrency == JobConcurrency.QUEUE:
# Job-level queue: wait for lock
await self.lock.acquire() await self.lock.acquire()
elif self.concurrency == JobConcurrency.GROUP_REJECT:
try: @asynccontextmanager
await cast(JobGroup, job_group).acquire(job, wait=False) async def _concurrency_control(
except JobGroupExecutionLimitExceeded as err: self, job_group: JobGroup | None, job: SupervisorJob
if self.on_condition: ) -> AsyncIterator[None]:
raise self.on_condition(str(err)) from err """Context manager for concurrency control that ensures locks are always released."""
raise err await self._handle_concurrency_control(job_group, job)
elif self.concurrency == JobConcurrency.GROUP_QUEUE: try:
try: yield
await cast(JobGroup, job_group).acquire(job, wait=True) finally:
except JobGroupExecutionLimitExceeded as err: self._release_concurrency_control(job_group, job)
if self.on_condition:
raise self.on_condition(str(err)) from err
raise err
async def _handle_throttling(self, group_name: str | None) -> bool: async def _handle_throttling(self, group_name: str | None) -> bool:
"""Handle throttling limits. Returns True if job should continue, False if throttled.""" """Handle throttling limits. Returns True if job should continue, False if throttled."""
@ -506,7 +498,7 @@ class Job(CoreSysAttributes):
if time_since_last_call < throttle_period: if time_since_last_call < throttle_period:
# Always return False when throttled (skip execution) # Always return False when throttled (skip execution)
return False return False
elif self.throttle in (JobThrottle.RATE_LIMIT, JobThrottle.GROUP_RATE_LIMIT): elif self._is_rate_limit_throttle():
# Only reprocess array when necessary (at limit) # Only reprocess array when necessary (at limit)
if len(self.rate_limited_calls(group_name)) >= self.throttle_max_calls: if len(self.rate_limited_calls(group_name)) >= self.throttle_max_calls:
self.set_rate_limited_calls( self.set_rate_limited_calls(

View File

@ -23,14 +23,14 @@ class JobGroup(CoreSysAttributes):
self.coresys: CoreSys = coresys self.coresys: CoreSys = coresys
self._group_name: str = group_name self._group_name: str = group_name
self._lock: Lock = Lock() self._lock: Lock = Lock()
self._active_job: SupervisorJob | None = None self._lock_owner: SupervisorJob | None = None
self._parent_jobs: list[SupervisorJob] = [] self._parent_jobs: list[SupervisorJob] = []
self._job_reference: str | None = job_reference self._job_reference: str | None = job_reference
@property @property
def active_job(self) -> SupervisorJob | None: def active_job(self) -> SupervisorJob | None:
"""Get active job ID.""" """Get active job ID."""
return self._active_job return self._lock_owner
@property @property
def group_name(self) -> str: def group_name(self) -> str:
@ -40,42 +40,54 @@ class JobGroup(CoreSysAttributes):
@property @property
def has_lock(self) -> bool: def has_lock(self) -> bool:
"""Return true if current task has the lock on this job group.""" """Return true if current task has the lock on this job group."""
return ( if not self._lock_owner:
self.active_job is not None return False
and self.sys_jobs.is_job
and self.active_job == self.sys_jobs.current if not self.sys_jobs.is_job:
) return False
current_job = self.sys_jobs.current
# Check if current job owns lock directly
return current_job == self._lock_owner
@property @property
def job_reference(self) -> str | None: def job_reference(self) -> str | None:
"""Return value to use as reference for all jobs created for this job group.""" """Return value to use as reference for all jobs created for this job group."""
return self._job_reference return self._job_reference
def is_locked_by(self, job: SupervisorJob) -> bool:
"""Check if this specific job owns the lock."""
return self._lock_owner == job
async def acquire(self, job: SupervisorJob, wait: bool = False) -> None: async def acquire(self, job: SupervisorJob, wait: bool = False) -> None:
"""Acquire the lock for the group for the specified job.""" """Acquire the lock for the group for the specified job."""
# If we already own the lock (nested call or same job chain), just update parent stack
if self.has_lock:
if self._lock_owner:
self._parent_jobs.append(self._lock_owner)
self._lock_owner = job
return
# If there's another job running and we're not waiting, raise # If there's another job running and we're not waiting, raise
if self.active_job and not self.has_lock and not wait: if self._lock_owner and not wait:
raise JobGroupExecutionLimitExceeded( raise JobGroupExecutionLimitExceeded(
f"Another job is running for job group {self.group_name}" f"Another job is running for job group {self.group_name}"
) )
# Else if we don't have the lock, acquire it # Acquire the actual asyncio lock
if not self.has_lock: await self._lock.acquire()
await self._lock.acquire()
# Store the job ID we acquired the lock for # Set ownership
if self.active_job: self._lock_owner = job
self._parent_jobs.append(self.active_job)
self._active_job = job def release(self, job: SupervisorJob) -> None:
def release(self) -> None:
"""Release the lock for the group or return it to parent.""" """Release the lock for the group or return it to parent."""
if not self.has_lock: if not self.is_locked_by(job):
raise JobException("Cannot release as caller does not own lock") raise JobException(f"Job {job.uuid} does not own the lock")
# Return to parent job if exists
if self._parent_jobs: if self._parent_jobs:
self._active_job = self._parent_jobs.pop() self._lock_owner = self._parent_jobs.pop()
else: else:
self._active_job = None self._lock_owner = None
self._lock.release() self._lock.release()

View File

@ -1300,3 +1300,59 @@ async def test_concurency_reject_and_rate_limit(
await test.execute() await test.execute()
assert test.call == 2 assert test.call == 2
async def test_group_concurrency_with_group_throttling(coresys: CoreSys):
"""Test that group concurrency works with group throttling."""
class TestClass(JobGroup):
"""Test class."""
def __init__(self, coresys: CoreSys):
"""Initialize the test class."""
super().__init__(coresys, "TestGroupConcurrencyThrottle")
self.call_count = 0
self.nested_call_count = 0
@Job(
name="test_group_concurrency_throttle_main",
concurrency=JobConcurrency.GROUP_QUEUE,
throttle=JobThrottle.GROUP_THROTTLE,
throttle_period=timedelta(milliseconds=50),
on_condition=JobException,
)
async def main_method(self) -> None:
"""Make nested call with group concurrency and throttling."""
self.call_count += 1
# Test nested call to ensure lock handling works
await self.nested_method()
@Job(
name="test_group_concurrency_throttle_nested",
concurrency=JobConcurrency.GROUP_QUEUE,
throttle=JobThrottle.GROUP_THROTTLE,
throttle_period=timedelta(milliseconds=50),
on_condition=JobException,
)
async def nested_method(self) -> None:
"""Nested method with group concurrency and throttling."""
self.nested_call_count += 1
test = TestClass(coresys)
# First call should work
await test.main_method()
assert test.call_count == 1
assert test.nested_call_count == 1
# Second call should be throttled (not execute due to throttle period)
await test.main_method()
assert test.call_count == 1 # Still 1, throttled
assert test.nested_call_count == 1 # Still 1, throttled
# Wait for throttle period to pass and try again
with time_machine.travel(utcnow() + timedelta(milliseconds=60)):
await test.main_method()
assert test.call_count == 2 # Should execute now
assert test.nested_call_count == 2 # Nested call should also execute