Add TTS response timeout for idle state (#146984)

* Add TTS response timeout for idle state

* Consider time spent sending TTS audio in timeout
This commit is contained in:
Michael Hansen 2025-06-17 08:39:18 -05:00 committed by GitHub
parent 79cc3bffc6
commit 3b611b9b03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 168 additions and 27 deletions

View File

@ -6,6 +6,7 @@ import asyncio
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
import io import io
import logging import logging
import time
from typing import Any, Final from typing import Any, Final
import wave import wave
@ -36,6 +37,7 @@ from homeassistant.components.assist_satellite import (
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.ulid import ulid_now
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_WIDTH from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_WIDTH
from .data import WyomingService from .data import WyomingService
@ -53,6 +55,7 @@ _PING_SEND_DELAY: Final = 2
_PIPELINE_FINISH_TIMEOUT: Final = 1 _PIPELINE_FINISH_TIMEOUT: Final = 1
_TTS_SAMPLE_RATE: Final = 22050 _TTS_SAMPLE_RATE: Final = 22050
_ANNOUNCE_CHUNK_BYTES: Final = 2048 # 1024 samples _ANNOUNCE_CHUNK_BYTES: Final = 2048 # 1024 samples
_TTS_TIMEOUT_EXTRA: Final = 1.0
# Wyoming stage -> Assist stage # Wyoming stage -> Assist stage
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = { _STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
@ -125,6 +128,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
self._ffmpeg_manager: ffmpeg.FFmpegManager | None = None self._ffmpeg_manager: ffmpeg.FFmpegManager | None = None
self._played_event_received: asyncio.Event | None = None self._played_event_received: asyncio.Event | None = None
# Randomly set on each pipeline loop run.
# Used to ensure TTS timeout is acted on correctly.
self._run_loop_id: str | None = None
@property @property
def pipeline_entity_id(self) -> str | None: def pipeline_entity_id(self) -> str | None:
"""Return the entity ID of the pipeline to use for the next conversation.""" """Return the entity ID of the pipeline to use for the next conversation."""
@ -511,6 +518,7 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
wake_word_phrase: str | None = None wake_word_phrase: str | None = None
run_pipeline: RunPipeline | None = None run_pipeline: RunPipeline | None = None
send_ping = True send_ping = True
self._run_loop_id = ulid_now()
# Read events and check for pipeline end in parallel # Read events and check for pipeline end in parallel
pipeline_ended_task = self.config_entry.async_create_background_task( pipeline_ended_task = self.config_entry.async_create_background_task(
@ -698,38 +706,52 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
f"Cannot stream audio format to satellite: {tts_result.extension}" f"Cannot stream audio format to satellite: {tts_result.extension}"
) )
data = b"".join([chunk async for chunk in tts_result.async_stream_result()]) # Track the total duration of TTS audio for response timeout
total_seconds = 0.0
start_time = time.monotonic()
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: try:
sample_rate = wav_file.getframerate() data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
timestamp = 0 with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
await self._client.write_event( sample_rate = wav_file.getframerate()
AudioStart( sample_width = wav_file.getsampwidth()
rate=sample_rate, sample_channels = wav_file.getnchannels()
width=sample_width, _LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
channels=sample_channels,
timestamp=timestamp,
).event()
)
# Stream audio chunks timestamp = 0
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK): await self._client.write_event(
chunk = AudioChunk( AudioStart(
rate=sample_rate, rate=sample_rate,
width=sample_width, width=sample_width,
channels=sample_channels, channels=sample_channels,
audio=audio_bytes, timestamp=timestamp,
timestamp=timestamp, ).event()
) )
await self._client.write_event(chunk.event())
timestamp += chunk.seconds
await self._client.write_event(AudioStop(timestamp=timestamp).event()) # Stream audio chunks
_LOGGER.debug("TTS streaming complete") while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
chunk = AudioChunk(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
audio=audio_bytes,
timestamp=timestamp,
)
await self._client.write_event(chunk.event())
timestamp += chunk.seconds
total_seconds += chunk.seconds
await self._client.write_event(AudioStop(timestamp=timestamp).event())
_LOGGER.debug("TTS streaming complete")
finally:
send_duration = time.monotonic() - start_time
timeout_seconds = max(0, total_seconds - send_duration + _TTS_TIMEOUT_EXTRA)
self.config_entry.async_create_background_task(
self.hass,
self._tts_timeout(timeout_seconds, self._run_loop_id),
name="wyoming TTS timeout",
)
async def _stt_stream(self) -> AsyncGenerator[bytes]: async def _stt_stream(self) -> AsyncGenerator[bytes]:
"""Yield audio chunks from a queue.""" """Yield audio chunks from a queue."""
@ -744,6 +766,18 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
yield chunk yield chunk
async def _tts_timeout(
self, timeout_seconds: float, run_loop_id: str | None
) -> None:
"""Force state change to IDLE in case TTS played event isn't received."""
await asyncio.sleep(timeout_seconds + _TTS_TIMEOUT_EXTRA)
if run_loop_id != self._run_loop_id:
# On a different pipeline run now
return
self.tts_response_finished()
@callback @callback
def _handle_timer( def _handle_timer(
self, event_type: intent.TimerEventType, timer: intent.TimerInfo self, event_type: intent.TimerEventType, timer: intent.TimerInfo

View File

@ -1365,3 +1365,110 @@ async def test_announce(
# Stop the satellite # Stop the satellite
await hass.config_entries.async_unload(entry.entry_id) await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
async def test_tts_timeout(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test entity state goes back to IDLE on a timeout."""
events = [
Info(satellite=SATELLITE_INFO.satellite).event(),
RunPipeline(start_stage=PipelineStage.TTS, end_stage=PipelineStage.TTS).event(),
]
pipeline_kwargs: dict[str, Any] = {}
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
None
)
run_pipeline_called = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_kwargs, pipeline_event_callback
pipeline_kwargs = kwargs
pipeline_event_callback = event_callback
run_pipeline_called.set()
response_finished = asyncio.Event()
def tts_response_finished(self):
response_finished.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
),
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.tts_response_finished",
tts_response_finished,
),
patch(
"homeassistant.components.wyoming.assist_satellite._TTS_TIMEOUT_EXTRA",
0,
),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
assert device is not None
satellite_entry = next(
(
maybe_entry
for maybe_entry in er.async_entries_for_device(
entity_registry, device.device_id
)
if maybe_entry.domain == assist_satellite.DOMAIN
),
None,
)
assert satellite_entry is not None
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Reset so we can check the pipeline is automatically restarted below
run_pipeline_called.clear()
assert pipeline_event_callback is not None
assert pipeline_kwargs.get("device_id") == device.device_id
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_START,
{
"tts_input": "test text to speak",
"voice": "test voice",
},
)
)
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav())
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_END,
{"tts_output": {"token": mock_tts_result_stream.token}},
)
)
async with asyncio.timeout(1):
# tts_response_finished should be called on timeout
await response_finished.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()