mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 09:17:10 +00:00
Add TTS response timeout for idle state (#146984)
* Add TTS response timeout for idle state * Consider time spent sending TTS audio in timeout
This commit is contained in:
parent
79cc3bffc6
commit
3b611b9b03
@ -6,6 +6,7 @@ import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Final
|
||||
import wave
|
||||
|
||||
@ -36,6 +37,7 @@ from homeassistant.components.assist_satellite import (
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.util.ulid import ulid_now
|
||||
|
||||
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_WIDTH
|
||||
from .data import WyomingService
|
||||
@ -53,6 +55,7 @@ _PING_SEND_DELAY: Final = 2
|
||||
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
||||
_TTS_SAMPLE_RATE: Final = 22050
|
||||
_ANNOUNCE_CHUNK_BYTES: Final = 2048 # 1024 samples
|
||||
_TTS_TIMEOUT_EXTRA: Final = 1.0
|
||||
|
||||
# Wyoming stage -> Assist stage
|
||||
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
||||
@ -125,6 +128,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
self._ffmpeg_manager: ffmpeg.FFmpegManager | None = None
|
||||
self._played_event_received: asyncio.Event | None = None
|
||||
|
||||
# Randomly set on each pipeline loop run.
|
||||
# Used to ensure TTS timeout is acted on correctly.
|
||||
self._run_loop_id: str | None = None
|
||||
|
||||
@property
|
||||
def pipeline_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the pipeline to use for the next conversation."""
|
||||
@ -511,6 +518,7 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
wake_word_phrase: str | None = None
|
||||
run_pipeline: RunPipeline | None = None
|
||||
send_ping = True
|
||||
self._run_loop_id = ulid_now()
|
||||
|
||||
# Read events and check for pipeline end in parallel
|
||||
pipeline_ended_task = self.config_entry.async_create_background_task(
|
||||
@ -698,38 +706,52 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
f"Cannot stream audio format to satellite: {tts_result.extension}"
|
||||
)
|
||||
|
||||
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
|
||||
# Track the total duration of TTS audio for response timeout
|
||||
total_seconds = 0.0
|
||||
start_time = time.monotonic()
|
||||
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
|
||||
try:
|
||||
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
|
||||
|
||||
timestamp = 0
|
||||
await self._client.write_event(
|
||||
AudioStart(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
timestamp=timestamp,
|
||||
).event()
|
||||
)
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
|
||||
|
||||
# Stream audio chunks
|
||||
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
|
||||
chunk = AudioChunk(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
audio=audio_bytes,
|
||||
timestamp=timestamp,
|
||||
timestamp = 0
|
||||
await self._client.write_event(
|
||||
AudioStart(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
timestamp=timestamp,
|
||||
).event()
|
||||
)
|
||||
await self._client.write_event(chunk.event())
|
||||
timestamp += chunk.seconds
|
||||
|
||||
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
||||
_LOGGER.debug("TTS streaming complete")
|
||||
# Stream audio chunks
|
||||
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
|
||||
chunk = AudioChunk(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
audio=audio_bytes,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
await self._client.write_event(chunk.event())
|
||||
timestamp += chunk.seconds
|
||||
total_seconds += chunk.seconds
|
||||
|
||||
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
||||
_LOGGER.debug("TTS streaming complete")
|
||||
finally:
|
||||
send_duration = time.monotonic() - start_time
|
||||
timeout_seconds = max(0, total_seconds - send_duration + _TTS_TIMEOUT_EXTRA)
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._tts_timeout(timeout_seconds, self._run_loop_id),
|
||||
name="wyoming TTS timeout",
|
||||
)
|
||||
|
||||
async def _stt_stream(self) -> AsyncGenerator[bytes]:
|
||||
"""Yield audio chunks from a queue."""
|
||||
@ -744,6 +766,18 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
|
||||
yield chunk
|
||||
|
||||
async def _tts_timeout(
|
||||
self, timeout_seconds: float, run_loop_id: str | None
|
||||
) -> None:
|
||||
"""Force state change to IDLE in case TTS played event isn't received."""
|
||||
await asyncio.sleep(timeout_seconds + _TTS_TIMEOUT_EXTRA)
|
||||
|
||||
if run_loop_id != self._run_loop_id:
|
||||
# On a different pipeline run now
|
||||
return
|
||||
|
||||
self.tts_response_finished()
|
||||
|
||||
@callback
|
||||
def _handle_timer(
|
||||
self, event_type: intent.TimerEventType, timer: intent.TimerInfo
|
||||
|
@ -1365,3 +1365,110 @@ async def test_announce(
|
||||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_tts_timeout(
|
||||
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||
) -> None:
|
||||
"""Test entity state goes back to IDLE on a timeout."""
|
||||
events = [
|
||||
Info(satellite=SATELLITE_INFO.satellite).event(),
|
||||
RunPipeline(start_stage=PipelineStage.TTS, end_stage=PipelineStage.TTS).event(),
|
||||
]
|
||||
|
||||
pipeline_kwargs: dict[str, Any] = {}
|
||||
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
|
||||
None
|
||||
)
|
||||
run_pipeline_called = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant,
|
||||
context,
|
||||
event_callback,
|
||||
stt_metadata,
|
||||
stt_stream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
nonlocal pipeline_kwargs, pipeline_event_callback
|
||||
pipeline_kwargs = kwargs
|
||||
pipeline_event_callback = event_callback
|
||||
|
||||
run_pipeline_called.set()
|
||||
|
||||
response_finished = asyncio.Event()
|
||||
|
||||
def tts_response_finished(self):
|
||||
response_finished.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.tts_response_finished",
|
||||
tts_response_finished,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite._TTS_TIMEOUT_EXTRA",
|
||||
0,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
||||
assert device is not None
|
||||
|
||||
satellite_entry = next(
|
||||
(
|
||||
maybe_entry
|
||||
for maybe_entry in er.async_entries_for_device(
|
||||
entity_registry, device.device_id
|
||||
)
|
||||
if maybe_entry.domain == assist_satellite.DOMAIN
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert satellite_entry is not None
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await run_pipeline_called.wait()
|
||||
|
||||
# Reset so we can check the pipeline is automatically restarted below
|
||||
run_pipeline_called.clear()
|
||||
|
||||
assert pipeline_event_callback is not None
|
||||
assert pipeline_kwargs.get("device_id") == device.device_id
|
||||
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.TTS_START,
|
||||
{
|
||||
"tts_input": "test text to speak",
|
||||
"voice": "test voice",
|
||||
},
|
||||
)
|
||||
)
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav())
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.TTS_END,
|
||||
{"tts_output": {"token": mock_tts_result_stream.token}},
|
||||
)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
# tts_response_finished should be called on timeout
|
||||
await response_finished.wait()
|
||||
|
||||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
Loading…
x
Reference in New Issue
Block a user