mirror of
https://github.com/home-assistant/core.git
synced 2025-07-08 05:47:10 +00:00
Support streaming TTS in wyoming (#147392)
* Support streaming TTS in wyoming * Add test * Refactor to avoid repeated task creation * Manually manage client lifecycle
This commit is contained in:
parent
3dc8676b99
commit
cefc8822b6
@ -1,13 +1,21 @@
|
||||
"""Support for Wyoming text-to-speech services."""
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import AsyncGenerator
|
||||
import io
|
||||
import logging
|
||||
import wave
|
||||
|
||||
from wyoming.audio import AudioChunk, AudioStop
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.tts import Synthesize, SynthesizeVoice
|
||||
from wyoming.tts import (
|
||||
Synthesize,
|
||||
SynthesizeChunk,
|
||||
SynthesizeStart,
|
||||
SynthesizeStop,
|
||||
SynthesizeStopped,
|
||||
SynthesizeVoice,
|
||||
)
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@ -45,6 +53,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||
service: WyomingService,
|
||||
) -> None:
|
||||
"""Set up provider."""
|
||||
self.config_entry = config_entry
|
||||
self.service = service
|
||||
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
|
||||
|
||||
@ -150,3 +159,98 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||
return (None, None)
|
||||
|
||||
return ("wav", data)
|
||||
|
||||
def async_supports_streaming_input(self) -> bool:
|
||||
"""Return if the TTS engine supports streaming input."""
|
||||
return self._tts_service.supports_synthesize_streaming
|
||||
|
||||
async def async_stream_tts_audio(
|
||||
self, request: tts.TTSAudioRequest
|
||||
) -> tts.TTSAudioResponse:
|
||||
"""Generate speech from an incoming message."""
|
||||
voice_name: str | None = request.options.get(tts.ATTR_VOICE)
|
||||
voice_speaker: str | None = request.options.get(ATTR_SPEAKER)
|
||||
voice: SynthesizeVoice | None = None
|
||||
if voice_name is not None:
|
||||
voice = SynthesizeVoice(name=voice_name, speaker=voice_speaker)
|
||||
|
||||
client = AsyncTcpClient(self.service.host, self.service.port)
|
||||
await client.connect()
|
||||
|
||||
# Stream text chunks to client
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._write_tts_message(request.message_gen, client, voice),
|
||||
"wyoming tts write",
|
||||
)
|
||||
|
||||
async def data_gen():
|
||||
# Stream audio bytes from client
|
||||
try:
|
||||
async for data_chunk in self._read_tts_audio(client):
|
||||
yield data_chunk
|
||||
finally:
|
||||
await client.disconnect()
|
||||
|
||||
return tts.TTSAudioResponse("wav", data_gen())
|
||||
|
||||
async def _write_tts_message(
|
||||
self,
|
||||
message_gen: AsyncGenerator[str],
|
||||
client: AsyncTcpClient,
|
||||
voice: SynthesizeVoice | None,
|
||||
) -> None:
|
||||
"""Write text chunks to the client."""
|
||||
try:
|
||||
# Start stream
|
||||
await client.write_event(SynthesizeStart(voice=voice).event())
|
||||
|
||||
# Accumulate entire message for synthesize event.
|
||||
message = ""
|
||||
async for message_chunk in message_gen:
|
||||
message += message_chunk
|
||||
|
||||
await client.write_event(SynthesizeChunk(text=message_chunk).event())
|
||||
|
||||
# Send entire message for backwards compatibility
|
||||
await client.write_event(Synthesize(text=message, voice=voice).event())
|
||||
|
||||
# End stream
|
||||
await client.write_event(SynthesizeStop().event())
|
||||
except (OSError, WyomingError):
|
||||
# Disconnected
|
||||
_LOGGER.warning("Unexpected disconnection from TTS client")
|
||||
|
||||
async def _read_tts_audio(self, client: AsyncTcpClient) -> AsyncGenerator[bytes]:
|
||||
"""Read audio events from the client and yield WAV audio chunks.
|
||||
|
||||
The WAV header is sent first with a frame count of 0 to indicate that
|
||||
we're streaming and don't know the number of frames ahead of time.
|
||||
"""
|
||||
wav_header_sent = False
|
||||
|
||||
try:
|
||||
while event := await client.read_event():
|
||||
if wav_header_sent and AudioChunk.is_type(event.type):
|
||||
# PCM audio
|
||||
yield AudioChunk.from_event(event).audio
|
||||
elif (not wav_header_sent) and AudioStart.is_type(event.type):
|
||||
# WAV header with nframes = 0 for streaming
|
||||
audio_start = AudioStart.from_event(event)
|
||||
with io.BytesIO() as wav_io:
|
||||
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
||||
with wav_file:
|
||||
wav_file.setframerate(audio_start.rate)
|
||||
wav_file.setsampwidth(audio_start.width)
|
||||
wav_file.setnchannels(audio_start.channels)
|
||||
|
||||
wav_io.seek(0)
|
||||
yield wav_io.getvalue()
|
||||
|
||||
wav_header_sent = True
|
||||
elif SynthesizeStopped.is_type(event.type):
|
||||
# All TTS audio has been received
|
||||
break
|
||||
except (OSError, WyomingError):
|
||||
# Disconnected
|
||||
_LOGGER.warning("Unexpected disconnection from TTS client")
|
||||
|
@ -69,6 +69,29 @@ TTS_INFO = Info(
|
||||
)
|
||||
]
|
||||
)
|
||||
TTS_STREAMING_INFO = Info(
|
||||
tts=[
|
||||
TtsProgram(
|
||||
name="Test Streaming TTS",
|
||||
description="Test Streaming TTS",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
voices=[
|
||||
TtsVoice(
|
||||
name="Test Voice",
|
||||
description="Test Voice",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=["en-US"],
|
||||
speakers=[TtsVoiceSpeaker(name="Test Speaker")],
|
||||
version=None,
|
||||
)
|
||||
],
|
||||
version=None,
|
||||
supports_synthesize_streaming=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
WAKE_WORD_INFO = Info(
|
||||
wake=[
|
||||
WakeProgram(
|
||||
@ -155,9 +178,15 @@ class MockAsyncTcpClient:
|
||||
self.port: int | None = None
|
||||
self.written: list[Event] = []
|
||||
self.responses = responses
|
||||
self.is_connected: bool | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect."""
|
||||
self.is_connected = True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect."""
|
||||
self.is_connected = False
|
||||
|
||||
async def write_event(self, event: Event):
|
||||
"""Send."""
|
||||
|
@ -19,6 +19,7 @@ from . import (
|
||||
SATELLITE_INFO,
|
||||
STT_INFO,
|
||||
TTS_INFO,
|
||||
TTS_STREAMING_INFO,
|
||||
WAKE_WORD_INFO,
|
||||
)
|
||||
|
||||
@ -148,6 +149,20 @@ async def init_wyoming_tts(
|
||||
return tts_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_streaming_tts(
|
||||
hass: HomeAssistant, tts_config_entry: ConfigEntry
|
||||
) -> ConfigEntry:
|
||||
"""Initialize Wyoming streaming TTS."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=TTS_STREAMING_INFO,
|
||||
):
|
||||
await hass.config_entries.async_setup(tts_config_entry.entry_id)
|
||||
|
||||
return tts_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_wake_word(
|
||||
hass: HomeAssistant, wake_word_config_entry: ConfigEntry
|
||||
|
@ -32,6 +32,43 @@
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_streaming
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-start',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello ',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Word.',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello Word.',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-stop',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_voice_speaker
|
||||
list([
|
||||
dict({
|
||||
|
@ -8,7 +8,8 @@ import wave
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from wyoming.audio import AudioChunk, AudioStop
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.tts import SynthesizeStopped
|
||||
|
||||
from homeassistant.components import tts, wyoming
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -43,11 +44,11 @@ async def test_get_tts_audio(
|
||||
hass: HomeAssistant, init_wyoming_tts, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
"""Test get audio."""
|
||||
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
|
||||
assert entity is not None
|
||||
assert not entity.async_supports_streaming_input()
|
||||
|
||||
audio = bytes(100)
|
||||
audio_events = [
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
|
||||
# Verify audio
|
||||
audio_events = [
|
||||
@ -215,3 +216,52 @@ async def test_voice_speaker(
|
||||
),
|
||||
)
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
async def test_get_tts_audio_streaming(
|
||||
hass: HomeAssistant, init_wyoming_streaming_tts, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
"""Test get audio with streaming."""
|
||||
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_streaming_tts")
|
||||
assert entity is not None
|
||||
assert entity.async_supports_streaming_input()
|
||||
|
||||
audio = bytes(100)
|
||||
|
||||
# Verify audio
|
||||
audio_events = [
|
||||
AudioStart(rate=16000, width=2, channels=1).event(),
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
SynthesizeStopped().event(),
|
||||
]
|
||||
|
||||
async def message_gen():
|
||||
yield "Hello "
|
||||
yield "Word."
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
MockAsyncTcpClient(audio_events),
|
||||
) as mock_client:
|
||||
stream = tts.async_create_stream(
|
||||
hass,
|
||||
"tts.test_streaming_tts",
|
||||
"en-US",
|
||||
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
|
||||
)
|
||||
stream.async_set_message_stream(message_gen())
|
||||
data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||
|
||||
# Ensure client was disconnected properly
|
||||
assert mock_client.is_connected is False
|
||||
|
||||
assert data is not None
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
assert wav_file.getframerate() == 16000
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getnframes() == 0 # streaming
|
||||
assert data[44:] == audio # WAV header is 44 bytes
|
||||
|
||||
assert mock_client.written == snapshot
|
||||
|
Loading…
x
Reference in New Issue
Block a user