mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
TTS to always stream when available (#148695)
Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
parent
c27a67db82
commit
124931b2ee
@ -382,7 +382,7 @@ async def _async_convert_audio(
|
||||
assert process.stderr
|
||||
stderr_data = await process.stderr.read()
|
||||
_LOGGER.error(stderr_data.decode())
|
||||
raise RuntimeError(
|
||||
raise HomeAssistantError(
|
||||
f"Unexpected error while running ffmpeg with arguments: {command}. "
|
||||
"See log for details."
|
||||
)
|
||||
@ -976,7 +976,7 @@ class SpeechManager:
|
||||
if engine_instance.name is None or engine_instance.name is UNDEFINED:
|
||||
raise HomeAssistantError("TTS engine name is not set.")
|
||||
|
||||
if isinstance(engine_instance, Provider) or isinstance(message_or_stream, str):
|
||||
if isinstance(engine_instance, Provider):
|
||||
if isinstance(message_or_stream, str):
|
||||
message = message_or_stream
|
||||
else:
|
||||
@ -996,8 +996,18 @@ class SpeechManager:
|
||||
data_gen = make_data_generator(data)
|
||||
|
||||
else:
|
||||
if isinstance(message_or_stream, str):
|
||||
|
||||
async def gen_stream() -> AsyncGenerator[str]:
|
||||
yield message_or_stream
|
||||
|
||||
stream = gen_stream()
|
||||
|
||||
else:
|
||||
stream = message_or_stream
|
||||
|
||||
tts_result = await engine_instance.internal_async_stream_tts_audio(
|
||||
TTSAudioRequest(language, options, message_or_stream)
|
||||
TTSAudioRequest(language, options, stream)
|
||||
)
|
||||
extension = tts_result.extension
|
||||
data_gen = tts_result.data_gen
|
||||
|
@ -1,5 +1,5 @@
|
||||
# serializer version: 1
|
||||
# name: test_chat_log_tts_streaming[to_stream_deltas0-0-]
|
||||
# name: test_chat_log_tts_streaming[to_stream_deltas0-1-hello, how are you?]
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
|
@ -1550,9 +1550,9 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
"?",
|
||||
],
|
||||
),
|
||||
# We are not streaming, so 0 chunks via streaming method
|
||||
0,
|
||||
"",
|
||||
# We always stream when possible, so 1 chunk via streaming method
|
||||
1,
|
||||
"hello, how are you?",
|
||||
),
|
||||
# Size above STREAM_RESPONSE_CHUNKS
|
||||
(
|
||||
|
@ -1835,7 +1835,7 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
|
||||
async def bad_data_gen():
|
||||
yield bytes(0)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(HomeAssistantError):
|
||||
# Simulate a bad WAV file
|
||||
async for _chunk in tts._async_convert_audio(
|
||||
hass, "wav", bad_data_gen(), "mp3"
|
||||
|
@ -1,6 +1,19 @@
|
||||
# serializer version: 1
|
||||
# name: test_get_tts_audio
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-start',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
@ -8,10 +21,29 @@
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-stop',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_different_formats
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-start',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
@ -19,10 +51,29 @@
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-stop',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_different_formats.1
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-start',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
@ -30,6 +81,12 @@
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-stop',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_streaming
|
||||
@ -71,6 +128,23 @@
|
||||
# ---
|
||||
# name: test_voice_speaker
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'voice': dict({
|
||||
'name': 'voice1',
|
||||
'speaker': 'speaker1',
|
||||
}),
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-start',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
@ -82,5 +156,11 @@
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-stop',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
|
@ -52,6 +52,7 @@ async def test_get_tts_audio(
|
||||
|
||||
# Verify audio
|
||||
audio_events = [
|
||||
AudioStart(rate=16000, width=2, channels=1).event(),
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
@ -77,7 +78,10 @@ async def test_get_tts_audio(
|
||||
assert wav_file.getframerate() == 16000
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.readframes(wav_file.getnframes()) == audio
|
||||
|
||||
# nframes = 0 due to streaming
|
||||
assert len(data) == len(audio) + 44 # WAVE header is 44 bytes
|
||||
assert data[44:] == audio
|
||||
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
@ -88,6 +92,7 @@ async def test_get_tts_audio_different_formats(
|
||||
"""Test changing preferred audio format."""
|
||||
audio = bytes(16000 * 2 * 1) # one second
|
||||
audio_events = [
|
||||
AudioStart(rate=16000, width=2, channels=1).event(),
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
@ -123,6 +128,7 @@ async def test_get_tts_audio_different_formats(
|
||||
|
||||
# MP3 is the default
|
||||
audio_events = [
|
||||
AudioStart(rate=16000, width=2, channels=1).event(),
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
@ -167,6 +173,7 @@ async def test_get_tts_audio_audio_oserror(
|
||||
"""Test get audio and error raising."""
|
||||
audio = bytes(100)
|
||||
audio_events = [
|
||||
AudioStart(rate=16000, width=2, channels=1).event(),
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
@ -197,6 +204,7 @@ async def test_voice_speaker(
|
||||
"""Test using a different voice and speaker."""
|
||||
audio = bytes(100)
|
||||
audio_events = [
|
||||
AudioStart(rate=16000, width=2, channels=1).event(),
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user