mirror of
https://github.com/home-assistant/core.git
synced 2025-07-08 13:57:10 +00:00
Support streaming TTS in wyoming (#147392)
* Support streaming TTS in wyoming * Add test * Refactor to avoid repeated task creation * Manually manage client lifecycle
This commit is contained in:
parent
3dc8676b99
commit
cefc8822b6
@ -1,13 +1,21 @@
|
|||||||
"""Support for Wyoming text-to-speech services."""
|
"""Support for Wyoming text-to-speech services."""
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
from wyoming.audio import AudioChunk, AudioStop
|
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||||
from wyoming.client import AsyncTcpClient
|
from wyoming.client import AsyncTcpClient
|
||||||
from wyoming.tts import Synthesize, SynthesizeVoice
|
from wyoming.tts import (
|
||||||
|
Synthesize,
|
||||||
|
SynthesizeChunk,
|
||||||
|
SynthesizeStart,
|
||||||
|
SynthesizeStop,
|
||||||
|
SynthesizeStopped,
|
||||||
|
SynthesizeVoice,
|
||||||
|
)
|
||||||
|
|
||||||
from homeassistant.components import tts
|
from homeassistant.components import tts
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
@ -45,6 +53,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
|||||||
service: WyomingService,
|
service: WyomingService,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up provider."""
|
"""Set up provider."""
|
||||||
|
self.config_entry = config_entry
|
||||||
self.service = service
|
self.service = service
|
||||||
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
|
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
|
||||||
|
|
||||||
@ -150,3 +159,98 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
|||||||
return (None, None)
|
return (None, None)
|
||||||
|
|
||||||
return ("wav", data)
|
return ("wav", data)
|
||||||
|
|
||||||
|
def async_supports_streaming_input(self) -> bool:
|
||||||
|
"""Return if the TTS engine supports streaming input."""
|
||||||
|
return self._tts_service.supports_synthesize_streaming
|
||||||
|
|
||||||
|
async def async_stream_tts_audio(
|
||||||
|
self, request: tts.TTSAudioRequest
|
||||||
|
) -> tts.TTSAudioResponse:
|
||||||
|
"""Generate speech from an incoming message."""
|
||||||
|
voice_name: str | None = request.options.get(tts.ATTR_VOICE)
|
||||||
|
voice_speaker: str | None = request.options.get(ATTR_SPEAKER)
|
||||||
|
voice: SynthesizeVoice | None = None
|
||||||
|
if voice_name is not None:
|
||||||
|
voice = SynthesizeVoice(name=voice_name, speaker=voice_speaker)
|
||||||
|
|
||||||
|
client = AsyncTcpClient(self.service.host, self.service.port)
|
||||||
|
await client.connect()
|
||||||
|
|
||||||
|
# Stream text chunks to client
|
||||||
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._write_tts_message(request.message_gen, client, voice),
|
||||||
|
"wyoming tts write",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def data_gen():
|
||||||
|
# Stream audio bytes from client
|
||||||
|
try:
|
||||||
|
async for data_chunk in self._read_tts_audio(client):
|
||||||
|
yield data_chunk
|
||||||
|
finally:
|
||||||
|
await client.disconnect()
|
||||||
|
|
||||||
|
return tts.TTSAudioResponse("wav", data_gen())
|
||||||
|
|
||||||
|
async def _write_tts_message(
|
||||||
|
self,
|
||||||
|
message_gen: AsyncGenerator[str],
|
||||||
|
client: AsyncTcpClient,
|
||||||
|
voice: SynthesizeVoice | None,
|
||||||
|
) -> None:
|
||||||
|
"""Write text chunks to the client."""
|
||||||
|
try:
|
||||||
|
# Start stream
|
||||||
|
await client.write_event(SynthesizeStart(voice=voice).event())
|
||||||
|
|
||||||
|
# Accumulate entire message for synthesize event.
|
||||||
|
message = ""
|
||||||
|
async for message_chunk in message_gen:
|
||||||
|
message += message_chunk
|
||||||
|
|
||||||
|
await client.write_event(SynthesizeChunk(text=message_chunk).event())
|
||||||
|
|
||||||
|
# Send entire message for backwards compatibility
|
||||||
|
await client.write_event(Synthesize(text=message, voice=voice).event())
|
||||||
|
|
||||||
|
# End stream
|
||||||
|
await client.write_event(SynthesizeStop().event())
|
||||||
|
except (OSError, WyomingError):
|
||||||
|
# Disconnected
|
||||||
|
_LOGGER.warning("Unexpected disconnection from TTS client")
|
||||||
|
|
||||||
|
async def _read_tts_audio(self, client: AsyncTcpClient) -> AsyncGenerator[bytes]:
|
||||||
|
"""Read audio events from the client and yield WAV audio chunks.
|
||||||
|
|
||||||
|
The WAV header is sent first with a frame count of 0 to indicate that
|
||||||
|
we're streaming and don't know the number of frames ahead of time.
|
||||||
|
"""
|
||||||
|
wav_header_sent = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
while event := await client.read_event():
|
||||||
|
if wav_header_sent and AudioChunk.is_type(event.type):
|
||||||
|
# PCM audio
|
||||||
|
yield AudioChunk.from_event(event).audio
|
||||||
|
elif (not wav_header_sent) and AudioStart.is_type(event.type):
|
||||||
|
# WAV header with nframes = 0 for streaming
|
||||||
|
audio_start = AudioStart.from_event(event)
|
||||||
|
with io.BytesIO() as wav_io:
|
||||||
|
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
||||||
|
with wav_file:
|
||||||
|
wav_file.setframerate(audio_start.rate)
|
||||||
|
wav_file.setsampwidth(audio_start.width)
|
||||||
|
wav_file.setnchannels(audio_start.channels)
|
||||||
|
|
||||||
|
wav_io.seek(0)
|
||||||
|
yield wav_io.getvalue()
|
||||||
|
|
||||||
|
wav_header_sent = True
|
||||||
|
elif SynthesizeStopped.is_type(event.type):
|
||||||
|
# All TTS audio has been received
|
||||||
|
break
|
||||||
|
except (OSError, WyomingError):
|
||||||
|
# Disconnected
|
||||||
|
_LOGGER.warning("Unexpected disconnection from TTS client")
|
||||||
|
@ -69,6 +69,29 @@ TTS_INFO = Info(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
TTS_STREAMING_INFO = Info(
|
||||||
|
tts=[
|
||||||
|
TtsProgram(
|
||||||
|
name="Test Streaming TTS",
|
||||||
|
description="Test Streaming TTS",
|
||||||
|
installed=True,
|
||||||
|
attribution=TEST_ATTR,
|
||||||
|
voices=[
|
||||||
|
TtsVoice(
|
||||||
|
name="Test Voice",
|
||||||
|
description="Test Voice",
|
||||||
|
installed=True,
|
||||||
|
attribution=TEST_ATTR,
|
||||||
|
languages=["en-US"],
|
||||||
|
speakers=[TtsVoiceSpeaker(name="Test Speaker")],
|
||||||
|
version=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
version=None,
|
||||||
|
supports_synthesize_streaming=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
WAKE_WORD_INFO = Info(
|
WAKE_WORD_INFO = Info(
|
||||||
wake=[
|
wake=[
|
||||||
WakeProgram(
|
WakeProgram(
|
||||||
@ -155,9 +178,15 @@ class MockAsyncTcpClient:
|
|||||||
self.port: int | None = None
|
self.port: int | None = None
|
||||||
self.written: list[Event] = []
|
self.written: list[Event] = []
|
||||||
self.responses = responses
|
self.responses = responses
|
||||||
|
self.is_connected: bool | None = None
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""Connect."""
|
"""Connect."""
|
||||||
|
self.is_connected = True
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Disconnect."""
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
async def write_event(self, event: Event):
|
async def write_event(self, event: Event):
|
||||||
"""Send."""
|
"""Send."""
|
||||||
|
@ -19,6 +19,7 @@ from . import (
|
|||||||
SATELLITE_INFO,
|
SATELLITE_INFO,
|
||||||
STT_INFO,
|
STT_INFO,
|
||||||
TTS_INFO,
|
TTS_INFO,
|
||||||
|
TTS_STREAMING_INFO,
|
||||||
WAKE_WORD_INFO,
|
WAKE_WORD_INFO,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -148,6 +149,20 @@ async def init_wyoming_tts(
|
|||||||
return tts_config_entry
|
return tts_config_entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def init_wyoming_streaming_tts(
|
||||||
|
hass: HomeAssistant, tts_config_entry: ConfigEntry
|
||||||
|
) -> ConfigEntry:
|
||||||
|
"""Initialize Wyoming streaming TTS."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=TTS_STREAMING_INFO,
|
||||||
|
):
|
||||||
|
await hass.config_entries.async_setup(tts_config_entry.entry_id)
|
||||||
|
|
||||||
|
return tts_config_entry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def init_wyoming_wake_word(
|
async def init_wyoming_wake_word(
|
||||||
hass: HomeAssistant, wake_word_config_entry: ConfigEntry
|
hass: HomeAssistant, wake_word_config_entry: ConfigEntry
|
||||||
|
@ -32,6 +32,43 @@
|
|||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_get_tts_audio_streaming
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'synthesize-start',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'text': 'Hello ',
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'synthesize-chunk',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'text': 'Word.',
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'synthesize-chunk',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'text': 'Hello Word.',
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'synthesize',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'synthesize-stop',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_voice_speaker
|
# name: test_voice_speaker
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
|
@ -8,7 +8,8 @@ import wave
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
from wyoming.audio import AudioChunk, AudioStop
|
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||||
|
from wyoming.tts import SynthesizeStopped
|
||||||
|
|
||||||
from homeassistant.components import tts, wyoming
|
from homeassistant.components import tts, wyoming
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -43,11 +44,11 @@ async def test_get_tts_audio(
|
|||||||
hass: HomeAssistant, init_wyoming_tts, snapshot: SnapshotAssertion
|
hass: HomeAssistant, init_wyoming_tts, snapshot: SnapshotAssertion
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test get audio."""
|
"""Test get audio."""
|
||||||
|
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
|
||||||
|
assert entity is not None
|
||||||
|
assert not entity.async_supports_streaming_input()
|
||||||
|
|
||||||
audio = bytes(100)
|
audio = bytes(100)
|
||||||
audio_events = [
|
|
||||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
|
||||||
AudioStop().event(),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Verify audio
|
# Verify audio
|
||||||
audio_events = [
|
audio_events = [
|
||||||
@ -215,3 +216,52 @@ async def test_voice_speaker(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert mock_client.written == snapshot
|
assert mock_client.written == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_tts_audio_streaming(
|
||||||
|
hass: HomeAssistant, init_wyoming_streaming_tts, snapshot: SnapshotAssertion
|
||||||
|
) -> None:
|
||||||
|
"""Test get audio with streaming."""
|
||||||
|
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_streaming_tts")
|
||||||
|
assert entity is not None
|
||||||
|
assert entity.async_supports_streaming_input()
|
||||||
|
|
||||||
|
audio = bytes(100)
|
||||||
|
|
||||||
|
# Verify audio
|
||||||
|
audio_events = [
|
||||||
|
AudioStart(rate=16000, width=2, channels=1).event(),
|
||||||
|
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||||
|
AudioStop().event(),
|
||||||
|
SynthesizeStopped().event(),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def message_gen():
|
||||||
|
yield "Hello "
|
||||||
|
yield "Word."
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient(audio_events),
|
||||||
|
) as mock_client:
|
||||||
|
stream = tts.async_create_stream(
|
||||||
|
hass,
|
||||||
|
"tts.test_streaming_tts",
|
||||||
|
"en-US",
|
||||||
|
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
|
||||||
|
)
|
||||||
|
stream.async_set_message_stream(message_gen())
|
||||||
|
data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||||
|
|
||||||
|
# Ensure client was disconnected properly
|
||||||
|
assert mock_client.is_connected is False
|
||||||
|
|
||||||
|
assert data is not None
|
||||||
|
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||||
|
assert wav_file.getframerate() == 16000
|
||||||
|
assert wav_file.getsampwidth() == 2
|
||||||
|
assert wav_file.getnchannels() == 1
|
||||||
|
assert wav_file.getnframes() == 0 # streaming
|
||||||
|
assert data[44:] == audio # WAV header is 44 bytes
|
||||||
|
|
||||||
|
assert mock_client.written == snapshot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user