Improve JobGroup locking with external ownership tracking (#6074)

* Use context manager for Job concurrency control

* Allow to release lock outside of Job running context

* Improve JobGroup locking with external ownership tracking

Track lock ownership by job UUID instead of execution context. This
allows external lock release via job parameter.

* Fix acquire lock in nested Jobs

* Simplify nested lock tracking

* Simplify Job group lock acquisition logic

* Simplify by using helper methods

* Allow throttling with group concurrency

* Use Lock instead of Semaphore for job concurrency control

Use the same synchronization primitive (Lock) for job concurrency
control as used in job groups.

* Go back to lock ownership tracking with references

* Drop unused property `active_job_id`

* Drop unused property `can_acquire`

* Replace assert with cast
This commit is contained in:
Stefan Agner 2025-08-07 00:14:58 +02:00 committed by GitHub
parent 15ba1a3c94
commit 3b093200e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 176 additions and 116 deletions

View File

@ -1,8 +1,8 @@
"""Job decorator."""
import asyncio
from collections.abc import Awaitable, Callable
from contextlib import suppress
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager, suppress
from datetime import datetime, timedelta
from functools import wraps
import logging
@ -71,7 +71,7 @@ class Job(CoreSysAttributes):
self.on_condition = on_condition
self._throttle_period = throttle_period
self._throttle_max_calls = throttle_max_calls
self._lock: asyncio.Semaphore | None = None
self._lock: asyncio.Lock | None = None
self._last_call: dict[str | None, datetime] = {}
self._rate_limited_calls: dict[str | None, list[datetime]] | None = None
self._internal = internal
@ -82,43 +82,42 @@ class Job(CoreSysAttributes):
# Validate Options
self._validate_parameters()
def _is_group_concurrency(self) -> bool:
"""Check if this job uses group-level concurrency."""
return self.concurrency in (
JobConcurrency.GROUP_REJECT,
JobConcurrency.GROUP_QUEUE,
)
def _is_group_throttle(self) -> bool:
"""Check if this job uses group-level throttling."""
return self.throttle in (
JobThrottle.GROUP_THROTTLE,
JobThrottle.GROUP_RATE_LIMIT,
)
def _is_rate_limit_throttle(self) -> bool:
"""Check if this job uses rate limiting (job or group level)."""
return self.throttle in (
JobThrottle.RATE_LIMIT,
JobThrottle.GROUP_RATE_LIMIT,
)
def _validate_parameters(self) -> None:
"""Validate job parameters."""
# Validate throttle parameters
if (
self.throttle
in (
JobThrottle.THROTTLE,
JobThrottle.GROUP_THROTTLE,
JobThrottle.RATE_LIMIT,
JobThrottle.GROUP_RATE_LIMIT,
)
and self._throttle_period is None
):
if self.throttle is not None and self._throttle_period is None:
raise RuntimeError(
f"Job {self.name} is using throttle {self.throttle} without a throttle period!"
)
if self.throttle in (
JobThrottle.RATE_LIMIT,
JobThrottle.GROUP_RATE_LIMIT,
):
if self._is_rate_limit_throttle():
if self._throttle_max_calls is None:
raise RuntimeError(
f"Job {self.name} is using throttle {self.throttle} without throttle max calls!"
)
self._rate_limited_calls = {}
if self.throttle is not None and self.concurrency in (
JobConcurrency.GROUP_REJECT,
JobConcurrency.GROUP_QUEUE,
):
# We cannot release group locks when Job is not running (e.g. throttled)
# which makes these combinations impossible to use currently.
raise RuntimeError(
f"Job {self.name} is using throttling ({self.throttle}) with group concurrency ({self.concurrency}), which is not allowed!"
)
@property
def throttle_max_calls(self) -> int:
"""Return max calls for throttle."""
@ -127,9 +126,9 @@ class Job(CoreSysAttributes):
return self._throttle_max_calls
@property
def lock(self) -> asyncio.Semaphore:
def lock(self) -> asyncio.Lock:
"""Return lock for limits."""
# asyncio.Semaphore objects must be created in event loop
# asyncio.Lock objects must be created in event loop
# Since this is sync code it is not safe to create if missing here
if not self._lock:
raise RuntimeError("Lock has not been created yet!")
@ -201,7 +200,7 @@ class Job(CoreSysAttributes):
# Setup lock for limits
if self._lock is None:
self._lock = asyncio.Semaphore()
self._lock = asyncio.Lock()
# Job groups
job_group: JobGroup | None = None
@ -211,18 +210,12 @@ class Job(CoreSysAttributes):
# Check for group-based parameters
if not job_group:
if self.concurrency in (
JobConcurrency.GROUP_REJECT,
JobConcurrency.GROUP_QUEUE,
):
if self._is_group_concurrency():
raise RuntimeError(
f"Job {self.name} uses group concurrency ({self.concurrency}) but is not on a JobGroup! "
f"The class must inherit from JobGroup to use GROUP_REJECT or GROUP_QUEUE."
) from None
if self.throttle in (
JobThrottle.GROUP_THROTTLE,
JobThrottle.GROUP_RATE_LIMIT,
):
if self._is_group_throttle():
raise RuntimeError(
f"Job {self.name} uses group throttling ({self.throttle}) but is not on a JobGroup! "
f"The class must inherit from JobGroup to use GROUP_THROTTLE or GROUP_RATE_LIMIT."
@ -279,41 +272,34 @@ class Job(CoreSysAttributes):
except JobConditionException as err:
return self._handle_job_condition_exception(err)
# Handle execution limits
await self._handle_concurrency_control(job_group, job)
try:
# Handle execution limits using context manager
async with self._concurrency_control(job_group, job):
if not await self._handle_throttling(group_name):
self._release_concurrency_control(job_group)
return # Job was throttled, exit early
except Exception:
self._release_concurrency_control(job_group)
raise
# Execute Job
with job.start():
try:
self.set_last_call(datetime.now(), group_name)
if self._rate_limited_calls is not None:
self.add_rate_limited_call(
self.last_call(group_name), group_name
)
# Execute Job
with job.start():
try:
self.set_last_call(datetime.now(), group_name)
if self._rate_limited_calls is not None:
self.add_rate_limited_call(
self.last_call(group_name), group_name
)
return await method(obj, *args, **kwargs)
return await method(obj, *args, **kwargs)
# If a method has a conditional JobCondition, they must check it in the method
# These should be handled like normal JobConditions as much as possible
except JobConditionException as err:
return self._handle_job_condition_exception(err)
except HassioError as err:
job.capture_error(err)
raise err
except Exception as err:
_LOGGER.exception("Unhandled exception: %s", err)
job.capture_error()
await async_capture_exception(err)
raise JobException() from err
finally:
self._release_concurrency_control(job_group)
# If a method has a conditional JobCondition, they must check it in the method
# These should be handled like normal JobConditions as much as possible
except JobConditionException as err:
return self._handle_job_condition_exception(err)
except HassioError as err:
job.capture_error(err)
raise err
except Exception as err:
_LOGGER.exception("Unhandled exception: %s", err)
job.capture_error()
await async_capture_exception(err)
raise JobException() from err
# Jobs that weren't started are always cleaned up. Also clean up done jobs if required
finally:
@ -455,26 +441,34 @@ class Job(CoreSysAttributes):
f"'{method_name}' blocked from execution, mounting not supported on system"
)
def _release_concurrency_control(self, job_group: JobGroup | None) -> None:
def _release_concurrency_control(
self, job_group: JobGroup | None, job: SupervisorJob
) -> None:
"""Release concurrency control locks."""
if self.concurrency == JobConcurrency.REJECT:
if self._is_group_concurrency():
# Group-level concurrency: delegate to job group
cast(JobGroup, job_group).release(job)
elif self.concurrency in (JobConcurrency.REJECT, JobConcurrency.QUEUE):
# Job-level concurrency: use semaphore
if self.lock.locked():
self.lock.release()
elif self.concurrency == JobConcurrency.QUEUE:
if self.lock.locked():
self.lock.release()
elif self.concurrency in (
JobConcurrency.GROUP_REJECT,
JobConcurrency.GROUP_QUEUE,
):
if job_group and job_group.has_lock:
job_group.release()
async def _handle_concurrency_control(
self, job_group: JobGroup | None, job: SupervisorJob
) -> None:
"""Handle concurrency control limits."""
if self.concurrency == JobConcurrency.REJECT:
if self._is_group_concurrency():
# Group-level concurrency: delegate to job group
try:
await cast(JobGroup, job_group).acquire(
job, wait=self.concurrency == JobConcurrency.GROUP_QUEUE
)
except JobGroupExecutionLimitExceeded as err:
if self.on_condition:
raise self.on_condition(str(err)) from err
raise err
elif self.concurrency == JobConcurrency.REJECT:
# Job-level reject: fail if lock is taken
if self.lock.locked():
on_condition = (
JobException if self.on_condition is None else self.on_condition
@ -482,21 +476,19 @@ class Job(CoreSysAttributes):
raise on_condition("Another job is running")
await self.lock.acquire()
elif self.concurrency == JobConcurrency.QUEUE:
# Job-level queue: wait for lock
await self.lock.acquire()
elif self.concurrency == JobConcurrency.GROUP_REJECT:
try:
await cast(JobGroup, job_group).acquire(job, wait=False)
except JobGroupExecutionLimitExceeded as err:
if self.on_condition:
raise self.on_condition(str(err)) from err
raise err
elif self.concurrency == JobConcurrency.GROUP_QUEUE:
try:
await cast(JobGroup, job_group).acquire(job, wait=True)
except JobGroupExecutionLimitExceeded as err:
if self.on_condition:
raise self.on_condition(str(err)) from err
raise err
@asynccontextmanager
async def _concurrency_control(
self, job_group: JobGroup | None, job: SupervisorJob
) -> AsyncIterator[None]:
"""Context manager for concurrency control that ensures locks are always released."""
await self._handle_concurrency_control(job_group, job)
try:
yield
finally:
self._release_concurrency_control(job_group, job)
async def _handle_throttling(self, group_name: str | None) -> bool:
"""Handle throttling limits. Returns True if job should continue, False if throttled."""
@ -506,7 +498,7 @@ class Job(CoreSysAttributes):
if time_since_last_call < throttle_period:
# Always return False when throttled (skip execution)
return False
elif self.throttle in (JobThrottle.RATE_LIMIT, JobThrottle.GROUP_RATE_LIMIT):
elif self._is_rate_limit_throttle():
# Only reprocess array when necessary (at limit)
if len(self.rate_limited_calls(group_name)) >= self.throttle_max_calls:
self.set_rate_limited_calls(

View File

@ -23,14 +23,14 @@ class JobGroup(CoreSysAttributes):
self.coresys: CoreSys = coresys
self._group_name: str = group_name
self._lock: Lock = Lock()
self._active_job: SupervisorJob | None = None
self._lock_owner: SupervisorJob | None = None
self._parent_jobs: list[SupervisorJob] = []
self._job_reference: str | None = job_reference
@property
def active_job(self) -> SupervisorJob | None:
"""Get active job ID."""
return self._active_job
return self._lock_owner
@property
def group_name(self) -> str:
@ -40,42 +40,54 @@ class JobGroup(CoreSysAttributes):
@property
def has_lock(self) -> bool:
"""Return true if current task has the lock on this job group."""
return (
self.active_job is not None
and self.sys_jobs.is_job
and self.active_job == self.sys_jobs.current
)
if not self._lock_owner:
return False
if not self.sys_jobs.is_job:
return False
current_job = self.sys_jobs.current
# Check if current job owns lock directly
return current_job == self._lock_owner
@property
def job_reference(self) -> str | None:
"""Return value to use as reference for all jobs created for this job group."""
return self._job_reference
def is_locked_by(self, job: SupervisorJob) -> bool:
"""Check if this specific job owns the lock."""
return self._lock_owner == job
async def acquire(self, job: SupervisorJob, wait: bool = False) -> None:
"""Acquire the lock for the group for the specified job."""
# If we already own the lock (nested call or same job chain), just update parent stack
if self.has_lock:
if self._lock_owner:
self._parent_jobs.append(self._lock_owner)
self._lock_owner = job
return
# If there's another job running and we're not waiting, raise
if self.active_job and not self.has_lock and not wait:
if self._lock_owner and not wait:
raise JobGroupExecutionLimitExceeded(
f"Another job is running for job group {self.group_name}"
)
# Else if we don't have the lock, acquire it
if not self.has_lock:
await self._lock.acquire()
# Acquire the actual asyncio lock
await self._lock.acquire()
# Store the job ID we acquired the lock for
if self.active_job:
self._parent_jobs.append(self.active_job)
# Set ownership
self._lock_owner = job
self._active_job = job
def release(self) -> None:
def release(self, job: SupervisorJob) -> None:
"""Release the lock for the group or return it to parent."""
if not self.has_lock:
raise JobException("Cannot release as caller does not own lock")
if not self.is_locked_by(job):
raise JobException(f"Job {job.uuid} does not own the lock")
# Return to parent job if exists
if self._parent_jobs:
self._active_job = self._parent_jobs.pop()
self._lock_owner = self._parent_jobs.pop()
else:
self._active_job = None
self._lock_owner = None
self._lock.release()

View File

@ -1300,3 +1300,59 @@ async def test_concurency_reject_and_rate_limit(
await test.execute()
assert test.call == 2
async def test_group_concurrency_with_group_throttling(coresys: CoreSys):
"""Test that group concurrency works with group throttling."""
class TestClass(JobGroup):
"""Test class."""
def __init__(self, coresys: CoreSys):
"""Initialize the test class."""
super().__init__(coresys, "TestGroupConcurrencyThrottle")
self.call_count = 0
self.nested_call_count = 0
@Job(
name="test_group_concurrency_throttle_main",
concurrency=JobConcurrency.GROUP_QUEUE,
throttle=JobThrottle.GROUP_THROTTLE,
throttle_period=timedelta(milliseconds=50),
on_condition=JobException,
)
async def main_method(self) -> None:
"""Make nested call with group concurrency and throttling."""
self.call_count += 1
# Test nested call to ensure lock handling works
await self.nested_method()
@Job(
name="test_group_concurrency_throttle_nested",
concurrency=JobConcurrency.GROUP_QUEUE,
throttle=JobThrottle.GROUP_THROTTLE,
throttle_period=timedelta(milliseconds=50),
on_condition=JobException,
)
async def nested_method(self) -> None:
"""Nested method with group concurrency and throttling."""
self.nested_call_count += 1
test = TestClass(coresys)
# First call should work
await test.main_method()
assert test.call_count == 1
assert test.nested_call_count == 1
# Second call should be throttled (not execute due to throttle period)
await test.main_method()
assert test.call_count == 1 # Still 1, throttled
assert test.nested_call_count == 1 # Still 1, throttled
# Wait for throttle period to pass and try again
with time_machine.travel(utcnow() + timedelta(milliseconds=60)):
await test.main_method()
assert test.call_count == 2 # Should execute now
assert test.nested_call_count == 2 # Nested call should also execute