TTS to always stream when available (#148695)

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Paulus Schoutsen 2025-07-14 20:23:43 +02:00 committed by GitHub
parent c27a67db82
commit 124931b2ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 107 additions and 9 deletions

View File

@ -382,7 +382,7 @@ async def _async_convert_audio(
assert process.stderr assert process.stderr
stderr_data = await process.stderr.read() stderr_data = await process.stderr.read()
_LOGGER.error(stderr_data.decode()) _LOGGER.error(stderr_data.decode())
raise RuntimeError( raise HomeAssistantError(
f"Unexpected error while running ffmpeg with arguments: {command}. " f"Unexpected error while running ffmpeg with arguments: {command}. "
"See log for details." "See log for details."
) )
@ -976,7 +976,7 @@ class SpeechManager:
if engine_instance.name is None or engine_instance.name is UNDEFINED: if engine_instance.name is None or engine_instance.name is UNDEFINED:
raise HomeAssistantError("TTS engine name is not set.") raise HomeAssistantError("TTS engine name is not set.")
if isinstance(engine_instance, Provider) or isinstance(message_or_stream, str): if isinstance(engine_instance, Provider):
if isinstance(message_or_stream, str): if isinstance(message_or_stream, str):
message = message_or_stream message = message_or_stream
else: else:
@ -996,8 +996,18 @@ class SpeechManager:
data_gen = make_data_generator(data) data_gen = make_data_generator(data)
else: else:
if isinstance(message_or_stream, str):
async def gen_stream() -> AsyncGenerator[str]:
yield message_or_stream
stream = gen_stream()
else:
stream = message_or_stream
tts_result = await engine_instance.internal_async_stream_tts_audio( tts_result = await engine_instance.internal_async_stream_tts_audio(
TTSAudioRequest(language, options, message_or_stream) TTSAudioRequest(language, options, stream)
) )
extension = tts_result.extension extension = tts_result.extension
data_gen = tts_result.data_gen data_gen = tts_result.data_gen

View File

@ -1,5 +1,5 @@
# serializer version: 1 # serializer version: 1
# name: test_chat_log_tts_streaming[to_stream_deltas0-0-] # name: test_chat_log_tts_streaming[to_stream_deltas0-1-hello, how are you?]
list([ list([
dict({ dict({
'data': dict({ 'data': dict({

View File

@ -1550,9 +1550,9 @@ async def test_pipeline_language_used_instead_of_conversation_language(
"?", "?",
], ],
), ),
# We are not streaming, so 0 chunks via streaming method # We always stream when possible, so 1 chunk via streaming method
0, 1,
"", "hello, how are you?",
), ),
# Size above STREAM_RESPONSE_CHUNKS # Size above STREAM_RESPONSE_CHUNKS
( (

View File

@ -1835,7 +1835,7 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
async def bad_data_gen(): async def bad_data_gen():
yield bytes(0) yield bytes(0)
with pytest.raises(RuntimeError): with pytest.raises(HomeAssistantError):
# Simulate a bad WAV file # Simulate a bad WAV file
async for _chunk in tts._async_convert_audio( async for _chunk in tts._async_convert_audio(
hass, "wav", bad_data_gen(), "mp3" hass, "wav", bad_data_gen(), "mp3"

View File

@ -1,6 +1,19 @@
# serializer version: 1 # serializer version: 1
# name: test_get_tts_audio # name: test_get_tts_audio
list([ list([
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({ dict({
'data': dict({ 'data': dict({
'text': 'Hello world', 'text': 'Hello world',
@ -8,10 +21,29 @@
'payload': None, 'payload': None,
'type': 'synthesize', 'type': 'synthesize',
}), }),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
]) ])
# --- # ---
# name: test_get_tts_audio_different_formats # name: test_get_tts_audio_different_formats
list([ list([
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({ dict({
'data': dict({ 'data': dict({
'text': 'Hello world', 'text': 'Hello world',
@ -19,10 +51,29 @@
'payload': None, 'payload': None,
'type': 'synthesize', 'type': 'synthesize',
}), }),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
]) ])
# --- # ---
# name: test_get_tts_audio_different_formats.1 # name: test_get_tts_audio_different_formats.1
list([ list([
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({ dict({
'data': dict({ 'data': dict({
'text': 'Hello world', 'text': 'Hello world',
@ -30,6 +81,12 @@
'payload': None, 'payload': None,
'type': 'synthesize', 'type': 'synthesize',
}), }),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
]) ])
# --- # ---
# name: test_get_tts_audio_streaming # name: test_get_tts_audio_streaming
@ -71,6 +128,23 @@
# --- # ---
# name: test_voice_speaker # name: test_voice_speaker
list([ list([
dict({
'data': dict({
'voice': dict({
'name': 'voice1',
'speaker': 'speaker1',
}),
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({ dict({
'data': dict({ 'data': dict({
'text': 'Hello world', 'text': 'Hello world',
@ -82,5 +156,11 @@
'payload': None, 'payload': None,
'type': 'synthesize', 'type': 'synthesize',
}), }),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
]) ])
# --- # ---

View File

@ -52,6 +52,7 @@ async def test_get_tts_audio(
# Verify audio # Verify audio
audio_events = [ audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(), AudioStop().event(),
] ]
@ -77,7 +78,10 @@ async def test_get_tts_audio(
assert wav_file.getframerate() == 16000 assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2 assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1 assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
# nframes = 0 due to streaming
assert len(data) == len(audio) + 44 # WAVE header is 44 bytes
assert data[44:] == audio
assert mock_client.written == snapshot assert mock_client.written == snapshot
@ -88,6 +92,7 @@ async def test_get_tts_audio_different_formats(
"""Test changing preferred audio format.""" """Test changing preferred audio format."""
audio = bytes(16000 * 2 * 1) # one second audio = bytes(16000 * 2 * 1) # one second
audio_events = [ audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(), AudioStop().event(),
] ]
@ -123,6 +128,7 @@ async def test_get_tts_audio_different_formats(
# MP3 is the default # MP3 is the default
audio_events = [ audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(), AudioStop().event(),
] ]
@ -167,6 +173,7 @@ async def test_get_tts_audio_audio_oserror(
"""Test get audio and error raising.""" """Test get audio and error raising."""
audio = bytes(100) audio = bytes(100)
audio_events = [ audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(), AudioStop().event(),
] ]
@ -197,6 +204,7 @@ async def test_voice_speaker(
"""Test using a different voice and speaker.""" """Test using a different voice and speaker."""
audio = bytes(100) audio = bytes(100)
audio_events = [ audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(), AudioStop().event(),
] ]