Add TTS response timeout for idle state (#146984)

* Add TTS response timeout for idle state

* Consider time spent sending TTS audio in timeout
This commit is contained in:
Michael Hansen 2025-06-17 08:39:18 -05:00 committed by GitHub
parent 79cc3bffc6
commit 3b611b9b03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 168 additions and 27 deletions

View File

@ -6,6 +6,7 @@ import asyncio
from collections.abc import AsyncGenerator
import io
import logging
import time
from typing import Any, Final
import wave
@ -36,6 +37,7 @@ from homeassistant.components.assist_satellite import (
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.ulid import ulid_now
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_WIDTH
from .data import WyomingService
@ -53,6 +55,7 @@ _PING_SEND_DELAY: Final = 2
_PIPELINE_FINISH_TIMEOUT: Final = 1
_TTS_SAMPLE_RATE: Final = 22050
_ANNOUNCE_CHUNK_BYTES: Final = 2048 # 1024 samples
_TTS_TIMEOUT_EXTRA: Final = 1.0
# Wyoming stage -> Assist stage
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
@ -125,6 +128,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
self._ffmpeg_manager: ffmpeg.FFmpegManager | None = None
self._played_event_received: asyncio.Event | None = None
# Randomly set on each pipeline loop run.
# Used to ensure TTS timeout is acted on correctly.
self._run_loop_id: str | None = None
@property
def pipeline_entity_id(self) -> str | None:
"""Return the entity ID of the pipeline to use for the next conversation."""
@ -511,6 +518,7 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
wake_word_phrase: str | None = None
run_pipeline: RunPipeline | None = None
send_ping = True
self._run_loop_id = ulid_now()
# Read events and check for pipeline end in parallel
pipeline_ended_task = self.config_entry.async_create_background_task(
@ -698,38 +706,52 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
f"Cannot stream audio format to satellite: {tts_result.extension}"
)
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
# Track the total duration of TTS audio for response timeout
total_seconds = 0.0
start_time = time.monotonic()
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
try:
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
timestamp = 0
await self._client.write_event(
AudioStart(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
timestamp=timestamp,
).event()
)
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
# Stream audio chunks
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
chunk = AudioChunk(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
audio=audio_bytes,
timestamp=timestamp,
timestamp = 0
await self._client.write_event(
AudioStart(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
timestamp=timestamp,
).event()
)
await self._client.write_event(chunk.event())
timestamp += chunk.seconds
await self._client.write_event(AudioStop(timestamp=timestamp).event())
_LOGGER.debug("TTS streaming complete")
# Stream audio chunks
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
chunk = AudioChunk(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
audio=audio_bytes,
timestamp=timestamp,
)
await self._client.write_event(chunk.event())
timestamp += chunk.seconds
total_seconds += chunk.seconds
await self._client.write_event(AudioStop(timestamp=timestamp).event())
_LOGGER.debug("TTS streaming complete")
finally:
send_duration = time.monotonic() - start_time
timeout_seconds = max(0, total_seconds - send_duration + _TTS_TIMEOUT_EXTRA)
self.config_entry.async_create_background_task(
self.hass,
self._tts_timeout(timeout_seconds, self._run_loop_id),
name="wyoming TTS timeout",
)
async def _stt_stream(self) -> AsyncGenerator[bytes]:
"""Yield audio chunks from a queue."""
@ -744,6 +766,18 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
yield chunk
async def _tts_timeout(
self, timeout_seconds: float, run_loop_id: str | None
) -> None:
"""Force state change to IDLE in case TTS played event isn't received."""
await asyncio.sleep(timeout_seconds + _TTS_TIMEOUT_EXTRA)
if run_loop_id != self._run_loop_id:
# On a different pipeline run now
return
self.tts_response_finished()
@callback
def _handle_timer(
self, event_type: intent.TimerEventType, timer: intent.TimerInfo

View File

@ -1365,3 +1365,110 @@ async def test_announce(
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_tts_timeout(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test entity state goes back to IDLE on a timeout."""
events = [
Info(satellite=SATELLITE_INFO.satellite).event(),
RunPipeline(start_stage=PipelineStage.TTS, end_stage=PipelineStage.TTS).event(),
]
pipeline_kwargs: dict[str, Any] = {}
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
None
)
run_pipeline_called = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_kwargs, pipeline_event_callback
pipeline_kwargs = kwargs
pipeline_event_callback = event_callback
run_pipeline_called.set()
response_finished = asyncio.Event()
def tts_response_finished(self):
response_finished.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
),
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.tts_response_finished",
tts_response_finished,
),
patch(
"homeassistant.components.wyoming.assist_satellite._TTS_TIMEOUT_EXTRA",
0,
),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
assert device is not None
satellite_entry = next(
(
maybe_entry
for maybe_entry in er.async_entries_for_device(
entity_registry, device.device_id
)
if maybe_entry.domain == assist_satellite.DOMAIN
),
None,
)
assert satellite_entry is not None
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Reset so we can check the pipeline is automatically restarted below
run_pipeline_called.clear()
assert pipeline_event_callback is not None
assert pipeline_kwargs.get("device_id") == device.device_id
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_START,
{
"tts_input": "test text to speak",
"voice": "test voice",
},
)
)
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav())
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_END,
{"tts_output": {"token": mock_tts_result_stream.token}},
)
)
async with asyncio.timeout(1):
# tts_response_finished should be called on timeout
await response_finished.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()