TTS to always stream when available (#148695)

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Paulus Schoutsen 2025-07-14 20:23:43 +02:00 committed by GitHub
parent c27a67db82
commit 124931b2ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 107 additions and 9 deletions

View File

@ -382,7 +382,7 @@ async def _async_convert_audio(
assert process.stderr
stderr_data = await process.stderr.read()
_LOGGER.error(stderr_data.decode())
raise RuntimeError(
raise HomeAssistantError(
f"Unexpected error while running ffmpeg with arguments: {command}. "
"See log for details."
)
@ -976,7 +976,7 @@ class SpeechManager:
if engine_instance.name is None or engine_instance.name is UNDEFINED:
raise HomeAssistantError("TTS engine name is not set.")
if isinstance(engine_instance, Provider) or isinstance(message_or_stream, str):
if isinstance(engine_instance, Provider):
if isinstance(message_or_stream, str):
message = message_or_stream
else:
@ -996,8 +996,18 @@ class SpeechManager:
data_gen = make_data_generator(data)
else:
if isinstance(message_or_stream, str):
async def gen_stream() -> AsyncGenerator[str]:
yield message_or_stream
stream = gen_stream()
else:
stream = message_or_stream
tts_result = await engine_instance.internal_async_stream_tts_audio(
TTSAudioRequest(language, options, message_or_stream)
TTSAudioRequest(language, options, stream)
)
extension = tts_result.extension
data_gen = tts_result.data_gen

View File

@ -1,5 +1,5 @@
# serializer version: 1
# name: test_chat_log_tts_streaming[to_stream_deltas0-0-]
# name: test_chat_log_tts_streaming[to_stream_deltas0-1-hello, how are you?]
list([
dict({
'data': dict({

View File

@ -1550,9 +1550,9 @@ async def test_pipeline_language_used_instead_of_conversation_language(
"?",
],
),
# We are not streaming, so 0 chunks via streaming method
0,
"",
# We always stream when possible, so 1 chunk via streaming method
1,
"hello, how are you?",
),
# Size above STREAM_RESPONSE_CHUNKS
(

View File

@ -1835,7 +1835,7 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
async def bad_data_gen():
yield bytes(0)
with pytest.raises(RuntimeError):
with pytest.raises(HomeAssistantError):
# Simulate a bad WAV file
async for _chunk in tts._async_convert_audio(
hass, "wav", bad_data_gen(), "mp3"

View File

@ -1,6 +1,19 @@
# serializer version: 1
# name: test_get_tts_audio
list([
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({
'data': dict({
'text': 'Hello world',
@ -8,10 +21,29 @@
'payload': None,
'type': 'synthesize',
}),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
])
# ---
# name: test_get_tts_audio_different_formats
list([
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({
'data': dict({
'text': 'Hello world',
@ -19,10 +51,29 @@
'payload': None,
'type': 'synthesize',
}),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
])
# ---
# name: test_get_tts_audio_different_formats.1
list([
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({
'data': dict({
'text': 'Hello world',
@ -30,6 +81,12 @@
'payload': None,
'type': 'synthesize',
}),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
])
# ---
# name: test_get_tts_audio_streaming
@ -71,6 +128,23 @@
# ---
# name: test_voice_speaker
list([
dict({
'data': dict({
'voice': dict({
'name': 'voice1',
'speaker': 'speaker1',
}),
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({
'data': dict({
'text': 'Hello world',
@ -82,5 +156,11 @@
'payload': None,
'type': 'synthesize',
}),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
])
# ---

View File

@ -52,6 +52,7 @@ async def test_get_tts_audio(
# Verify audio
audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
@ -77,7 +78,10 @@ async def test_get_tts_audio(
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
# nframes = 0 due to streaming
assert len(data) == len(audio) + 44 # WAVE header is 44 bytes
assert data[44:] == audio
assert mock_client.written == snapshot
@ -88,6 +92,7 @@ async def test_get_tts_audio_different_formats(
"""Test changing preferred audio format."""
audio = bytes(16000 * 2 * 1) # one second
audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
@ -123,6 +128,7 @@ async def test_get_tts_audio_different_formats(
# MP3 is the default
audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
@ -167,6 +173,7 @@ async def test_get_tts_audio_audio_oserror(
"""Test get audio and error raising."""
audio = bytes(100)
audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
@ -197,6 +204,7 @@ async def test_voice_speaker(
"""Test using a different voice and speaker."""
audio = bytes(100)
audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]