Cancel running pipeline on new pipeline or announcement (#125687)

* Cancel running pipeline

* Incorporate feedback

* Change to async_create_task
This commit is contained in:
Michael Hansen 2024-09-10 19:56:15 -05:00 committed by GitHub
parent c01bdd860a
commit 8e0b2b752c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 138 additions and 22 deletions

View File

@ -3,6 +3,7 @@
from abc import abstractmethod from abc import abstractmethod
import asyncio import asyncio
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
import contextlib
from enum import StrEnum from enum import StrEnum
import logging import logging
import time import time
@ -73,6 +74,7 @@ class AssistSatelliteEntity(entity.Entity):
_is_announcing = False _is_announcing = False
_wake_word_intercept_future: asyncio.Future[str | None] | None = None _wake_word_intercept_future: asyncio.Future[str | None] | None = None
_attr_tts_options: dict[str, Any] | None = None _attr_tts_options: dict[str, Any] | None = None
_pipeline_task: asyncio.Task | None = None
__assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD __assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD
@ -131,6 +133,8 @@ class AssistSatelliteEntity(entity.Entity):
Calls async_announce with message and media id. Calls async_announce with message and media id.
""" """
await self._cancel_running_pipeline()
if message is None: if message is None:
message = "" message = ""
@ -176,7 +180,7 @@ class AssistSatelliteEntity(entity.Entity):
await self.async_announce(message, media_id) await self.async_announce(message, media_id)
finally: finally:
self._is_announcing = False self._is_announcing = False
self.tts_response_finished() self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
async def async_announce(self, message: str, media_id: str) -> None: async def async_announce(self, message: str, media_id: str) -> None:
"""Announce media on the satellite. """Announce media on the satellite.
@ -193,6 +197,8 @@ class AssistSatelliteEntity(entity.Entity):
wake_word_phrase: str | None = None, wake_word_phrase: str | None = None,
) -> None: ) -> None:
"""Triggers an Assist pipeline in Home Assistant from a satellite.""" """Triggers an Assist pipeline in Home Assistant from a satellite."""
await self._cancel_running_pipeline()
if self._wake_word_intercept_future and start_stage in ( if self._wake_word_intercept_future and start_stage in (
PipelineStage.WAKE_WORD, PipelineStage.WAKE_WORD,
PipelineStage.STT, PipelineStage.STT,
@ -248,7 +254,10 @@ class AssistSatelliteEntity(entity.Entity):
# Set entity state based on pipeline events # Set entity state based on pipeline events
self._run_has_tts = False self._run_has_tts = False
await async_pipeline_from_audio_stream( assert self.platform.config_entry is not None
self._pipeline_task = self.platform.config_entry.async_create_background_task(
self.hass,
async_pipeline_from_audio_stream(
self.hass, self.hass,
context=self._context, context=self._context,
event_callback=self._internal_on_pipeline_event, event_callback=self._internal_on_pipeline_event,
@ -271,8 +280,24 @@ class AssistSatelliteEntity(entity.Entity):
), ),
start_stage=start_stage, start_stage=start_stage,
end_stage=end_stage, end_stage=end_stage,
),
f"{self.entity_id}_pipeline",
) )
try:
await self._pipeline_task
finally:
self._pipeline_task = None
async def _cancel_running_pipeline(self) -> None:
"""Cancel the current pipeline if it's running."""
if self._pipeline_task is not None:
self._pipeline_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._pipeline_task
self._pipeline_task = None
@abstractmethod @abstractmethod
def on_pipeline_event(self, event: PipelineEvent) -> None: def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events.""" """Handle pipeline events."""

View File

@ -93,6 +93,55 @@ async def test_entity_state(
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
async def test_new_pipeline_cancels_pipeline(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
) -> None:
"""Test that a new pipeline run cancels any running pipeline."""
pipeline1_started = asyncio.Event()
pipeline1_finished = asyncio.Event()
pipeline1_cancelled = asyncio.Event()
pipeline2_finished = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
if not pipeline1_started.is_set():
# First pipeline run
pipeline1_started.set()
# Wait for pipeline to be cancelled
try:
await pipeline1_finished.wait()
except asyncio.CancelledError:
pipeline1_cancelled.set()
raise
else:
# Second pipeline run
pipeline2_finished.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
):
hass.async_create_task(
entity.async_accept_pipeline_from_satellite(
object(), # type: ignore[arg-type]
)
)
async with asyncio.timeout(1):
await pipeline1_started.wait()
# Start a second pipeline
await entity.async_accept_pipeline_from_satellite(
object(), # type: ignore[arg-type]
)
await pipeline1_cancelled.wait()
await pipeline2_finished.wait()
@pytest.mark.parametrize( @pytest.mark.parametrize(
("service_data", "expected_params"), ("service_data", "expected_params"),
[ [
@ -210,6 +259,48 @@ async def test_announce_busy(
await announce_task await announce_task
async def test_announce_cancels_pipeline(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
) -> None:
"""Test that announcements cancel any running pipeline."""
media_id = "https://www.home-assistant.io/resolved.mp3"
pipeline_started = asyncio.Event()
pipeline_finished = asyncio.Event()
pipeline_cancelled = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
pipeline_started.set()
# Wait for pipeline to be cancelled
try:
await pipeline_finished.wait()
except asyncio.CancelledError:
pipeline_cancelled.set()
raise
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch.object(entity, "async_announce") as mock_async_announce,
):
hass.async_create_task(
entity.async_accept_pipeline_from_satellite(
object(), # type: ignore[arg-type]
)
)
async with asyncio.timeout(1):
await pipeline_started.wait()
await entity.async_internal_announce(None, media_id)
await pipeline_cancelled.wait()
mock_async_announce.assert_called_once()
async def test_context_refresh( async def test_context_refresh(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None: ) -> None: