mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 23:27:37 +00:00
Cancel running pipeline on new pipeline or announcement (#125687)
* Cancel running pipeline * Incorporate feedback * Change to async_create_task
This commit is contained in:
parent
c01bdd860a
commit
8e0b2b752c
@ -3,6 +3,7 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncIterable
|
from collections.abc import AsyncIterable
|
||||||
|
import contextlib
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@ -73,6 +74,7 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
_is_announcing = False
|
_is_announcing = False
|
||||||
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
||||||
_attr_tts_options: dict[str, Any] | None = None
|
_attr_tts_options: dict[str, Any] | None = None
|
||||||
|
_pipeline_task: asyncio.Task | None = None
|
||||||
|
|
||||||
__assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD
|
__assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
|
|
||||||
@ -131,6 +133,8 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
|
|
||||||
Calls async_announce with message and media id.
|
Calls async_announce with message and media id.
|
||||||
"""
|
"""
|
||||||
|
await self._cancel_running_pipeline()
|
||||||
|
|
||||||
if message is None:
|
if message is None:
|
||||||
message = ""
|
message = ""
|
||||||
|
|
||||||
@ -176,7 +180,7 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
await self.async_announce(message, media_id)
|
await self.async_announce(message, media_id)
|
||||||
finally:
|
finally:
|
||||||
self._is_announcing = False
|
self._is_announcing = False
|
||||||
self.tts_response_finished()
|
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||||
|
|
||||||
async def async_announce(self, message: str, media_id: str) -> None:
|
async def async_announce(self, message: str, media_id: str) -> None:
|
||||||
"""Announce media on the satellite.
|
"""Announce media on the satellite.
|
||||||
@ -193,6 +197,8 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
wake_word_phrase: str | None = None,
|
wake_word_phrase: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
|
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
|
||||||
|
await self._cancel_running_pipeline()
|
||||||
|
|
||||||
if self._wake_word_intercept_future and start_stage in (
|
if self._wake_word_intercept_future and start_stage in (
|
||||||
PipelineStage.WAKE_WORD,
|
PipelineStage.WAKE_WORD,
|
||||||
PipelineStage.STT,
|
PipelineStage.STT,
|
||||||
@ -248,7 +254,10 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
# Set entity state based on pipeline events
|
# Set entity state based on pipeline events
|
||||||
self._run_has_tts = False
|
self._run_has_tts = False
|
||||||
|
|
||||||
await async_pipeline_from_audio_stream(
|
assert self.platform.config_entry is not None
|
||||||
|
self._pipeline_task = self.platform.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
async_pipeline_from_audio_stream(
|
||||||
self.hass,
|
self.hass,
|
||||||
context=self._context,
|
context=self._context,
|
||||||
event_callback=self._internal_on_pipeline_event,
|
event_callback=self._internal_on_pipeline_event,
|
||||||
@ -271,8 +280,24 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
),
|
),
|
||||||
start_stage=start_stage,
|
start_stage=start_stage,
|
||||||
end_stage=end_stage,
|
end_stage=end_stage,
|
||||||
|
),
|
||||||
|
f"{self.entity_id}_pipeline",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._pipeline_task
|
||||||
|
finally:
|
||||||
|
self._pipeline_task = None
|
||||||
|
|
||||||
|
async def _cancel_running_pipeline(self) -> None:
|
||||||
|
"""Cancel the current pipeline if it's running."""
|
||||||
|
if self._pipeline_task is not None:
|
||||||
|
self._pipeline_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self._pipeline_task
|
||||||
|
|
||||||
|
self._pipeline_task = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
"""Handle pipeline events."""
|
"""Handle pipeline events."""
|
||||||
|
@ -93,6 +93,55 @@ async def test_entity_state(
|
|||||||
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
|
|
||||||
|
|
||||||
|
async def test_new_pipeline_cancels_pipeline(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
) -> None:
|
||||||
|
"""Test that a new pipeline run cancels any running pipeline."""
|
||||||
|
pipeline1_started = asyncio.Event()
|
||||||
|
pipeline1_finished = asyncio.Event()
|
||||||
|
pipeline1_cancelled = asyncio.Event()
|
||||||
|
pipeline2_finished = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
|
if not pipeline1_started.is_set():
|
||||||
|
# First pipeline run
|
||||||
|
pipeline1_started.set()
|
||||||
|
|
||||||
|
# Wait for pipeline to be cancelled
|
||||||
|
try:
|
||||||
|
await pipeline1_finished.wait()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pipeline1_cancelled.set()
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# Second pipeline run
|
||||||
|
pipeline2_finished.set()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
hass.async_create_task(
|
||||||
|
entity.async_accept_pipeline_from_satellite(
|
||||||
|
object(), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await pipeline1_started.wait()
|
||||||
|
|
||||||
|
# Start a second pipeline
|
||||||
|
await entity.async_accept_pipeline_from_satellite(
|
||||||
|
object(), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
await pipeline1_cancelled.wait()
|
||||||
|
await pipeline2_finished.wait()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("service_data", "expected_params"),
|
("service_data", "expected_params"),
|
||||||
[
|
[
|
||||||
@ -210,6 +259,48 @@ async def test_announce_busy(
|
|||||||
await announce_task
|
await announce_task
|
||||||
|
|
||||||
|
|
||||||
|
async def test_announce_cancels_pipeline(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
) -> None:
|
||||||
|
"""Test that announcements cancel any running pipeline."""
|
||||||
|
media_id = "https://www.home-assistant.io/resolved.mp3"
|
||||||
|
pipeline_started = asyncio.Event()
|
||||||
|
pipeline_finished = asyncio.Event()
|
||||||
|
pipeline_cancelled = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
|
pipeline_started.set()
|
||||||
|
|
||||||
|
# Wait for pipeline to be cancelled
|
||||||
|
try:
|
||||||
|
await pipeline_finished.wait()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pipeline_cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
patch.object(entity, "async_announce") as mock_async_announce,
|
||||||
|
):
|
||||||
|
hass.async_create_task(
|
||||||
|
entity.async_accept_pipeline_from_satellite(
|
||||||
|
object(), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await pipeline_started.wait()
|
||||||
|
await entity.async_internal_announce(None, media_id)
|
||||||
|
await pipeline_cancelled.wait()
|
||||||
|
|
||||||
|
mock_async_announce.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
async def test_context_refresh(
|
async def test_context_refresh(
|
||||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user