mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Wait for TTS before restarting pipeline (#91962)
This commit is contained in:
parent
36f90cda92
commit
5a57602163
@ -101,6 +101,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||||||
self._context = context
|
self._context = context
|
||||||
self._conversation_id: str | None = None
|
self._conversation_id: str | None = None
|
||||||
self._pipeline_task: asyncio.Task | None = None
|
self._pipeline_task: asyncio.Task | None = None
|
||||||
|
self._tts_done = asyncio.Event()
|
||||||
self._session_id: str | None = None
|
self._session_id: str | None = None
|
||||||
self._tone_bytes: bytes | None = None
|
self._tone_bytes: bytes | None = None
|
||||||
self._processing_bytes: bytes | None = None
|
self._processing_bytes: bytes | None = None
|
||||||
@ -152,6 +153,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||||||
return
|
return
|
||||||
|
|
||||||
_LOGGER.debug("Starting pipeline")
|
_LOGGER.debug("Starting pipeline")
|
||||||
|
self._tts_done.clear()
|
||||||
|
|
||||||
async def stt_stream():
|
async def stt_stream():
|
||||||
try:
|
try:
|
||||||
@ -193,6 +195,10 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||||||
tts_audio_output="raw",
|
tts_audio_output="raw",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Block until TTS is done speaking
|
||||||
|
await self._tts_done.wait()
|
||||||
|
|
||||||
|
_LOGGER.debug("Pipeline finished")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# Expected after caller hangs up
|
# Expected after caller hangs up
|
||||||
_LOGGER.debug("Pipeline timeout")
|
_LOGGER.debug("Pipeline timeout")
|
||||||
@ -269,26 +275,30 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||||||
# Send TTS audio to caller over RTP
|
# Send TTS audio to caller over RTP
|
||||||
media_id = event.data["tts_output"]["media_id"]
|
media_id = event.data["tts_output"]["media_id"]
|
||||||
self.hass.async_create_background_task(
|
self.hass.async_create_background_task(
|
||||||
self._send_media(media_id),
|
self._send_tts(media_id),
|
||||||
"voip_pipeline_tts",
|
"voip_pipeline_tts",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _send_media(self, media_id: str) -> None:
|
async def _send_tts(self, media_id: str) -> None:
|
||||||
"""Send TTS audio to caller via RTP."""
|
"""Send TTS audio to caller via RTP."""
|
||||||
if self.transport is None:
|
try:
|
||||||
return
|
if self.transport is None:
|
||||||
|
return
|
||||||
|
|
||||||
_extension, audio_bytes = await tts.async_get_media_source_audio(
|
_extension, audio_bytes = await tts.async_get_media_source_audio(
|
||||||
self.hass,
|
self.hass,
|
||||||
media_id,
|
media_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||||
|
|
||||||
# Assume TTS audio is 16Khz 16-bit mono
|
# Assume TTS audio is 16Khz 16-bit mono
|
||||||
await self.hass.async_add_executor_job(
|
await self.hass.async_add_executor_job(
|
||||||
partial(self.send_audio, audio_bytes, **_RTP_AUDIO_SETTINGS)
|
partial(self.send_audio, audio_bytes, **_RTP_AUDIO_SETTINGS)
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
# Signal pipeline to restart
|
||||||
|
self._tts_done.set()
|
||||||
|
|
||||||
async def _play_listening_tone(self) -> None:
|
async def _play_listening_tone(self) -> None:
|
||||||
"""Play a tone to indicate that Home Assistant is listening."""
|
"""Play a tone to indicate that Home Assistant is listening."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user