mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Wait for TTS before restarting pipeline (#91962)
This commit is contained in:
parent
36f90cda92
commit
5a57602163
@ -101,6 +101,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
self._context = context
|
||||
self._conversation_id: str | None = None
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._tts_done = asyncio.Event()
|
||||
self._session_id: str | None = None
|
||||
self._tone_bytes: bytes | None = None
|
||||
self._processing_bytes: bytes | None = None
|
||||
@ -152,6 +153,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
return
|
||||
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
self._tts_done.clear()
|
||||
|
||||
async def stt_stream():
|
||||
try:
|
||||
@ -193,6 +195,10 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
tts_audio_output="raw",
|
||||
)
|
||||
|
||||
# Block until TTS is done speaking
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except asyncio.TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline timeout")
|
||||
@ -269,26 +275,30 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
# Send TTS audio to caller over RTP
|
||||
media_id = event.data["tts_output"]["media_id"]
|
||||
self.hass.async_create_background_task(
|
||||
self._send_media(media_id),
|
||||
self._send_tts(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
|
||||
async def _send_media(self, media_id: str) -> None:
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
if self.transport is None:
|
||||
return
|
||||
try:
|
||||
if self.transport is None:
|
||||
return
|
||||
|
||||
_extension, audio_bytes = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
_extension, audio_bytes = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||
|
||||
# Assume TTS audio is 16Khz 16-bit mono
|
||||
await self.hass.async_add_executor_job(
|
||||
partial(self.send_audio, audio_bytes, **_RTP_AUDIO_SETTINGS)
|
||||
)
|
||||
# Assume TTS audio is 16Khz 16-bit mono
|
||||
await self.hass.async_add_executor_job(
|
||||
partial(self.send_audio, audio_bytes, **_RTP_AUDIO_SETTINGS)
|
||||
)
|
||||
finally:
|
||||
# Signal pipeline to restart
|
||||
self._tts_done.set()
|
||||
|
||||
async def _play_listening_tone(self) -> None:
|
||||
"""Play a tone to indicate that Home Assistant is listening."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user