Support streaming TTS in wyoming (#147392)

* Support streaming TTS in wyoming

* Add test

* Refactor to avoid repeated task creation

* Manually manage client lifecycle
This commit is contained in:
Michael Hansen 2025-06-24 12:04:40 -05:00 committed by GitHub
parent 3dc8676b99
commit cefc8822b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 242 additions and 7 deletions

View File

@ -1,13 +1,21 @@
"""Support for Wyoming text-to-speech services.""" """Support for Wyoming text-to-speech services."""
from collections import defaultdict from collections import defaultdict
from collections.abc import AsyncGenerator
import io import io
import logging import logging
import wave import wave
from wyoming.audio import AudioChunk, AudioStop from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient from wyoming.client import AsyncTcpClient
from wyoming.tts import Synthesize, SynthesizeVoice from wyoming.tts import (
Synthesize,
SynthesizeChunk,
SynthesizeStart,
SynthesizeStop,
SynthesizeStopped,
SynthesizeVoice,
)
from homeassistant.components import tts from homeassistant.components import tts
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
@ -45,6 +53,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
service: WyomingService, service: WyomingService,
) -> None: ) -> None:
"""Set up provider.""" """Set up provider."""
self.config_entry = config_entry
self.service = service self.service = service
self._tts_service = next(tts for tts in service.info.tts if tts.installed) self._tts_service = next(tts for tts in service.info.tts if tts.installed)
@ -150,3 +159,98 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
return (None, None) return (None, None)
return ("wav", data) return ("wav", data)
def async_supports_streaming_input(self) -> bool:
"""Return if the TTS engine supports streaming input."""
return self._tts_service.supports_synthesize_streaming
async def async_stream_tts_audio(
self, request: tts.TTSAudioRequest
) -> tts.TTSAudioResponse:
"""Generate speech from an incoming message."""
voice_name: str | None = request.options.get(tts.ATTR_VOICE)
voice_speaker: str | None = request.options.get(ATTR_SPEAKER)
voice: SynthesizeVoice | None = None
if voice_name is not None:
voice = SynthesizeVoice(name=voice_name, speaker=voice_speaker)
client = AsyncTcpClient(self.service.host, self.service.port)
await client.connect()
# Stream text chunks to client
self.config_entry.async_create_background_task(
self.hass,
self._write_tts_message(request.message_gen, client, voice),
"wyoming tts write",
)
async def data_gen():
# Stream audio bytes from client
try:
async for data_chunk in self._read_tts_audio(client):
yield data_chunk
finally:
await client.disconnect()
return tts.TTSAudioResponse("wav", data_gen())
async def _write_tts_message(
self,
message_gen: AsyncGenerator[str],
client: AsyncTcpClient,
voice: SynthesizeVoice | None,
) -> None:
"""Write text chunks to the client."""
try:
# Start stream
await client.write_event(SynthesizeStart(voice=voice).event())
# Accumulate entire message for synthesize event.
message = ""
async for message_chunk in message_gen:
message += message_chunk
await client.write_event(SynthesizeChunk(text=message_chunk).event())
# Send entire message for backwards compatibility
await client.write_event(Synthesize(text=message, voice=voice).event())
# End stream
await client.write_event(SynthesizeStop().event())
except (OSError, WyomingError):
# Disconnected
_LOGGER.warning("Unexpected disconnection from TTS client")
async def _read_tts_audio(self, client: AsyncTcpClient) -> AsyncGenerator[bytes]:
"""Read audio events from the client and yield WAV audio chunks.
The WAV header is sent first with a frame count of 0 to indicate that
we're streaming and don't know the number of frames ahead of time.
"""
wav_header_sent = False
try:
while event := await client.read_event():
if wav_header_sent and AudioChunk.is_type(event.type):
# PCM audio
yield AudioChunk.from_event(event).audio
elif (not wav_header_sent) and AudioStart.is_type(event.type):
# WAV header with nframes = 0 for streaming
audio_start = AudioStart.from_event(event)
with io.BytesIO() as wav_io:
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
with wav_file:
wav_file.setframerate(audio_start.rate)
wav_file.setsampwidth(audio_start.width)
wav_file.setnchannels(audio_start.channels)
wav_io.seek(0)
yield wav_io.getvalue()
wav_header_sent = True
elif SynthesizeStopped.is_type(event.type):
# All TTS audio has been received
break
except (OSError, WyomingError):
# Disconnected
_LOGGER.warning("Unexpected disconnection from TTS client")

View File

@ -69,6 +69,29 @@ TTS_INFO = Info(
) )
] ]
) )
TTS_STREAMING_INFO = Info(
tts=[
TtsProgram(
name="Test Streaming TTS",
description="Test Streaming TTS",
installed=True,
attribution=TEST_ATTR,
voices=[
TtsVoice(
name="Test Voice",
description="Test Voice",
installed=True,
attribution=TEST_ATTR,
languages=["en-US"],
speakers=[TtsVoiceSpeaker(name="Test Speaker")],
version=None,
)
],
version=None,
supports_synthesize_streaming=True,
)
]
)
WAKE_WORD_INFO = Info( WAKE_WORD_INFO = Info(
wake=[ wake=[
WakeProgram( WakeProgram(
@ -155,9 +178,15 @@ class MockAsyncTcpClient:
self.port: int | None = None self.port: int | None = None
self.written: list[Event] = [] self.written: list[Event] = []
self.responses = responses self.responses = responses
self.is_connected: bool | None = None
async def connect(self) -> None: async def connect(self) -> None:
"""Connect.""" """Connect."""
self.is_connected = True
async def disconnect(self) -> None:
"""Disconnect."""
self.is_connected = False
async def write_event(self, event: Event): async def write_event(self, event: Event):
"""Send.""" """Send."""

View File

@ -19,6 +19,7 @@ from . import (
SATELLITE_INFO, SATELLITE_INFO,
STT_INFO, STT_INFO,
TTS_INFO, TTS_INFO,
TTS_STREAMING_INFO,
WAKE_WORD_INFO, WAKE_WORD_INFO,
) )
@ -148,6 +149,20 @@ async def init_wyoming_tts(
return tts_config_entry return tts_config_entry
@pytest.fixture
async def init_wyoming_streaming_tts(
hass: HomeAssistant, tts_config_entry: ConfigEntry
) -> ConfigEntry:
"""Initialize Wyoming streaming TTS."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=TTS_STREAMING_INFO,
):
await hass.config_entries.async_setup(tts_config_entry.entry_id)
return tts_config_entry
@pytest.fixture @pytest.fixture
async def init_wyoming_wake_word( async def init_wyoming_wake_word(
hass: HomeAssistant, wake_word_config_entry: ConfigEntry hass: HomeAssistant, wake_word_config_entry: ConfigEntry

View File

@ -32,6 +32,43 @@
}), }),
]) ])
# --- # ---
# name: test_get_tts_audio_streaming
list([
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello ',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({
'data': dict({
'text': 'Word.',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({
'data': dict({
'text': 'Hello Word.',
}),
'payload': None,
'type': 'synthesize',
}),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
])
# ---
# name: test_voice_speaker # name: test_voice_speaker
list([ list([
dict({ dict({

View File

@ -8,7 +8,8 @@ import wave
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from wyoming.audio import AudioChunk, AudioStop from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.tts import SynthesizeStopped
from homeassistant.components import tts, wyoming from homeassistant.components import tts, wyoming
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -43,11 +44,11 @@ async def test_get_tts_audio(
hass: HomeAssistant, init_wyoming_tts, snapshot: SnapshotAssertion hass: HomeAssistant, init_wyoming_tts, snapshot: SnapshotAssertion
) -> None: ) -> None:
"""Test get audio.""" """Test get audio."""
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
assert entity is not None
assert not entity.async_supports_streaming_input()
audio = bytes(100) audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
# Verify audio # Verify audio
audio_events = [ audio_events = [
@ -215,3 +216,52 @@ async def test_voice_speaker(
), ),
) )
assert mock_client.written == snapshot assert mock_client.written == snapshot
async def test_get_tts_audio_streaming(
hass: HomeAssistant, init_wyoming_streaming_tts, snapshot: SnapshotAssertion
) -> None:
"""Test get audio with streaming."""
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_streaming_tts")
assert entity is not None
assert entity.async_supports_streaming_input()
audio = bytes(100)
# Verify audio
audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
SynthesizeStopped().event(),
]
async def message_gen():
yield "Hello "
yield "Word."
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
stream = tts.async_create_stream(
hass,
"tts.test_streaming_tts",
"en-US",
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
)
stream.async_set_message_stream(message_gen())
data = b"".join([chunk async for chunk in stream.async_stream_result()])
# Ensure client was disconnected properly
assert mock_client.is_connected is False
assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.getnframes() == 0 # streaming
assert data[44:] == audio # WAV header is 44 bytes
assert mock_client.written == snapshot