mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 20:27:08 +00:00
Fix pipeline restart in VoIP (#126668)
This commit is contained in:
parent
739165585a
commit
86f8901c96
@ -14,11 +14,7 @@ import wave
|
|||||||
from voip_utils import RtpDatagramProtocol
|
from voip_utils import RtpDatagramProtocol
|
||||||
|
|
||||||
from homeassistant.components import tts
|
from homeassistant.components import tts
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
|
||||||
PipelineEvent,
|
|
||||||
PipelineEventType,
|
|
||||||
PipelineNotFound,
|
|
||||||
)
|
|
||||||
from homeassistant.components.assist_satellite import (
|
from homeassistant.components.assist_satellite import (
|
||||||
AssistSatelliteConfiguration,
|
AssistSatelliteConfiguration,
|
||||||
AssistSatelliteEntity,
|
AssistSatelliteEntity,
|
||||||
@ -31,7 +27,6 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|||||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||||
from .devices import VoIPDevice
|
from .devices import VoIPDevice
|
||||||
from .entity import VoIPEntity
|
from .entity import VoIPEntity
|
||||||
from .util import queue_to_iterable
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from . import DomainData
|
from . import DomainData
|
||||||
@ -101,9 +96,9 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
|
|
||||||
self.config_entry = config_entry
|
self.config_entry = config_entry
|
||||||
|
|
||||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||||
self._audio_chunk_timeout: float = 2.0
|
self._audio_chunk_timeout: float = 2.0
|
||||||
self._pipeline_task: asyncio.Task | None = None
|
self._run_pipeline_task: asyncio.Task | None = None
|
||||||
self._pipeline_had_error: bool = False
|
self._pipeline_had_error: bool = False
|
||||||
self._tts_done = asyncio.Event()
|
self._tts_done = asyncio.Event()
|
||||||
self._tts_extra_timeout: float = 1.0
|
self._tts_extra_timeout: float = 1.0
|
||||||
@ -161,11 +156,11 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
|
|
||||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||||
"""Handle raw audio chunk."""
|
"""Handle raw audio chunk."""
|
||||||
if self._pipeline_task is None:
|
if self._run_pipeline_task is None:
|
||||||
self._clear_audio_queue()
|
|
||||||
|
|
||||||
# Run pipeline until voice command finishes, then start over
|
# Run pipeline until voice command finishes, then start over
|
||||||
self._pipeline_task = self.config_entry.async_create_background_task(
|
self._clear_audio_queue()
|
||||||
|
self._tts_done.clear()
|
||||||
|
self._run_pipeline_task = self.config_entry.async_create_background_task(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._run_pipeline(),
|
self._run_pipeline(),
|
||||||
"voip_pipeline_run",
|
"voip_pipeline_run",
|
||||||
@ -173,27 +168,28 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
|
|
||||||
self._audio_queue.put_nowait(audio_bytes)
|
self._audio_queue.put_nowait(audio_bytes)
|
||||||
|
|
||||||
async def _run_pipeline(
|
async def _run_pipeline(self) -> None:
|
||||||
self,
|
_LOGGER.debug("Starting pipeline")
|
||||||
) -> None:
|
|
||||||
"""Forward audio to pipeline STT and handle TTS."""
|
|
||||||
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
|
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
|
||||||
self.voip_device.set_is_active(True)
|
self.voip_device.set_is_active(True)
|
||||||
|
|
||||||
|
async def stt_stream():
|
||||||
|
while True:
|
||||||
|
async with asyncio.timeout(self._audio_chunk_timeout):
|
||||||
|
chunk = await self._audio_queue.get()
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
# Play listening tone at the start of each cycle
|
# Play listening tone at the start of each cycle
|
||||||
await self._play_tone(Tones.LISTENING, silence_before=0.2)
|
await self._play_tone(Tones.LISTENING, silence_before=0.2)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._tts_done.clear()
|
await self.async_accept_pipeline_from_satellite(
|
||||||
|
audio_stream=stt_stream(),
|
||||||
# Run pipeline with a timeout
|
)
|
||||||
_LOGGER.debug("Starting pipeline")
|
|
||||||
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
|
|
||||||
await self.async_accept_pipeline_from_satellite(
|
|
||||||
audio_stream=queue_to_iterable(
|
|
||||||
self._audio_queue, timeout=self._audio_chunk_timeout
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._pipeline_had_error:
|
if self._pipeline_had_error:
|
||||||
self._pipeline_had_error = False
|
self._pipeline_had_error = False
|
||||||
@ -204,20 +200,15 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
# This is set in _send_tts and has a timeout that's based on the
|
# This is set in _send_tts and has a timeout that's based on the
|
||||||
# length of the TTS audio.
|
# length of the TTS audio.
|
||||||
await self._tts_done.wait()
|
await self._tts_done.wait()
|
||||||
|
except TimeoutError:
|
||||||
_LOGGER.debug("Pipeline finished")
|
self.disconnect() # caller hung up
|
||||||
except PipelineNotFound:
|
|
||||||
_LOGGER.warning("Pipeline not found")
|
|
||||||
except (asyncio.CancelledError, TimeoutError):
|
|
||||||
# Expected after caller hangs up
|
|
||||||
_LOGGER.debug("Pipeline cancelled or timed out")
|
|
||||||
self.disconnect()
|
|
||||||
self._clear_audio_queue()
|
|
||||||
finally:
|
finally:
|
||||||
self.voip_device.set_is_active(False)
|
# Stop audio stream
|
||||||
|
await self._audio_queue.put(None)
|
||||||
|
|
||||||
# Allow pipeline to run again
|
self.voip_device.set_is_active(False)
|
||||||
self._pipeline_task = None
|
self._run_pipeline_task = None
|
||||||
|
_LOGGER.debug("Pipeline finished")
|
||||||
|
|
||||||
def _clear_audio_queue(self) -> None:
|
def _clear_audio_queue(self) -> None:
|
||||||
"""Ensure audio queue is empty."""
|
"""Ensure audio queue is empty."""
|
||||||
@ -247,6 +238,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
elif event.type == PipelineEventType.ERROR:
|
elif event.type == PipelineEventType.ERROR:
|
||||||
# Play error tone instead of wait for TTS when pipeline is finished.
|
# Play error tone instead of wait for TTS when pipeline is finished.
|
||||||
self._pipeline_had_error = True
|
self._pipeline_had_error = True
|
||||||
|
_LOGGER.warning(event)
|
||||||
|
|
||||||
async def _send_tts(self, media_id: str) -> None:
|
async def _send_tts(self, media_id: str) -> None:
|
||||||
"""Send TTS audio to caller via RTP."""
|
"""Send TTS audio to caller via RTP."""
|
||||||
@ -264,6 +256,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
|
|
||||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||||
# Don't overlap TTS and processing beep
|
# Don't overlap TTS and processing beep
|
||||||
|
_LOGGER.debug("Waiting for processing tone")
|
||||||
await self._processing_tone_done.wait()
|
await self._processing_tone_done.wait()
|
||||||
|
|
||||||
with io.BytesIO(data) as wav_io:
|
with io.BytesIO(data) as wav_io:
|
||||||
@ -297,12 +290,12 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
_LOGGER.warning("TTS timeout")
|
_LOGGER.warning("TTS timeout")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
# Signal pipeline to restart
|
|
||||||
self._tts_done.set()
|
|
||||||
|
|
||||||
# Update satellite state
|
# Update satellite state
|
||||||
self.tts_response_finished()
|
self.tts_response_finished()
|
||||||
|
|
||||||
|
# Signal pipeline to restart
|
||||||
|
self._tts_done.set()
|
||||||
|
|
||||||
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
||||||
"""Send audio in executor."""
|
"""Send audio in executor."""
|
||||||
await self.hass.async_add_executor_job(
|
await self.hass.async_add_executor_job(
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
"""Voip util functions."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from asyncio import Queue, timeout as async_timeout
|
|
||||||
from collections.abc import AsyncIterable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from typing_extensions import TypeVar
|
|
||||||
|
|
||||||
_DataT = TypeVar("_DataT", default=Any)
|
|
||||||
|
|
||||||
|
|
||||||
async def queue_to_iterable(
|
|
||||||
queue: Queue[_DataT], timeout: float | None = None
|
|
||||||
) -> AsyncIterable[_DataT]:
|
|
||||||
"""Stream items from a queue until None with an optional timeout per item."""
|
|
||||||
if timeout is None:
|
|
||||||
while (item := await queue.get()) is not None:
|
|
||||||
yield item
|
|
||||||
else:
|
|
||||||
async with async_timeout(timeout):
|
|
||||||
item = await queue.get()
|
|
||||||
|
|
||||||
while item is not None:
|
|
||||||
yield item
|
|
||||||
async with async_timeout(timeout):
|
|
||||||
item = await queue.get()
|
|
@ -1,47 +0,0 @@
|
|||||||
"""Test VoIP utils."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from homeassistant.components.voip.util import queue_to_iterable
|
|
||||||
|
|
||||||
|
|
||||||
async def test_queue_to_iterable() -> None:
|
|
||||||
"""Test queue_to_iterable."""
|
|
||||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
|
||||||
expected_items = list(range(10))
|
|
||||||
|
|
||||||
for i in expected_items:
|
|
||||||
await queue.put(i)
|
|
||||||
|
|
||||||
# Will terminate the stream
|
|
||||||
await queue.put(None)
|
|
||||||
|
|
||||||
actual_items = [item async for item in queue_to_iterable(queue)]
|
|
||||||
|
|
||||||
assert expected_items == actual_items
|
|
||||||
|
|
||||||
# Check timeout
|
|
||||||
assert queue.empty()
|
|
||||||
|
|
||||||
# Time out on first item
|
|
||||||
async with asyncio.timeout(1):
|
|
||||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
|
||||||
# Should time out very quickly
|
|
||||||
async for _item in queue_to_iterable(queue, timeout=0.01):
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
# Check timeout on second item
|
|
||||||
assert queue.empty()
|
|
||||||
await queue.put(12345)
|
|
||||||
|
|
||||||
# Time out on second item
|
|
||||||
async with asyncio.timeout(1):
|
|
||||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
|
||||||
# Should time out very quickly
|
|
||||||
async for item in queue_to_iterable(queue, timeout=0.01):
|
|
||||||
if item != 12345:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
assert queue.empty()
|
|
Loading…
x
Reference in New Issue
Block a user