mirror of
https://github.com/home-assistant/supervisor.git
synced 2025-08-07 08:17:41 +00:00
Improve JobGroup locking with external ownership tracking (#6074)
* Use context manager for Job concurrency control * Allow to release lock outside of Job running context * Improve JobGroup locking with external ownership tracking Track lock ownership by job UUID instead of execution context. This allows external lock release via job parameter. * Fix acquire lock in nested Jobs * Simplify nested lock tracking * Simplify Job group lock acquisition logic * Simplify by using helper methods * Allow throttling with group concurrency * Use Lock instead of Semaphore for job concurrency control Use the same synchronization primitive (Lock) for job concurrency control as used in job groups. * Go back to lock ownership tracking with references * Drop unused property `active_job_id` * Drop unused property `can_acquire` * Replace assert with cast
This commit is contained in:
parent
15ba1a3c94
commit
3b093200e3
@ -1,8 +1,8 @@
|
||||
"""Job decorator."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import suppress
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
import logging
|
||||
@ -71,7 +71,7 @@ class Job(CoreSysAttributes):
|
||||
self.on_condition = on_condition
|
||||
self._throttle_period = throttle_period
|
||||
self._throttle_max_calls = throttle_max_calls
|
||||
self._lock: asyncio.Semaphore | None = None
|
||||
self._lock: asyncio.Lock | None = None
|
||||
self._last_call: dict[str | None, datetime] = {}
|
||||
self._rate_limited_calls: dict[str | None, list[datetime]] | None = None
|
||||
self._internal = internal
|
||||
@ -82,43 +82,42 @@ class Job(CoreSysAttributes):
|
||||
# Validate Options
|
||||
self._validate_parameters()
|
||||
|
||||
def _validate_parameters(self) -> None:
|
||||
"""Validate job parameters."""
|
||||
# Validate throttle parameters
|
||||
if (
|
||||
self.throttle
|
||||
in (
|
||||
JobThrottle.THROTTLE,
|
||||
def _is_group_concurrency(self) -> bool:
|
||||
"""Check if this job uses group-level concurrency."""
|
||||
return self.concurrency in (
|
||||
JobConcurrency.GROUP_REJECT,
|
||||
JobConcurrency.GROUP_QUEUE,
|
||||
)
|
||||
|
||||
def _is_group_throttle(self) -> bool:
|
||||
"""Check if this job uses group-level throttling."""
|
||||
return self.throttle in (
|
||||
JobThrottle.GROUP_THROTTLE,
|
||||
JobThrottle.GROUP_RATE_LIMIT,
|
||||
)
|
||||
|
||||
def _is_rate_limit_throttle(self) -> bool:
|
||||
"""Check if this job uses rate limiting (job or group level)."""
|
||||
return self.throttle in (
|
||||
JobThrottle.RATE_LIMIT,
|
||||
JobThrottle.GROUP_RATE_LIMIT,
|
||||
)
|
||||
and self._throttle_period is None
|
||||
):
|
||||
|
||||
def _validate_parameters(self) -> None:
|
||||
"""Validate job parameters."""
|
||||
# Validate throttle parameters
|
||||
if self.throttle is not None and self._throttle_period is None:
|
||||
raise RuntimeError(
|
||||
f"Job {self.name} is using throttle {self.throttle} without a throttle period!"
|
||||
)
|
||||
|
||||
if self.throttle in (
|
||||
JobThrottle.RATE_LIMIT,
|
||||
JobThrottle.GROUP_RATE_LIMIT,
|
||||
):
|
||||
if self._is_rate_limit_throttle():
|
||||
if self._throttle_max_calls is None:
|
||||
raise RuntimeError(
|
||||
f"Job {self.name} is using throttle {self.throttle} without throttle max calls!"
|
||||
)
|
||||
self._rate_limited_calls = {}
|
||||
|
||||
if self.throttle is not None and self.concurrency in (
|
||||
JobConcurrency.GROUP_REJECT,
|
||||
JobConcurrency.GROUP_QUEUE,
|
||||
):
|
||||
# We cannot release group locks when Job is not running (e.g. throttled)
|
||||
# which makes these combinations impossible to use currently.
|
||||
raise RuntimeError(
|
||||
f"Job {self.name} is using throttling ({self.throttle}) with group concurrency ({self.concurrency}), which is not allowed!"
|
||||
)
|
||||
|
||||
@property
|
||||
def throttle_max_calls(self) -> int:
|
||||
"""Return max calls for throttle."""
|
||||
@ -127,9 +126,9 @@ class Job(CoreSysAttributes):
|
||||
return self._throttle_max_calls
|
||||
|
||||
@property
|
||||
def lock(self) -> asyncio.Semaphore:
|
||||
def lock(self) -> asyncio.Lock:
|
||||
"""Return lock for limits."""
|
||||
# asyncio.Semaphore objects must be created in event loop
|
||||
# asyncio.Lock objects must be created in event loop
|
||||
# Since this is sync code it is not safe to create if missing here
|
||||
if not self._lock:
|
||||
raise RuntimeError("Lock has not been created yet!")
|
||||
@ -201,7 +200,7 @@ class Job(CoreSysAttributes):
|
||||
|
||||
# Setup lock for limits
|
||||
if self._lock is None:
|
||||
self._lock = asyncio.Semaphore()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Job groups
|
||||
job_group: JobGroup | None = None
|
||||
@ -211,18 +210,12 @@ class Job(CoreSysAttributes):
|
||||
|
||||
# Check for group-based parameters
|
||||
if not job_group:
|
||||
if self.concurrency in (
|
||||
JobConcurrency.GROUP_REJECT,
|
||||
JobConcurrency.GROUP_QUEUE,
|
||||
):
|
||||
if self._is_group_concurrency():
|
||||
raise RuntimeError(
|
||||
f"Job {self.name} uses group concurrency ({self.concurrency}) but is not on a JobGroup! "
|
||||
f"The class must inherit from JobGroup to use GROUP_REJECT or GROUP_QUEUE."
|
||||
) from None
|
||||
if self.throttle in (
|
||||
JobThrottle.GROUP_THROTTLE,
|
||||
JobThrottle.GROUP_RATE_LIMIT,
|
||||
):
|
||||
if self._is_group_throttle():
|
||||
raise RuntimeError(
|
||||
f"Job {self.name} uses group throttling ({self.throttle}) but is not on a JobGroup! "
|
||||
f"The class must inherit from JobGroup to use GROUP_THROTTLE or GROUP_RATE_LIMIT."
|
||||
@ -279,15 +272,10 @@ class Job(CoreSysAttributes):
|
||||
except JobConditionException as err:
|
||||
return self._handle_job_condition_exception(err)
|
||||
|
||||
# Handle execution limits
|
||||
await self._handle_concurrency_control(job_group, job)
|
||||
try:
|
||||
# Handle execution limits using context manager
|
||||
async with self._concurrency_control(job_group, job):
|
||||
if not await self._handle_throttling(group_name):
|
||||
self._release_concurrency_control(job_group)
|
||||
return # Job was throttled, exit early
|
||||
except Exception:
|
||||
self._release_concurrency_control(job_group)
|
||||
raise
|
||||
|
||||
# Execute Job
|
||||
with job.start():
|
||||
@ -312,8 +300,6 @@ class Job(CoreSysAttributes):
|
||||
job.capture_error()
|
||||
await async_capture_exception(err)
|
||||
raise JobException() from err
|
||||
finally:
|
||||
self._release_concurrency_control(job_group)
|
||||
|
||||
# Jobs that weren't started are always cleaned up. Also clean up done jobs if required
|
||||
finally:
|
||||
@ -455,26 +441,34 @@ class Job(CoreSysAttributes):
|
||||
f"'{method_name}' blocked from execution, mounting not supported on system"
|
||||
)
|
||||
|
||||
def _release_concurrency_control(self, job_group: JobGroup | None) -> None:
|
||||
def _release_concurrency_control(
|
||||
self, job_group: JobGroup | None, job: SupervisorJob
|
||||
) -> None:
|
||||
"""Release concurrency control locks."""
|
||||
if self.concurrency == JobConcurrency.REJECT:
|
||||
if self._is_group_concurrency():
|
||||
# Group-level concurrency: delegate to job group
|
||||
cast(JobGroup, job_group).release(job)
|
||||
elif self.concurrency in (JobConcurrency.REJECT, JobConcurrency.QUEUE):
|
||||
# Job-level concurrency: use semaphore
|
||||
if self.lock.locked():
|
||||
self.lock.release()
|
||||
elif self.concurrency == JobConcurrency.QUEUE:
|
||||
if self.lock.locked():
|
||||
self.lock.release()
|
||||
elif self.concurrency in (
|
||||
JobConcurrency.GROUP_REJECT,
|
||||
JobConcurrency.GROUP_QUEUE,
|
||||
):
|
||||
if job_group and job_group.has_lock:
|
||||
job_group.release()
|
||||
|
||||
async def _handle_concurrency_control(
|
||||
self, job_group: JobGroup | None, job: SupervisorJob
|
||||
) -> None:
|
||||
"""Handle concurrency control limits."""
|
||||
if self.concurrency == JobConcurrency.REJECT:
|
||||
if self._is_group_concurrency():
|
||||
# Group-level concurrency: delegate to job group
|
||||
try:
|
||||
await cast(JobGroup, job_group).acquire(
|
||||
job, wait=self.concurrency == JobConcurrency.GROUP_QUEUE
|
||||
)
|
||||
except JobGroupExecutionLimitExceeded as err:
|
||||
if self.on_condition:
|
||||
raise self.on_condition(str(err)) from err
|
||||
raise err
|
||||
elif self.concurrency == JobConcurrency.REJECT:
|
||||
# Job-level reject: fail if lock is taken
|
||||
if self.lock.locked():
|
||||
on_condition = (
|
||||
JobException if self.on_condition is None else self.on_condition
|
||||
@ -482,21 +476,19 @@ class Job(CoreSysAttributes):
|
||||
raise on_condition("Another job is running")
|
||||
await self.lock.acquire()
|
||||
elif self.concurrency == JobConcurrency.QUEUE:
|
||||
# Job-level queue: wait for lock
|
||||
await self.lock.acquire()
|
||||
elif self.concurrency == JobConcurrency.GROUP_REJECT:
|
||||
|
||||
@asynccontextmanager
|
||||
async def _concurrency_control(
|
||||
self, job_group: JobGroup | None, job: SupervisorJob
|
||||
) -> AsyncIterator[None]:
|
||||
"""Context manager for concurrency control that ensures locks are always released."""
|
||||
await self._handle_concurrency_control(job_group, job)
|
||||
try:
|
||||
await cast(JobGroup, job_group).acquire(job, wait=False)
|
||||
except JobGroupExecutionLimitExceeded as err:
|
||||
if self.on_condition:
|
||||
raise self.on_condition(str(err)) from err
|
||||
raise err
|
||||
elif self.concurrency == JobConcurrency.GROUP_QUEUE:
|
||||
try:
|
||||
await cast(JobGroup, job_group).acquire(job, wait=True)
|
||||
except JobGroupExecutionLimitExceeded as err:
|
||||
if self.on_condition:
|
||||
raise self.on_condition(str(err)) from err
|
||||
raise err
|
||||
yield
|
||||
finally:
|
||||
self._release_concurrency_control(job_group, job)
|
||||
|
||||
async def _handle_throttling(self, group_name: str | None) -> bool:
|
||||
"""Handle throttling limits. Returns True if job should continue, False if throttled."""
|
||||
@ -506,7 +498,7 @@ class Job(CoreSysAttributes):
|
||||
if time_since_last_call < throttle_period:
|
||||
# Always return False when throttled (skip execution)
|
||||
return False
|
||||
elif self.throttle in (JobThrottle.RATE_LIMIT, JobThrottle.GROUP_RATE_LIMIT):
|
||||
elif self._is_rate_limit_throttle():
|
||||
# Only reprocess array when necessary (at limit)
|
||||
if len(self.rate_limited_calls(group_name)) >= self.throttle_max_calls:
|
||||
self.set_rate_limited_calls(
|
||||
|
@ -23,14 +23,14 @@ class JobGroup(CoreSysAttributes):
|
||||
self.coresys: CoreSys = coresys
|
||||
self._group_name: str = group_name
|
||||
self._lock: Lock = Lock()
|
||||
self._active_job: SupervisorJob | None = None
|
||||
self._lock_owner: SupervisorJob | None = None
|
||||
self._parent_jobs: list[SupervisorJob] = []
|
||||
self._job_reference: str | None = job_reference
|
||||
|
||||
@property
|
||||
def active_job(self) -> SupervisorJob | None:
|
||||
"""Get active job ID."""
|
||||
return self._active_job
|
||||
return self._lock_owner
|
||||
|
||||
@property
|
||||
def group_name(self) -> str:
|
||||
@ -40,42 +40,54 @@ class JobGroup(CoreSysAttributes):
|
||||
@property
|
||||
def has_lock(self) -> bool:
|
||||
"""Return true if current task has the lock on this job group."""
|
||||
return (
|
||||
self.active_job is not None
|
||||
and self.sys_jobs.is_job
|
||||
and self.active_job == self.sys_jobs.current
|
||||
)
|
||||
if not self._lock_owner:
|
||||
return False
|
||||
|
||||
if not self.sys_jobs.is_job:
|
||||
return False
|
||||
|
||||
current_job = self.sys_jobs.current
|
||||
# Check if current job owns lock directly
|
||||
return current_job == self._lock_owner
|
||||
|
||||
@property
|
||||
def job_reference(self) -> str | None:
|
||||
"""Return value to use as reference for all jobs created for this job group."""
|
||||
return self._job_reference
|
||||
|
||||
def is_locked_by(self, job: SupervisorJob) -> bool:
|
||||
"""Check if this specific job owns the lock."""
|
||||
return self._lock_owner == job
|
||||
|
||||
async def acquire(self, job: SupervisorJob, wait: bool = False) -> None:
|
||||
"""Acquire the lock for the group for the specified job."""
|
||||
# If we already own the lock (nested call or same job chain), just update parent stack
|
||||
if self.has_lock:
|
||||
if self._lock_owner:
|
||||
self._parent_jobs.append(self._lock_owner)
|
||||
self._lock_owner = job
|
||||
return
|
||||
|
||||
# If there's another job running and we're not waiting, raise
|
||||
if self.active_job and not self.has_lock and not wait:
|
||||
if self._lock_owner and not wait:
|
||||
raise JobGroupExecutionLimitExceeded(
|
||||
f"Another job is running for job group {self.group_name}"
|
||||
)
|
||||
|
||||
# Else if we don't have the lock, acquire it
|
||||
if not self.has_lock:
|
||||
# Acquire the actual asyncio lock
|
||||
await self._lock.acquire()
|
||||
|
||||
# Store the job ID we acquired the lock for
|
||||
if self.active_job:
|
||||
self._parent_jobs.append(self.active_job)
|
||||
# Set ownership
|
||||
self._lock_owner = job
|
||||
|
||||
self._active_job = job
|
||||
|
||||
def release(self) -> None:
|
||||
def release(self, job: SupervisorJob) -> None:
|
||||
"""Release the lock for the group or return it to parent."""
|
||||
if not self.has_lock:
|
||||
raise JobException("Cannot release as caller does not own lock")
|
||||
if not self.is_locked_by(job):
|
||||
raise JobException(f"Job {job.uuid} does not own the lock")
|
||||
|
||||
# Return to parent job if exists
|
||||
if self._parent_jobs:
|
||||
self._active_job = self._parent_jobs.pop()
|
||||
self._lock_owner = self._parent_jobs.pop()
|
||||
else:
|
||||
self._active_job = None
|
||||
self._lock_owner = None
|
||||
self._lock.release()
|
||||
|
@ -1300,3 +1300,59 @@ async def test_concurency_reject_and_rate_limit(
|
||||
await test.execute()
|
||||
|
||||
assert test.call == 2
|
||||
|
||||
|
||||
async def test_group_concurrency_with_group_throttling(coresys: CoreSys):
|
||||
"""Test that group concurrency works with group throttling."""
|
||||
|
||||
class TestClass(JobGroup):
|
||||
"""Test class."""
|
||||
|
||||
def __init__(self, coresys: CoreSys):
|
||||
"""Initialize the test class."""
|
||||
super().__init__(coresys, "TestGroupConcurrencyThrottle")
|
||||
self.call_count = 0
|
||||
self.nested_call_count = 0
|
||||
|
||||
@Job(
|
||||
name="test_group_concurrency_throttle_main",
|
||||
concurrency=JobConcurrency.GROUP_QUEUE,
|
||||
throttle=JobThrottle.GROUP_THROTTLE,
|
||||
throttle_period=timedelta(milliseconds=50),
|
||||
on_condition=JobException,
|
||||
)
|
||||
async def main_method(self) -> None:
|
||||
"""Make nested call with group concurrency and throttling."""
|
||||
self.call_count += 1
|
||||
# Test nested call to ensure lock handling works
|
||||
await self.nested_method()
|
||||
|
||||
@Job(
|
||||
name="test_group_concurrency_throttle_nested",
|
||||
concurrency=JobConcurrency.GROUP_QUEUE,
|
||||
throttle=JobThrottle.GROUP_THROTTLE,
|
||||
throttle_period=timedelta(milliseconds=50),
|
||||
on_condition=JobException,
|
||||
)
|
||||
async def nested_method(self) -> None:
|
||||
"""Nested method with group concurrency and throttling."""
|
||||
self.nested_call_count += 1
|
||||
|
||||
test = TestClass(coresys)
|
||||
|
||||
# First call should work
|
||||
await test.main_method()
|
||||
assert test.call_count == 1
|
||||
assert test.nested_call_count == 1
|
||||
|
||||
# Second call should be throttled (not execute due to throttle period)
|
||||
await test.main_method()
|
||||
assert test.call_count == 1 # Still 1, throttled
|
||||
assert test.nested_call_count == 1 # Still 1, throttled
|
||||
|
||||
# Wait for throttle period to pass and try again
|
||||
with time_machine.travel(utcnow() + timedelta(milliseconds=60)):
|
||||
await test.main_method()
|
||||
|
||||
assert test.call_count == 2 # Should execute now
|
||||
assert test.nested_call_count == 2 # Nested call should also execute
|
||||
|
Loading…
x
Reference in New Issue
Block a user