mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Add TTS response timeout for idle state (#146984)
* Add TTS response timeout for idle state * Consider time spent sending TTS audio in timeout
This commit is contained in:
parent
79cc3bffc6
commit
3b611b9b03
@ -6,6 +6,7 @@ import asyncio
|
|||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any, Final
|
from typing import Any, Final
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
@ -36,6 +37,7 @@ from homeassistant.components.assist_satellite import (
|
|||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
from homeassistant.util.ulid import ulid_now
|
||||||
|
|
||||||
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_WIDTH
|
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_WIDTH
|
||||||
from .data import WyomingService
|
from .data import WyomingService
|
||||||
@ -53,6 +55,7 @@ _PING_SEND_DELAY: Final = 2
|
|||||||
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
||||||
_TTS_SAMPLE_RATE: Final = 22050
|
_TTS_SAMPLE_RATE: Final = 22050
|
||||||
_ANNOUNCE_CHUNK_BYTES: Final = 2048 # 1024 samples
|
_ANNOUNCE_CHUNK_BYTES: Final = 2048 # 1024 samples
|
||||||
|
_TTS_TIMEOUT_EXTRA: Final = 1.0
|
||||||
|
|
||||||
# Wyoming stage -> Assist stage
|
# Wyoming stage -> Assist stage
|
||||||
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
||||||
@ -125,6 +128,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
self._ffmpeg_manager: ffmpeg.FFmpegManager | None = None
|
self._ffmpeg_manager: ffmpeg.FFmpegManager | None = None
|
||||||
self._played_event_received: asyncio.Event | None = None
|
self._played_event_received: asyncio.Event | None = None
|
||||||
|
|
||||||
|
# Randomly set on each pipeline loop run.
|
||||||
|
# Used to ensure TTS timeout is acted on correctly.
|
||||||
|
self._run_loop_id: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pipeline_entity_id(self) -> str | None:
|
def pipeline_entity_id(self) -> str | None:
|
||||||
"""Return the entity ID of the pipeline to use for the next conversation."""
|
"""Return the entity ID of the pipeline to use for the next conversation."""
|
||||||
@ -511,6 +518,7 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
wake_word_phrase: str | None = None
|
wake_word_phrase: str | None = None
|
||||||
run_pipeline: RunPipeline | None = None
|
run_pipeline: RunPipeline | None = None
|
||||||
send_ping = True
|
send_ping = True
|
||||||
|
self._run_loop_id = ulid_now()
|
||||||
|
|
||||||
# Read events and check for pipeline end in parallel
|
# Read events and check for pipeline end in parallel
|
||||||
pipeline_ended_task = self.config_entry.async_create_background_task(
|
pipeline_ended_task = self.config_entry.async_create_background_task(
|
||||||
@ -698,38 +706,52 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
f"Cannot stream audio format to satellite: {tts_result.extension}"
|
f"Cannot stream audio format to satellite: {tts_result.extension}"
|
||||||
)
|
)
|
||||||
|
|
||||||
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
|
# Track the total duration of TTS audio for response timeout
|
||||||
|
total_seconds = 0.0
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
try:
|
||||||
sample_rate = wav_file.getframerate()
|
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
|
||||||
sample_width = wav_file.getsampwidth()
|
|
||||||
sample_channels = wav_file.getnchannels()
|
|
||||||
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
|
|
||||||
|
|
||||||
timestamp = 0
|
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||||
await self._client.write_event(
|
sample_rate = wav_file.getframerate()
|
||||||
AudioStart(
|
sample_width = wav_file.getsampwidth()
|
||||||
rate=sample_rate,
|
sample_channels = wav_file.getnchannels()
|
||||||
width=sample_width,
|
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
|
||||||
channels=sample_channels,
|
|
||||||
timestamp=timestamp,
|
|
||||||
).event()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stream audio chunks
|
timestamp = 0
|
||||||
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
|
await self._client.write_event(
|
||||||
chunk = AudioChunk(
|
AudioStart(
|
||||||
rate=sample_rate,
|
rate=sample_rate,
|
||||||
width=sample_width,
|
width=sample_width,
|
||||||
channels=sample_channels,
|
channels=sample_channels,
|
||||||
audio=audio_bytes,
|
timestamp=timestamp,
|
||||||
timestamp=timestamp,
|
).event()
|
||||||
)
|
)
|
||||||
await self._client.write_event(chunk.event())
|
|
||||||
timestamp += chunk.seconds
|
|
||||||
|
|
||||||
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
# Stream audio chunks
|
||||||
_LOGGER.debug("TTS streaming complete")
|
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
|
||||||
|
chunk = AudioChunk(
|
||||||
|
rate=sample_rate,
|
||||||
|
width=sample_width,
|
||||||
|
channels=sample_channels,
|
||||||
|
audio=audio_bytes,
|
||||||
|
timestamp=timestamp,
|
||||||
|
)
|
||||||
|
await self._client.write_event(chunk.event())
|
||||||
|
timestamp += chunk.seconds
|
||||||
|
total_seconds += chunk.seconds
|
||||||
|
|
||||||
|
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
||||||
|
_LOGGER.debug("TTS streaming complete")
|
||||||
|
finally:
|
||||||
|
send_duration = time.monotonic() - start_time
|
||||||
|
timeout_seconds = max(0, total_seconds - send_duration + _TTS_TIMEOUT_EXTRA)
|
||||||
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._tts_timeout(timeout_seconds, self._run_loop_id),
|
||||||
|
name="wyoming TTS timeout",
|
||||||
|
)
|
||||||
|
|
||||||
async def _stt_stream(self) -> AsyncGenerator[bytes]:
|
async def _stt_stream(self) -> AsyncGenerator[bytes]:
|
||||||
"""Yield audio chunks from a queue."""
|
"""Yield audio chunks from a queue."""
|
||||||
@ -744,6 +766,18 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
async def _tts_timeout(
|
||||||
|
self, timeout_seconds: float, run_loop_id: str | None
|
||||||
|
) -> None:
|
||||||
|
"""Force state change to IDLE in case TTS played event isn't received."""
|
||||||
|
await asyncio.sleep(timeout_seconds + _TTS_TIMEOUT_EXTRA)
|
||||||
|
|
||||||
|
if run_loop_id != self._run_loop_id:
|
||||||
|
# On a different pipeline run now
|
||||||
|
return
|
||||||
|
|
||||||
|
self.tts_response_finished()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _handle_timer(
|
def _handle_timer(
|
||||||
self, event_type: intent.TimerEventType, timer: intent.TimerInfo
|
self, event_type: intent.TimerEventType, timer: intent.TimerInfo
|
||||||
|
@ -1365,3 +1365,110 @@ async def test_announce(
|
|||||||
# Stop the satellite
|
# Stop the satellite
|
||||||
await hass.config_entries.async_unload(entry.entry_id)
|
await hass.config_entries.async_unload(entry.entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_timeout(
|
||||||
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test entity state goes back to IDLE on a timeout."""
|
||||||
|
events = [
|
||||||
|
Info(satellite=SATELLITE_INFO.satellite).event(),
|
||||||
|
RunPipeline(start_stage=PipelineStage.TTS, end_stage=PipelineStage.TTS).event(),
|
||||||
|
]
|
||||||
|
|
||||||
|
pipeline_kwargs: dict[str, Any] = {}
|
||||||
|
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
|
||||||
|
None
|
||||||
|
)
|
||||||
|
run_pipeline_called = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
context,
|
||||||
|
event_callback,
|
||||||
|
stt_metadata,
|
||||||
|
stt_stream,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
nonlocal pipeline_kwargs, pipeline_event_callback
|
||||||
|
pipeline_kwargs = kwargs
|
||||||
|
pipeline_event_callback = event_callback
|
||||||
|
|
||||||
|
run_pipeline_called.set()
|
||||||
|
|
||||||
|
response_finished = asyncio.Event()
|
||||||
|
|
||||||
|
def tts_response_finished(self):
|
||||||
|
response_finished.set()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||||
|
SatelliteAsyncTcpClient(events),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.tts_response_finished",
|
||||||
|
tts_response_finished,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.assist_satellite._TTS_TIMEOUT_EXTRA",
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
entry = await setup_config_entry(hass)
|
||||||
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
||||||
|
assert device is not None
|
||||||
|
|
||||||
|
satellite_entry = next(
|
||||||
|
(
|
||||||
|
maybe_entry
|
||||||
|
for maybe_entry in er.async_entries_for_device(
|
||||||
|
entity_registry, device.device_id
|
||||||
|
)
|
||||||
|
if maybe_entry.domain == assist_satellite.DOMAIN
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert satellite_entry is not None
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await run_pipeline_called.wait()
|
||||||
|
|
||||||
|
# Reset so we can check the pipeline is automatically restarted below
|
||||||
|
run_pipeline_called.clear()
|
||||||
|
|
||||||
|
assert pipeline_event_callback is not None
|
||||||
|
assert pipeline_kwargs.get("device_id") == device.device_id
|
||||||
|
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.TTS_START,
|
||||||
|
{
|
||||||
|
"tts_input": "test text to speak",
|
||||||
|
"voice": "test voice",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav())
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.TTS_END,
|
||||||
|
{"tts_output": {"token": mock_tts_result_stream.token}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
# tts_response_finished should be called on timeout
|
||||||
|
await response_finished.wait()
|
||||||
|
|
||||||
|
# Stop the satellite
|
||||||
|
await hass.config_entries.async_unload(entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user