diff --git a/homeassistant/components/wyoming/tts.py b/homeassistant/components/wyoming/tts.py index 79e431fee98..cf088c04d9f 100644 --- a/homeassistant/components/wyoming/tts.py +++ b/homeassistant/components/wyoming/tts.py @@ -1,13 +1,21 @@ """Support for Wyoming text-to-speech services.""" from collections import defaultdict +from collections.abc import AsyncGenerator import io import logging import wave -from wyoming.audio import AudioChunk, AudioStop +from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.client import AsyncTcpClient -from wyoming.tts import Synthesize, SynthesizeVoice +from wyoming.tts import ( + Synthesize, + SynthesizeChunk, + SynthesizeStart, + SynthesizeStop, + SynthesizeStopped, + SynthesizeVoice, +) from homeassistant.components import tts from homeassistant.config_entries import ConfigEntry @@ -45,6 +53,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity): service: WyomingService, ) -> None: """Set up provider.""" + self.config_entry = config_entry self.service = service self._tts_service = next(tts for tts in service.info.tts if tts.installed) @@ -150,3 +159,98 @@ class WyomingTtsProvider(tts.TextToSpeechEntity): return (None, None) return ("wav", data) + + def async_supports_streaming_input(self) -> bool: + """Return if the TTS engine supports streaming input.""" + return self._tts_service.supports_synthesize_streaming + + async def async_stream_tts_audio( + self, request: tts.TTSAudioRequest + ) -> tts.TTSAudioResponse: + """Generate speech from an incoming message.""" + voice_name: str | None = request.options.get(tts.ATTR_VOICE) + voice_speaker: str | None = request.options.get(ATTR_SPEAKER) + voice: SynthesizeVoice | None = None + if voice_name is not None: + voice = SynthesizeVoice(name=voice_name, speaker=voice_speaker) + + client = AsyncTcpClient(self.service.host, self.service.port) + await client.connect() + + # Stream text chunks to client + self.config_entry.async_create_background_task( + self.hass, + self._write_tts_message(request.message_gen, client, voice), + "wyoming tts write", + ) + + async def data_gen(): + # Stream audio bytes from client + try: + async for data_chunk in self._read_tts_audio(client): + yield data_chunk + finally: + await client.disconnect() + + return tts.TTSAudioResponse("wav", data_gen()) + + async def _write_tts_message( + self, + message_gen: AsyncGenerator[str], + client: AsyncTcpClient, + voice: SynthesizeVoice | None, + ) -> None: + """Write text chunks to the client.""" + try: + # Start stream + await client.write_event(SynthesizeStart(voice=voice).event()) + + # Accumulate entire message for synthesize event. + message = "" + async for message_chunk in message_gen: + message += message_chunk + + await client.write_event(SynthesizeChunk(text=message_chunk).event()) + + # Send entire message for backwards compatibility + await client.write_event(Synthesize(text=message, voice=voice).event()) + + # End stream + await client.write_event(SynthesizeStop().event()) + except (OSError, WyomingError): + # Disconnected + _LOGGER.warning("Unexpected disconnection from TTS client") + + async def _read_tts_audio(self, client: AsyncTcpClient) -> AsyncGenerator[bytes]: + """Read audio events from the client and yield WAV audio chunks. + + The WAV header is sent first with a frame count of 0 to indicate that + we're streaming and don't know the number of frames ahead of time. + """ + wav_header_sent = False + + try: + while event := await client.read_event(): + if wav_header_sent and AudioChunk.is_type(event.type): + # PCM audio + yield AudioChunk.from_event(event).audio + elif (not wav_header_sent) and AudioStart.is_type(event.type): + # WAV header with nframes = 0 for streaming + audio_start = AudioStart.from_event(event) + with io.BytesIO() as wav_io: + wav_file: wave.Wave_write = wave.open(wav_io, "wb") + with wav_file: + wav_file.setframerate(audio_start.rate) + wav_file.setsampwidth(audio_start.width) + wav_file.setnchannels(audio_start.channels) + + wav_io.seek(0) + yield wav_io.getvalue() + + wav_header_sent = True + elif SynthesizeStopped.is_type(event.type): + # All TTS audio has been received + break + except (OSError, WyomingError): + # Disconnected + _LOGGER.warning("Unexpected disconnection from TTS client") diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index 4540cdaabfd..de82dc08719 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -69,6 +69,29 @@ TTS_INFO = Info( ) ] ) +TTS_STREAMING_INFO = Info( + tts=[ + TtsProgram( + name="Test Streaming TTS", + description="Test Streaming TTS", + installed=True, + attribution=TEST_ATTR, + voices=[ + TtsVoice( + name="Test Voice", + description="Test Voice", + installed=True, + attribution=TEST_ATTR, + languages=["en-US"], + speakers=[TtsVoiceSpeaker(name="Test Speaker")], + version=None, + ) + ], + version=None, + supports_synthesize_streaming=True, + ) + ] +) WAKE_WORD_INFO = Info( wake=[ WakeProgram( @@ -155,9 +178,15 @@ class MockAsyncTcpClient: self.port: int | None = None self.written: list[Event] = [] self.responses = responses + self.is_connected: bool | None = None async def connect(self) -> None: """Connect.""" + self.is_connected = True + + async def disconnect(self) -> None: + """Disconnect.""" + self.is_connected = False async def write_event(self, event: Event): """Send.""" diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index 125edc547c6..2974bb4b013 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -19,6 +19,7 @@ from . import ( SATELLITE_INFO, STT_INFO, TTS_INFO, + TTS_STREAMING_INFO, WAKE_WORD_INFO, ) @@ -148,6 +149,20 @@ async def init_wyoming_tts( return tts_config_entry +@pytest.fixture +async def init_wyoming_streaming_tts( + hass: HomeAssistant, tts_config_entry: ConfigEntry +) -> ConfigEntry: + """Initialize Wyoming streaming TTS.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=TTS_STREAMING_INFO, + ): + await hass.config_entries.async_setup(tts_config_entry.entry_id) + + return tts_config_entry + + @pytest.fixture async def init_wyoming_wake_word( hass: HomeAssistant, wake_word_config_entry: ConfigEntry diff --git a/tests/components/wyoming/snapshots/test_tts.ambr b/tests/components/wyoming/snapshots/test_tts.ambr index 7ca5204e66c..53cc02eaacf 100644 --- a/tests/components/wyoming/snapshots/test_tts.ambr +++ b/tests/components/wyoming/snapshots/test_tts.ambr @@ -32,6 +32,43 @@ }), ]) # --- +# name: test_get_tts_audio_streaming + list([ + dict({ + 'data': dict({ + }), + 'payload': None, + 'type': 'synthesize-start', + }), + dict({ + 'data': dict({ + 'text': 'Hello ', + }), + 'payload': None, + 'type': 'synthesize-chunk', + }), + dict({ + 'data': dict({ + 'text': 'Word.', + }), + 'payload': None, + 'type': 'synthesize-chunk', + }), + dict({ + 'data': dict({ + 'text': 'Hello Word.', + }), + 'payload': None, + 'type': 'synthesize', + }), + dict({ + 'data': dict({ + }), + 'payload': None, + 'type': 'synthesize-stop', + }), + ]) +# --- # name: test_voice_speaker list([ dict({ diff --git a/tests/components/wyoming/test_tts.py b/tests/components/wyoming/test_tts.py index c658bff1d0c..3374328f411 100644 --- a/tests/components/wyoming/test_tts.py +++ b/tests/components/wyoming/test_tts.py @@ -8,7 +8,8 @@ import wave import pytest from syrupy.assertion import SnapshotAssertion -from wyoming.audio import AudioChunk, AudioStop +from wyoming.audio import AudioChunk, AudioStart, AudioStop +from wyoming.tts import SynthesizeStopped from homeassistant.components import tts, wyoming from homeassistant.core import HomeAssistant @@ -43,11 +44,11 @@ async def test_get_tts_audio( hass: HomeAssistant, init_wyoming_tts, snapshot: SnapshotAssertion ) -> None: """Test get audio.""" + entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts") + assert entity is not None + assert not entity.async_supports_streaming_input() + audio = bytes(100) - audio_events = [ - AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), - AudioStop().event(), - ] # Verify audio audio_events = [ @@ -215,3 +216,52 @@ async def test_voice_speaker( ), ) assert mock_client.written == snapshot + + +async def test_get_tts_audio_streaming( + hass: HomeAssistant, init_wyoming_streaming_tts, snapshot: SnapshotAssertion +) -> None: + """Test get audio with streaming.""" + entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_streaming_tts") + assert entity is not None + assert entity.async_supports_streaming_input() + + audio = bytes(100) + + # Verify audio + audio_events = [ + AudioStart(rate=16000, width=2, channels=1).event(), + AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), + AudioStop().event(), + SynthesizeStopped().event(), + ] + + async def message_gen(): + yield "Hello " + yield "Word." + + with patch( + "homeassistant.components.wyoming.tts.AsyncTcpClient", + MockAsyncTcpClient(audio_events), + ) as mock_client: + stream = tts.async_create_stream( + hass, + "tts.test_streaming_tts", + "en-US", + options={tts.ATTR_PREFERRED_FORMAT: "wav"}, + ) + stream.async_set_message_stream(message_gen()) + data = b"".join([chunk async for chunk in stream.async_stream_result()]) + + # Ensure client was disconnected properly + assert mock_client.is_connected is False + + assert data is not None + with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: + assert wav_file.getframerate() == 16000 + assert wav_file.getsampwidth() == 2 + assert wav_file.getnchannels() == 1 + assert wav_file.getnframes() == 0 # streaming + assert data[44:] == audio # WAV header is 44 bytes + + assert mock_client.written == snapshot