Add streaming to cloud TTS (#148925)

This commit is contained in:
Michael Hansen 2025-07-21 10:33:23 -05:00 committed by GitHub
parent 3c70932357
commit 3f42911af4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 218 additions and 61 deletions

View File

@ -17,6 +17,8 @@ from homeassistant.components.tts import (
PLATFORM_SCHEMA as TTS_PLATFORM_SCHEMA,
Provider,
TextToSpeechEntity,
TTSAudioRequest,
TTSAudioResponse,
TtsAudioType,
Voice,
)
@ -332,7 +334,7 @@ class CloudTTSEntity(TextToSpeechEntity):
def default_options(self) -> dict[str, str]:
"""Return a dict include default options."""
return {
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
ATTR_AUDIO_OUTPUT: AudioOutput.MP3.value,
}
@property
@ -433,6 +435,29 @@ class CloudTTSEntity(TextToSpeechEntity):
return (options[ATTR_AUDIO_OUTPUT], data)
async def async_stream_tts_audio(
self, request: TTSAudioRequest
) -> TTSAudioResponse:
"""Generate speech from an incoming message."""
data_gen = self.cloud.voice.process_tts_stream(
text_stream=request.message_gen,
**_prepare_voice_args(
hass=self.hass,
language=request.language,
voice=request.options.get(
ATTR_VOICE,
(
self._voice
if request.language == self._language
else DEFAULT_VOICES[request.language]
),
),
gender=request.options.get(ATTR_GENDER),
),
)
return TTSAudioResponse(AudioOutput.WAV.value, data_gen)
class CloudProvider(Provider):
"""Home Assistant Cloud speech API provider."""
@ -526,9 +551,11 @@ class CloudProvider(Provider):
language=language,
voice=options.get(
ATTR_VOICE,
self._voice
if language == self._language
else DEFAULT_VOICES[language],
(
self._voice
if language == self._language
else DEFAULT_VOICES[language]
),
),
gender=options.get(ATTR_GENDER),
),

View File

@ -1,10 +1,12 @@
"""Tests for cloud tts."""
from collections.abc import AsyncGenerator, Callable, Coroutine
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Coroutine
from copy import deepcopy
from http import HTTPStatus
import io
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import wave
from hass_nabucasa.voice import VoiceError, VoiceTokenError
from hass_nabucasa.voice_data import TTS_VOICES
@ -239,6 +241,12 @@ async def test_get_tts_audio(
side_effect=mock_process_tts_side_effect,
)
cloud.voice.process_tts = mock_process_tts
mock_process_tts_stream = _make_stream_mock("There is someone at the door.")
if mock_process_tts_side_effect:
mock_process_tts_stream.side_effect = mock_process_tts_side_effect
cloud.voice.process_tts_stream = mock_process_tts_stream
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
await hass.async_block_till_done()
@ -262,13 +270,27 @@ async def test_get_tts_audio(
}
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
assert mock_process_tts.call_args.kwargs["gender"] is None
assert mock_process_tts.call_args.kwargs["voice"] == "JennyNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
# Force streaming
await client.get(response["path"])
if data.get("engine_id", "").startswith("tts."):
# Streaming
assert mock_process_tts_stream.call_count == 1
assert mock_process_tts_stream.call_args is not None
assert mock_process_tts_stream.call_args.kwargs["language"] == "en-US"
assert mock_process_tts_stream.call_args.kwargs["gender"] is None
assert mock_process_tts_stream.call_args.kwargs["voice"] == "JennyNeural"
else:
# Non-streaming
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert (
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
)
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
assert mock_process_tts.call_args.kwargs["gender"] is None
assert mock_process_tts.call_args.kwargs["voice"] == "JennyNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
@pytest.mark.parametrize(
@ -321,10 +343,10 @@ async def test_get_tts_audio_logged_out(
@pytest.mark.parametrize(
("mock_process_tts_return_value", "mock_process_tts_side_effect"),
("mock_process_tts_side_effect"),
[
(b"", None),
(None, VoiceError("Boom!")),
(None,),
(VoiceError("Boom!"),),
],
)
async def test_tts_entity(
@ -332,15 +354,13 @@ async def test_tts_entity(
hass_client: ClientSessionGenerator,
entity_registry: EntityRegistry,
cloud: MagicMock,
mock_process_tts_return_value: bytes | None,
mock_process_tts_side_effect: Exception | None,
) -> None:
"""Test text-to-speech entity."""
mock_process_tts = AsyncMock(
return_value=mock_process_tts_return_value,
side_effect=mock_process_tts_side_effect,
)
cloud.voice.process_tts = mock_process_tts
mock_process_tts_stream = _make_stream_mock("There is someone at the door.")
if mock_process_tts_side_effect:
mock_process_tts_stream.side_effect = mock_process_tts_side_effect
cloud.voice.process_tts_stream = mock_process_tts_stream
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
await hass.async_block_till_done()
@ -372,13 +392,14 @@ async def test_tts_entity(
}
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
assert mock_process_tts.call_args.kwargs["gender"] is None
assert mock_process_tts.call_args.kwargs["voice"] == "JennyNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
# Force streaming
await client.get(response["path"])
assert mock_process_tts_stream.call_count == 1
assert mock_process_tts_stream.call_args is not None
assert mock_process_tts_stream.call_args.kwargs["language"] == "en-US"
assert mock_process_tts_stream.call_args.kwargs["gender"] is None
assert mock_process_tts_stream.call_args.kwargs["voice"] == "JennyNeural"
state = hass.states.get(entity_id)
assert state
@ -482,6 +503,8 @@ async def test_deprecated_voice(
return_value=b"",
)
cloud.voice.process_tts = mock_process_tts
mock_process_tts_stream = _make_stream_mock("There is someone at the door.")
cloud.voice.process_tts_stream = mock_process_tts_stream
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
await hass.async_block_till_done()
@ -509,18 +532,34 @@ async def test_deprecated_voice(
}
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == language
assert mock_process_tts.call_args.kwargs["gender"] is None
assert mock_process_tts.call_args.kwargs["voice"] == replacement_voice
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
# Force streaming
await client.get(response["path"])
if data.get("engine_id", "").startswith("tts."):
# Streaming
assert mock_process_tts_stream.call_count == 1
assert mock_process_tts_stream.call_args is not None
assert mock_process_tts_stream.call_args.kwargs["language"] == language
assert mock_process_tts_stream.call_args.kwargs["gender"] is None
assert mock_process_tts_stream.call_args.kwargs["voice"] == replacement_voice
else:
# Non-streaming
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert (
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
)
assert mock_process_tts.call_args.kwargs["language"] == language
assert mock_process_tts.call_args.kwargs["gender"] is None
assert mock_process_tts.call_args.kwargs["voice"] == replacement_voice
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
issue = issue_registry.async_get_issue(
"cloud", f"deprecated_voice_{replacement_voice}"
)
assert issue is None
mock_process_tts.reset_mock()
mock_process_tts_stream.reset_mock()
# Test with deprecated voice.
data["options"] = {"voice": deprecated_voice}
@ -538,15 +577,30 @@ async def test_deprecated_voice(
}
await hass.async_block_till_done()
# Force streaming
await client.get(response["path"])
issue_id = f"deprecated_voice_{deprecated_voice}"
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == language
assert mock_process_tts.call_args.kwargs["gender"] is None
assert mock_process_tts.call_args.kwargs["voice"] == replacement_voice
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
if data.get("engine_id", "").startswith("tts."):
# Streaming
assert mock_process_tts_stream.call_count == 1
assert mock_process_tts_stream.call_args is not None
assert mock_process_tts_stream.call_args.kwargs["language"] == language
assert mock_process_tts_stream.call_args.kwargs["gender"] is None
assert mock_process_tts_stream.call_args.kwargs["voice"] == replacement_voice
else:
# Non-streaming
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert (
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
)
assert mock_process_tts.call_args.kwargs["language"] == language
assert mock_process_tts.call_args.kwargs["gender"] is None
assert mock_process_tts.call_args.kwargs["voice"] == replacement_voice
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
issue = issue_registry.async_get_issue("cloud", issue_id)
assert issue is not None
assert issue.breaks_in_ha_version == "2024.8.0"
@ -623,6 +677,8 @@ async def test_deprecated_gender(
return_value=b"",
)
cloud.voice.process_tts = mock_process_tts
mock_process_tts_stream = _make_stream_mock("There is someone at the door.")
cloud.voice.process_tts_stream = mock_process_tts_stream
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
await hass.async_block_till_done()
@ -649,15 +705,30 @@ async def test_deprecated_gender(
}
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == language
assert mock_process_tts.call_args.kwargs["voice"] == "XiaoxiaoNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
# Force streaming
await client.get(response["path"])
if data.get("engine_id", "").startswith("tts."):
# Streaming
assert mock_process_tts_stream.call_count == 1
assert mock_process_tts_stream.call_args is not None
assert mock_process_tts_stream.call_args.kwargs["language"] == language
assert mock_process_tts_stream.call_args.kwargs["voice"] == "XiaoxiaoNeural"
else:
# Non-streaming
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert (
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
)
assert mock_process_tts.call_args.kwargs["language"] == language
assert mock_process_tts.call_args.kwargs["voice"] == "XiaoxiaoNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
issue = issue_registry.async_get_issue("cloud", "deprecated_gender")
assert issue is None
mock_process_tts.reset_mock()
mock_process_tts_stream.reset_mock()
# Test with deprecated gender option.
data["options"] = {"gender": gender_option}
@ -675,15 +746,30 @@ async def test_deprecated_gender(
}
await hass.async_block_till_done()
# Force streaming
await client.get(response["path"])
issue_id = "deprecated_gender"
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == language
assert mock_process_tts.call_args.kwargs["gender"] == gender_option
assert mock_process_tts.call_args.kwargs["voice"] == "XiaoxiaoNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
if data.get("engine_id", "").startswith("tts."):
# Streaming
assert mock_process_tts_stream.call_count == 1
assert mock_process_tts_stream.call_args is not None
assert mock_process_tts_stream.call_args.kwargs["language"] == language
assert mock_process_tts_stream.call_args.kwargs["gender"] == gender_option
assert mock_process_tts_stream.call_args.kwargs["voice"] == "XiaoxiaoNeural"
else:
# Non-streaming
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert (
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
)
assert mock_process_tts.call_args.kwargs["language"] == language
assert mock_process_tts.call_args.kwargs["gender"] == gender_option
assert mock_process_tts.call_args.kwargs["voice"] == "XiaoxiaoNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
issue = issue_registry.async_get_issue("cloud", issue_id)
assert issue is not None
assert issue.breaks_in_ha_version == "2024.10.0"
@ -772,6 +858,8 @@ async def test_tts_services(
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
mock_process_tts = AsyncMock(return_value=b"")
cloud.voice.process_tts = mock_process_tts
mock_process_tts_stream = _make_stream_mock("There is someone at the door.")
cloud.voice.process_tts_stream = mock_process_tts_stream
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
await hass.async_block_till_done()
@ -793,9 +881,51 @@ async def test_tts_services(
assert response.status == HTTPStatus.OK
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == service_data[ATTR_LANGUAGE]
assert mock_process_tts.call_args.kwargs["voice"] == "GadisNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
if service_data.get("entity_id", "").startswith("tts."):
# Streaming
assert mock_process_tts_stream.call_count == 1
assert mock_process_tts_stream.call_args is not None
assert (
mock_process_tts_stream.call_args.kwargs["language"]
== service_data[ATTR_LANGUAGE]
)
assert mock_process_tts_stream.call_args.kwargs["voice"] == "GadisNeural"
else:
# Non-streaming
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert (
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
)
assert (
mock_process_tts.call_args.kwargs["language"] == service_data[ATTR_LANGUAGE]
)
assert mock_process_tts.call_args.kwargs["voice"] == "GadisNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
def _make_stream_mock(expected_text: str) -> MagicMock:
"""Create a mock TTS stream generator with just a WAV header."""
with io.BytesIO() as wav_io:
wav_writer: wave.Wave_write = wave.open(wav_io, "wb")
with wav_writer:
wav_writer.setframerate(24000)
wav_writer.setsampwidth(2)
wav_writer.setnchannels(1)
wav_io.seek(0)
wav_bytes = wav_io.getvalue()
process_tts_stream = MagicMock()
async def fake_process_tts_stream(*, text_stream: AsyncIterable[str], **kwargs):
# Verify text
actual_text = "".join([text_chunk async for text_chunk in text_stream])
assert actual_text == expected_text
# WAV header
yield wav_bytes
process_tts_stream.side_effect = fake_process_tts_stream
return process_tts_stream