Wait for TTS before restarting pipeline (#91962)

This commit is contained in:
Michael Hansen 2023-04-24 12:27:13 -05:00 committed by GitHub
parent 36f90cda92
commit 5a57602163
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -101,6 +101,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
self._context = context
self._conversation_id: str | None = None
self._pipeline_task: asyncio.Task | None = None
self._tts_done = asyncio.Event()
self._session_id: str | None = None
self._tone_bytes: bytes | None = None
self._processing_bytes: bytes | None = None
@ -152,6 +153,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
return
_LOGGER.debug("Starting pipeline")
self._tts_done.clear()
async def stt_stream():
try:
@ -193,6 +195,10 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
tts_audio_output="raw",
)
# Block until TTS is done speaking
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except asyncio.TimeoutError:
# Expected after caller hangs up
_LOGGER.debug("Pipeline timeout")
@ -269,26 +275,30 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
# Send TTS audio to caller over RTP
media_id = event.data["tts_output"]["media_id"]
self.hass.async_create_background_task(
self._send_media(media_id),
self._send_tts(media_id),
"voip_pipeline_tts",
)
async def _send_media(self, media_id: str) -> None:
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to caller via RTP."""
if self.transport is None:
return
try:
if self.transport is None:
return
_extension, audio_bytes = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
_extension, audio_bytes = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
# Assume TTS audio is 16Khz 16-bit mono
await self.hass.async_add_executor_job(
partial(self.send_audio, audio_bytes, **_RTP_AUDIO_SETTINGS)
)
# Assume TTS audio is 16Khz 16-bit mono
await self.hass.async_add_executor_job(
partial(self.send_audio, audio_bytes, **_RTP_AUDIO_SETTINGS)
)
finally:
# Signal pipeline to restart
self._tts_done.set()
async def _play_listening_tone(self) -> None:
"""Play a tone to indicate that Home Assistant is listening."""