From 5a5760216327432aa533eef62c33097c4724d7f5 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 24 Apr 2023 12:27:13 -0500 Subject: [PATCH] Wait for TTS before restarting pipeline (#91962) --- homeassistant/components/voip/voip.py | 36 +++++++++++++++++---------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/voip/voip.py b/homeassistant/components/voip/voip.py index 67b1fcda7e5..0d66facec7b 100644 --- a/homeassistant/components/voip/voip.py +++ b/homeassistant/components/voip/voip.py @@ -101,6 +101,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): self._context = context self._conversation_id: str | None = None self._pipeline_task: asyncio.Task | None = None + self._tts_done = asyncio.Event() self._session_id: str | None = None self._tone_bytes: bytes | None = None self._processing_bytes: bytes | None = None @@ -152,6 +153,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): return _LOGGER.debug("Starting pipeline") + self._tts_done.clear() async def stt_stream(): try: @@ -193,6 +195,10 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): tts_audio_output="raw", ) + # Block until TTS is done speaking + await self._tts_done.wait() + + _LOGGER.debug("Pipeline finished") except asyncio.TimeoutError: # Expected after caller hangs up _LOGGER.debug("Pipeline timeout") @@ -269,26 +275,30 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): # Send TTS audio to caller over RTP media_id = event.data["tts_output"]["media_id"] self.hass.async_create_background_task( - self._send_media(media_id), + self._send_tts(media_id), "voip_pipeline_tts", ) - async def _send_media(self, media_id: str) -> None: + async def _send_tts(self, media_id: str) -> None: """Send TTS audio to caller via RTP.""" - if self.transport is None: - return + try: + if self.transport is None: + return - _extension, audio_bytes = await tts.async_get_media_source_audio( - self.hass, - media_id, - ) + _extension, audio_bytes = await tts.async_get_media_source_audio( + self.hass, + media_id, + ) - _LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes)) + _LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes)) - # Assume TTS audio is 16Khz 16-bit mono - await self.hass.async_add_executor_job( - partial(self.send_audio, audio_bytes, **_RTP_AUDIO_SETTINGS) - ) + # Assume TTS audio is 16Khz 16-bit mono + await self.hass.async_add_executor_job( + partial(self.send_audio, audio_bytes, **_RTP_AUDIO_SETTINGS) + ) + finally: + # Signal pipeline to restart + self._tts_done.set() async def _play_listening_tone(self) -> None: """Play a tone to indicate that Home Assistant is listening."""