mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Add TTS streaming to Wyoming satellites (#147438)
* Add TTS streaming using intent-progress * Handle incomplete header
This commit is contained in:
parent
0f112bb9c4
commit
3dc8676b99
@ -132,6 +132,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
# Used to ensure TTS timeout is acted on correctly.
|
# Used to ensure TTS timeout is acted on correctly.
|
||||||
self._run_loop_id: str | None = None
|
self._run_loop_id: str | None = None
|
||||||
|
|
||||||
|
# TTS streaming
|
||||||
|
self._tts_stream_token: str | None = None
|
||||||
|
self._is_tts_streaming: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pipeline_entity_id(self) -> str | None:
|
def pipeline_entity_id(self) -> str | None:
|
||||||
"""Return the entity ID of the pipeline to use for the next conversation."""
|
"""Return the entity ID of the pipeline to use for the next conversation."""
|
||||||
@ -179,11 +183,20 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
"""Set state based on pipeline stage."""
|
"""Set state based on pipeline stage."""
|
||||||
assert self._client is not None
|
assert self._client is not None
|
||||||
|
|
||||||
if event.type == assist_pipeline.PipelineEventType.RUN_END:
|
if event.type == assist_pipeline.PipelineEventType.RUN_START:
|
||||||
|
if event.data and (tts_output := event.data["tts_output"]):
|
||||||
|
# Get stream token early.
|
||||||
|
# If "tts_start_streaming" is True in INTENT_PROGRESS event, we
|
||||||
|
# can start streaming TTS before the TTS_END event.
|
||||||
|
self._tts_stream_token = tts_output["token"]
|
||||||
|
self._is_tts_streaming = False
|
||||||
|
elif event.type == assist_pipeline.PipelineEventType.RUN_END:
|
||||||
# Pipeline run is complete
|
# Pipeline run is complete
|
||||||
self._is_pipeline_running = False
|
self._is_pipeline_running = False
|
||||||
self._pipeline_ended_event.set()
|
self._pipeline_ended_event.set()
|
||||||
self.device.set_is_active(False)
|
self.device.set_is_active(False)
|
||||||
|
self._tts_stream_token = None
|
||||||
|
self._is_tts_streaming = False
|
||||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
||||||
self.config_entry.async_create_background_task(
|
self.config_entry.async_create_background_task(
|
||||||
self.hass,
|
self.hass,
|
||||||
@ -245,6 +258,20 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
self._client.write_event(Transcript(text=stt_text).event()),
|
self._client.write_event(Transcript(text=stt_text).event()),
|
||||||
f"{self.entity_id} {event.type}",
|
f"{self.entity_id} {event.type}",
|
||||||
)
|
)
|
||||||
|
elif event.type == assist_pipeline.PipelineEventType.INTENT_PROGRESS:
|
||||||
|
if (
|
||||||
|
event.data
|
||||||
|
and event.data.get("tts_start_streaming")
|
||||||
|
and self._tts_stream_token
|
||||||
|
and (stream := tts.async_get_stream(self.hass, self._tts_stream_token))
|
||||||
|
):
|
||||||
|
# Start streaming TTS early (before TTS_END).
|
||||||
|
self._is_tts_streaming = True
|
||||||
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._stream_tts(stream),
|
||||||
|
f"{self.entity_id} {event.type}",
|
||||||
|
)
|
||||||
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
||||||
# Text-to-speech text
|
# Text-to-speech text
|
||||||
if event.data:
|
if event.data:
|
||||||
@ -267,8 +294,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
if (
|
if (
|
||||||
event.data
|
event.data
|
||||||
and (tts_output := event.data["tts_output"])
|
and (tts_output := event.data["tts_output"])
|
||||||
|
and not self._is_tts_streaming
|
||||||
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
|
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
|
||||||
):
|
):
|
||||||
|
# Send TTS only if we haven't already started streaming it in INTENT_PROGRESS.
|
||||||
self.config_entry.async_create_background_task(
|
self.config_entry.async_create_background_task(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._stream_tts(stream),
|
self._stream_tts(stream),
|
||||||
@ -711,15 +740,26 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
|
header_data = b""
|
||||||
|
header_complete = False
|
||||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
sample_rate: int | None = None
|
||||||
sample_rate = wav_file.getframerate()
|
sample_width: int | None = None
|
||||||
sample_width = wav_file.getsampwidth()
|
sample_channels: int | None = None
|
||||||
sample_channels = wav_file.getnchannels()
|
|
||||||
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
|
|
||||||
|
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
|
|
||||||
|
async for data_chunk in tts_result.async_stream_result():
|
||||||
|
if not header_complete:
|
||||||
|
# Accumulate data until we can parse the header and get
|
||||||
|
# sample rate, etc.
|
||||||
|
header_data += data_chunk
|
||||||
|
# Most WAVE headers are 44 bytes in length
|
||||||
|
if (len(header_data) >= 44) and (
|
||||||
|
audio_info := _try_parse_wav_header(header_data)
|
||||||
|
):
|
||||||
|
# Overwrite chunk with audio after header
|
||||||
|
sample_rate, sample_width, sample_channels, data_chunk = (
|
||||||
|
audio_info
|
||||||
|
)
|
||||||
await self._client.write_event(
|
await self._client.write_event(
|
||||||
AudioStart(
|
AudioStart(
|
||||||
rate=sample_rate,
|
rate=sample_rate,
|
||||||
@ -728,19 +768,31 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
).event()
|
).event()
|
||||||
)
|
)
|
||||||
|
header_complete = True
|
||||||
|
|
||||||
# Stream audio chunks
|
if not data_chunk:
|
||||||
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
|
# No audio after header
|
||||||
chunk = AudioChunk(
|
continue
|
||||||
|
else:
|
||||||
|
# Header is incomplete
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Streaming audio
|
||||||
|
assert sample_rate is not None
|
||||||
|
assert sample_width is not None
|
||||||
|
assert sample_channels is not None
|
||||||
|
|
||||||
|
audio_chunk = AudioChunk(
|
||||||
rate=sample_rate,
|
rate=sample_rate,
|
||||||
width=sample_width,
|
width=sample_width,
|
||||||
channels=sample_channels,
|
channels=sample_channels,
|
||||||
audio=audio_bytes,
|
audio=data_chunk,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
)
|
)
|
||||||
await self._client.write_event(chunk.event())
|
|
||||||
timestamp += chunk.milliseconds
|
await self._client.write_event(audio_chunk.event())
|
||||||
total_seconds += chunk.seconds
|
timestamp += audio_chunk.milliseconds
|
||||||
|
total_seconds += audio_chunk.seconds
|
||||||
|
|
||||||
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
||||||
_LOGGER.debug("TTS streaming complete")
|
_LOGGER.debug("TTS streaming complete")
|
||||||
@ -812,3 +864,25 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
self.config_entry.async_create_background_task(
|
self.config_entry.async_create_background_task(
|
||||||
self.hass, self._client.write_event(event), "wyoming timer event"
|
self.hass, self._client.write_event(event), "wyoming timer event"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _try_parse_wav_header(header_data: bytes) -> tuple[int, int, int, bytes] | None:
|
||||||
|
"""Try to parse a WAV header from a buffer.
|
||||||
|
|
||||||
|
If successful, return (rate, width, channels, audio).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with io.BytesIO(header_data) as wav_io:
|
||||||
|
wav_file: wave.Wave_read = wave.open(wav_io, "rb")
|
||||||
|
with wav_file:
|
||||||
|
return (
|
||||||
|
wav_file.getframerate(),
|
||||||
|
wav_file.getsampwidth(),
|
||||||
|
wav_file.getnchannels(),
|
||||||
|
wav_file.readframes(wav_file.getnframes()),
|
||||||
|
)
|
||||||
|
except wave.Error:
|
||||||
|
# Ignore errors and return None
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
@ -1472,3 +1472,184 @@ async def test_tts_timeout(
|
|||||||
# Stop the satellite
|
# Stop the satellite
|
||||||
await hass.config_entries.async_unload(entry.entry_id)
|
await hass.config_entries.async_unload(entry.entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_satellite_tts_streaming(hass: HomeAssistant) -> None:
|
||||||
|
"""Test running a streaming TTS pipeline with a satellite."""
|
||||||
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||||
|
|
||||||
|
events = [
|
||||||
|
RunPipeline(start_stage=PipelineStage.ASR, end_stage=PipelineStage.TTS).event(),
|
||||||
|
]
|
||||||
|
|
||||||
|
pipeline_kwargs: dict[str, Any] = {}
|
||||||
|
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
|
||||||
|
None
|
||||||
|
)
|
||||||
|
run_pipeline_called = asyncio.Event()
|
||||||
|
audio_chunk_received = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
context,
|
||||||
|
event_callback,
|
||||||
|
stt_metadata,
|
||||||
|
stt_stream,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
nonlocal pipeline_kwargs, pipeline_event_callback
|
||||||
|
pipeline_kwargs = kwargs
|
||||||
|
pipeline_event_callback = event_callback
|
||||||
|
|
||||||
|
run_pipeline_called.set()
|
||||||
|
async for chunk in stt_stream:
|
||||||
|
if chunk:
|
||||||
|
audio_chunk_received.set()
|
||||||
|
break
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||||
|
SatelliteAsyncTcpClient(events),
|
||||||
|
) as mock_client,
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
||||||
|
):
|
||||||
|
entry = await setup_config_entry(hass)
|
||||||
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
||||||
|
assert device is not None
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.connect_event.wait()
|
||||||
|
await mock_client.run_satellite_event.wait()
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await run_pipeline_called.wait()
|
||||||
|
|
||||||
|
assert pipeline_event_callback is not None
|
||||||
|
assert pipeline_kwargs.get("device_id") == device.device_id
|
||||||
|
|
||||||
|
# Send TTS info early
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav())
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.RUN_START,
|
||||||
|
{"tts_output": {"token": mock_tts_result_stream.token}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Speech-to-text started
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.STT_START,
|
||||||
|
{"metadata": {"language": "en"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.transcribe_event.wait()
|
||||||
|
|
||||||
|
# Push in some audio
|
||||||
|
mock_client.inject_event(
|
||||||
|
AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
|
||||||
|
)
|
||||||
|
|
||||||
|
# User started speaking
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.voice_started_event.wait()
|
||||||
|
|
||||||
|
# User stopped speaking
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.voice_stopped_event.wait()
|
||||||
|
|
||||||
|
# Speech-to-text transcription
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.STT_END,
|
||||||
|
{"stt_output": {"text": "test transcript"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.transcript_event.wait()
|
||||||
|
|
||||||
|
# Intent progress starts TTS streaming early with info received in the
|
||||||
|
# run-start event.
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.INTENT_PROGRESS,
|
||||||
|
{"tts_start_streaming": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# TTS events are sent now. In practice, these would be streamed as text
|
||||||
|
# chunks are generated.
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.tts_audio_start_event.wait()
|
||||||
|
await mock_client.tts_audio_chunk_event.wait()
|
||||||
|
await mock_client.tts_audio_stop_event.wait()
|
||||||
|
|
||||||
|
# Verify audio chunk from test WAV
|
||||||
|
assert mock_client.tts_audio_chunk is not None
|
||||||
|
assert mock_client.tts_audio_chunk.rate == 22050
|
||||||
|
assert mock_client.tts_audio_chunk.width == 2
|
||||||
|
assert mock_client.tts_audio_chunk.channels == 1
|
||||||
|
assert mock_client.tts_audio_chunk.audio == b"1234"
|
||||||
|
|
||||||
|
# Text-to-speech text
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.TTS_START,
|
||||||
|
{
|
||||||
|
"tts_input": "test text to speak",
|
||||||
|
"voice": "test voice",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# synthesize event is sent with complete message for non-streaming clients
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.synthesize_event.wait()
|
||||||
|
|
||||||
|
assert mock_client.synthesize is not None
|
||||||
|
assert mock_client.synthesize.text == "test text to speak"
|
||||||
|
assert mock_client.synthesize.voice is not None
|
||||||
|
assert mock_client.synthesize.voice.name == "test voice"
|
||||||
|
|
||||||
|
# Because we started streaming TTS after intent progress, we should not
|
||||||
|
# stream it again on tts-end.
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts"
|
||||||
|
) as mock_stream_tts:
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.TTS_END,
|
||||||
|
{"tts_output": {"token": mock_tts_result_stream.token}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_stream_tts.assert_not_called()
|
||||||
|
|
||||||
|
# Pipeline finished
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop the satellite
|
||||||
|
await hass.config_entries.async_unload(entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user