mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Fix cancellation leaking upward from the timeout util (#129003)
This commit is contained in:
parent
7e2b72fa5e
commit
c460e1bbbe
@ -16,7 +16,7 @@ from .async_ import run_callback_threadsafe
|
|||||||
ZONE_GLOBAL = "global"
|
ZONE_GLOBAL = "global"
|
||||||
|
|
||||||
|
|
||||||
class _State(str, enum.Enum):
|
class _State(enum.Enum):
|
||||||
"""States of a task."""
|
"""States of a task."""
|
||||||
|
|
||||||
INIT = "INIT"
|
INIT = "INIT"
|
||||||
@ -160,11 +160,16 @@ class _GlobalTaskContext:
|
|||||||
self._wait_zone: asyncio.Event = asyncio.Event()
|
self._wait_zone: asyncio.Event = asyncio.Event()
|
||||||
self._state: _State = _State.INIT
|
self._state: _State = _State.INIT
|
||||||
self._cool_down: float = cool_down
|
self._cool_down: float = cool_down
|
||||||
|
self._cancelling = 0
|
||||||
|
|
||||||
async def __aenter__(self) -> Self:
|
async def __aenter__(self) -> Self:
|
||||||
self._manager.global_tasks.append(self)
|
self._manager.global_tasks.append(self)
|
||||||
self._start_timer()
|
self._start_timer()
|
||||||
self._state = _State.ACTIVE
|
self._state = _State.ACTIVE
|
||||||
|
# Remember if the task was already cancelling
|
||||||
|
# so when we __aexit__ we can decide if we should
|
||||||
|
# raise asyncio.TimeoutError or let the cancellation propagate
|
||||||
|
self._cancelling = self._task.cancelling()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(
|
async def __aexit__(
|
||||||
@ -177,7 +182,15 @@ class _GlobalTaskContext:
|
|||||||
self._manager.global_tasks.remove(self)
|
self._manager.global_tasks.remove(self)
|
||||||
|
|
||||||
# Timeout on exit
|
# Timeout on exit
|
||||||
if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT:
|
if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT:
|
||||||
|
# The timeout was hit, and the task was cancelled
|
||||||
|
# so we need to uncancel the task since the cancellation
|
||||||
|
# should not leak out of the context manager
|
||||||
|
if self._task.uncancel() > self._cancelling:
|
||||||
|
# If the task was already cancelling don't raise
|
||||||
|
# asyncio.TimeoutError and instead return None
|
||||||
|
# to allow the cancellation to propagate
|
||||||
|
return None
|
||||||
raise TimeoutError
|
raise TimeoutError
|
||||||
|
|
||||||
self._state = _State.EXIT
|
self._state = _State.EXIT
|
||||||
@ -266,6 +279,7 @@ class _ZoneTaskContext:
|
|||||||
self._time_left: float = timeout
|
self._time_left: float = timeout
|
||||||
self._expiration_time: float | None = None
|
self._expiration_time: float | None = None
|
||||||
self._timeout_handler: asyncio.Handle | None = None
|
self._timeout_handler: asyncio.Handle | None = None
|
||||||
|
self._cancelling = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self) -> _State:
|
def state(self) -> _State:
|
||||||
@ -280,6 +294,11 @@ class _ZoneTaskContext:
|
|||||||
if self._zone.freezes_done:
|
if self._zone.freezes_done:
|
||||||
self._start_timer()
|
self._start_timer()
|
||||||
|
|
||||||
|
# Remember if the task was already cancelling
|
||||||
|
# so when we __aexit__ we can decide if we should
|
||||||
|
# raise asyncio.TimeoutError or let the cancellation propagate
|
||||||
|
self._cancelling = self._task.cancelling()
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(
|
async def __aexit__(
|
||||||
@ -292,7 +311,15 @@ class _ZoneTaskContext:
|
|||||||
self._stop_timer()
|
self._stop_timer()
|
||||||
|
|
||||||
# Timeout on exit
|
# Timeout on exit
|
||||||
if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT:
|
if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT:
|
||||||
|
# The timeout was hit, and the task was cancelled
|
||||||
|
# so we need to uncancel the task since the cancellation
|
||||||
|
# should not leak out of the context manager
|
||||||
|
if self._task.uncancel() > self._cancelling:
|
||||||
|
# If the task was already cancelling don't raise
|
||||||
|
# asyncio.TimeoutError and instead return None
|
||||||
|
# to allow the cancellation to propagate
|
||||||
|
return None
|
||||||
raise TimeoutError
|
raise TimeoutError
|
||||||
|
|
||||||
self._state = _State.EXIT
|
self._state = _State.EXIT
|
||||||
|
@ -146,6 +146,62 @@ async def test_simple_global_timeout_freeze_with_executor_job(
|
|||||||
await hass.async_add_executor_job(time.sleep, 0.3)
|
await hass.async_add_executor_job(time.sleep, 0.3)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_simple_global_timeout_does_not_leak_upward(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test a global timeout does not leak upward."""
|
||||||
|
timeout = TimeoutManager()
|
||||||
|
current_task = asyncio.current_task()
|
||||||
|
assert current_task is not None
|
||||||
|
cancelling_inside_timeout = None
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||||
|
async with timeout.async_timeout(0.1):
|
||||||
|
cancelling_inside_timeout = current_task.cancelling()
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
assert cancelling_inside_timeout == 0
|
||||||
|
# After the context manager exits, the task should no longer be cancelling
|
||||||
|
assert current_task.cancelling() == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_simple_global_timeout_does_swallow_cancellation(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test a global timeout does not swallow cancellation."""
|
||||||
|
timeout = TimeoutManager()
|
||||||
|
current_task = asyncio.current_task()
|
||||||
|
assert current_task is not None
|
||||||
|
cancelling_inside_timeout = None
|
||||||
|
|
||||||
|
async def task_with_timeout() -> None:
|
||||||
|
nonlocal cancelling_inside_timeout
|
||||||
|
new_task = asyncio.current_task()
|
||||||
|
assert new_task is not None
|
||||||
|
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||||
|
cancelling_inside_timeout = new_task.cancelling()
|
||||||
|
async with timeout.async_timeout(0.1):
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
# After the context manager exits, the task should no longer be cancelling
|
||||||
|
assert current_task.cancelling() == 0
|
||||||
|
|
||||||
|
task = asyncio.create_task(task_with_timeout())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
task.cancel()
|
||||||
|
assert task.cancelling() == 1
|
||||||
|
|
||||||
|
assert cancelling_inside_timeout == 0
|
||||||
|
# Cancellation should not leak into the current task
|
||||||
|
assert current_task.cancelling() == 0
|
||||||
|
# Cancellation should not be swallowed if the task is cancelled
|
||||||
|
# and it also times out
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
assert task.cancelling() == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_simple_global_timeout_freeze_reset() -> None:
|
async def test_simple_global_timeout_freeze_reset() -> None:
|
||||||
"""Test a simple global timeout freeze reset."""
|
"""Test a simple global timeout freeze reset."""
|
||||||
timeout = TimeoutManager()
|
timeout = TimeoutManager()
|
||||||
@ -166,6 +222,62 @@ async def test_simple_zone_timeout() -> None:
|
|||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_simple_zone_timeout_does_not_leak_upward(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test a zone timeout does not leak upward."""
|
||||||
|
timeout = TimeoutManager()
|
||||||
|
current_task = asyncio.current_task()
|
||||||
|
assert current_task is not None
|
||||||
|
cancelling_inside_timeout = None
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||||
|
async with timeout.async_timeout(0.1, "test"):
|
||||||
|
cancelling_inside_timeout = current_task.cancelling()
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
assert cancelling_inside_timeout == 0
|
||||||
|
# After the context manager exits, the task should no longer be cancelling
|
||||||
|
assert current_task.cancelling() == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_simple_zone_timeout_does_swallow_cancellation(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test a zone timeout does not swallow cancellation."""
|
||||||
|
timeout = TimeoutManager()
|
||||||
|
current_task = asyncio.current_task()
|
||||||
|
assert current_task is not None
|
||||||
|
cancelling_inside_timeout = None
|
||||||
|
|
||||||
|
async def task_with_timeout() -> None:
|
||||||
|
nonlocal cancelling_inside_timeout
|
||||||
|
new_task = asyncio.current_task()
|
||||||
|
assert new_task is not None
|
||||||
|
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||||
|
async with timeout.async_timeout(0.1, "test"):
|
||||||
|
cancelling_inside_timeout = current_task.cancelling()
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
# After the context manager exits, the task should no longer be cancelling
|
||||||
|
assert current_task.cancelling() == 0
|
||||||
|
|
||||||
|
task = asyncio.create_task(task_with_timeout())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
task.cancel()
|
||||||
|
assert task.cancelling() == 1
|
||||||
|
|
||||||
|
# Cancellation should not leak into the current task
|
||||||
|
assert cancelling_inside_timeout == 0
|
||||||
|
assert current_task.cancelling() == 0
|
||||||
|
# Cancellation should not be swallowed if the task is cancelled
|
||||||
|
# and it also times out
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
assert task.cancelling() == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_multiple_zone_timeout() -> None:
|
async def test_multiple_zone_timeout() -> None:
|
||||||
"""Test a simple zone timeout."""
|
"""Test a simple zone timeout."""
|
||||||
timeout = TimeoutManager()
|
timeout = TimeoutManager()
|
||||||
@ -327,7 +439,7 @@ async def test_simple_zone_timeout_freeze_without_timeout_exeption() -> None:
|
|||||||
await asyncio.sleep(0.4)
|
await asyncio.sleep(0.4)
|
||||||
|
|
||||||
|
|
||||||
async def test_simple_zone_timeout_zone_with_timeout_exeption() -> None:
|
async def test_simple_zone_timeout_zone_with_timeout_exception() -> None:
|
||||||
"""Test a simple zone timeout freeze on a zone that does not have a timeout set."""
|
"""Test a simple zone timeout freeze on a zone that does not have a timeout set."""
|
||||||
timeout = TimeoutManager()
|
timeout = TimeoutManager()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user