mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
Add streaming to cloud TTS (#148925)
This commit is contained in:
parent
3c70932357
commit
3f42911af4
@ -17,6 +17,8 @@ from homeassistant.components.tts import (
|
||||
PLATFORM_SCHEMA as TTS_PLATFORM_SCHEMA,
|
||||
Provider,
|
||||
TextToSpeechEntity,
|
||||
TTSAudioRequest,
|
||||
TTSAudioResponse,
|
||||
TtsAudioType,
|
||||
Voice,
|
||||
)
|
||||
@ -332,7 +334,7 @@ class CloudTTSEntity(TextToSpeechEntity):
|
||||
def default_options(self) -> dict[str, str]:
|
||||
"""Return a dict include default options."""
|
||||
return {
|
||||
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
|
||||
ATTR_AUDIO_OUTPUT: AudioOutput.MP3.value,
|
||||
}
|
||||
|
||||
@property
|
||||
@ -433,6 +435,29 @@ class CloudTTSEntity(TextToSpeechEntity):
|
||||
|
||||
return (options[ATTR_AUDIO_OUTPUT], data)
|
||||
|
||||
async def async_stream_tts_audio(
|
||||
self, request: TTSAudioRequest
|
||||
) -> TTSAudioResponse:
|
||||
"""Generate speech from an incoming message."""
|
||||
data_gen = self.cloud.voice.process_tts_stream(
|
||||
text_stream=request.message_gen,
|
||||
**_prepare_voice_args(
|
||||
hass=self.hass,
|
||||
language=request.language,
|
||||
voice=request.options.get(
|
||||
ATTR_VOICE,
|
||||
(
|
||||
self._voice
|
||||
if request.language == self._language
|
||||
else DEFAULT_VOICES[request.language]
|
||||
),
|
||||
),
|
||||
gender=request.options.get(ATTR_GENDER),
|
||||
),
|
||||
)
|
||||
|
||||
return TTSAudioResponse(AudioOutput.WAV.value, data_gen)
|
||||
|
||||
|
||||
class CloudProvider(Provider):
|
||||
"""Home Assistant Cloud speech API provider."""
|
||||
@ -526,9 +551,11 @@ class CloudProvider(Provider):
|
||||
language=language,
|
||||
voice=options.get(
|
||||
ATTR_VOICE,
|
||||
self._voice
|
||||
if language == self._language
|
||||
else DEFAULT_VOICES[language],
|
||||
(
|
||||
self._voice
|
||||
if language == self._language
|
||||
else DEFAULT_VOICES[language]
|
||||
),
|
||||
),
|
||||
gender=options.get(ATTR_GENDER),
|
||||
),
|
||||
|
@ -1,10 +1,12 @@
|
||||
"""Tests for cloud tts."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Coroutine
|
||||
from copy import deepcopy
|
||||
from http import HTTPStatus
|
||||
import io
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import wave
|
||||
|
||||
from hass_nabucasa.voice import VoiceError, VoiceTokenError
|
||||
from hass_nabucasa.voice_data import TTS_VOICES
|
||||
@ -239,6 +241,12 @@ async def test_get_tts_audio(
|
||||
side_effect=mock_process_tts_side_effect,
|
||||
)
|
||||
cloud.voice.process_tts = mock_process_tts
|
||||
|
||||
mock_process_tts_stream = _make_stream_mock("There is someone at the door.")
|
||||
if mock_process_tts_side_effect:
|
||||
mock_process_tts_stream.side_effect = mock_process_tts_side_effect
|
||||
cloud.voice.process_tts_stream = mock_process_tts_stream
|
||||
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
|
||||
await hass.async_block_till_done()
|
||||
@ -262,13 +270,27 @@ async def test_get_tts_audio(
|
||||
}
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
|
||||
assert mock_process_tts.call_args.kwargs["gender"] is None
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == "JennyNeural"
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
# Force streaming
|
||||
await client.get(response["path"])
|
||||
|
||||
if data.get("engine_id", "").startswith("tts."):
|
||||
# Streaming
|
||||
assert mock_process_tts_stream.call_count == 1
|
||||
assert mock_process_tts_stream.call_args is not None
|
||||
assert mock_process_tts_stream.call_args.kwargs["language"] == "en-US"
|
||||
assert mock_process_tts_stream.call_args.kwargs["gender"] is None
|
||||
assert mock_process_tts_stream.call_args.kwargs["voice"] == "JennyNeural"
|
||||
else:
|
||||
# Non-streaming
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert (
|
||||
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
)
|
||||
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
|
||||
assert mock_process_tts.call_args.kwargs["gender"] is None
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == "JennyNeural"
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -321,10 +343,10 @@ async def test_get_tts_audio_logged_out(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mock_process_tts_return_value", "mock_process_tts_side_effect"),
|
||||
("mock_process_tts_side_effect"),
|
||||
[
|
||||
(b"", None),
|
||||
(None, VoiceError("Boom!")),
|
||||
(None,),
|
||||
(VoiceError("Boom!"),),
|
||||
],
|
||||
)
|
||||
async def test_tts_entity(
|
||||
@ -332,15 +354,13 @@ async def test_tts_entity(
|
||||
hass_client: ClientSessionGenerator,
|
||||
entity_registry: EntityRegistry,
|
||||
cloud: MagicMock,
|
||||
mock_process_tts_return_value: bytes | None,
|
||||
mock_process_tts_side_effect: Exception | None,
|
||||
) -> None:
|
||||
"""Test text-to-speech entity."""
|
||||
mock_process_tts = AsyncMock(
|
||||
return_value=mock_process_tts_return_value,
|
||||
side_effect=mock_process_tts_side_effect,
|
||||
)
|
||||
cloud.voice.process_tts = mock_process_tts
|
||||
mock_process_tts_stream = _make_stream_mock("There is someone at the door.")
|
||||
if mock_process_tts_side_effect:
|
||||
mock_process_tts_stream.side_effect = mock_process_tts_side_effect
|
||||
cloud.voice.process_tts_stream = mock_process_tts_stream
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
|
||||
await hass.async_block_till_done()
|
||||
@ -372,13 +392,14 @@ async def test_tts_entity(
|
||||
}
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
|
||||
assert mock_process_tts.call_args.kwargs["gender"] is None
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == "JennyNeural"
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
# Force streaming
|
||||
await client.get(response["path"])
|
||||
|
||||
assert mock_process_tts_stream.call_count == 1
|
||||
assert mock_process_tts_stream.call_args is not None
|
||||
assert mock_process_tts_stream.call_args.kwargs["language"] == "en-US"
|
||||
assert mock_process_tts_stream.call_args.kwargs["gender"] is None
|
||||
assert mock_process_tts_stream.call_args.kwargs["voice"] == "JennyNeural"
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
@ -482,6 +503,8 @@ async def test_deprecated_voice(
|
||||
return_value=b"",
|
||||
)
|
||||
cloud.voice.process_tts = mock_process_tts
|
||||
mock_process_tts_stream = _make_stream_mock("There is someone at the door.")
|
||||
cloud.voice.process_tts_stream = mock_process_tts_stream
|
||||
|
||||
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
|
||||
await hass.async_block_till_done()
|
||||
@ -509,18 +532,34 @@ async def test_deprecated_voice(
|
||||
}
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
assert mock_process_tts.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts.call_args.kwargs["gender"] is None
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == replacement_voice
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
# Force streaming
|
||||
await client.get(response["path"])
|
||||
|
||||
if data.get("engine_id", "").startswith("tts."):
|
||||
# Streaming
|
||||
assert mock_process_tts_stream.call_count == 1
|
||||
assert mock_process_tts_stream.call_args is not None
|
||||
assert mock_process_tts_stream.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts_stream.call_args.kwargs["gender"] is None
|
||||
assert mock_process_tts_stream.call_args.kwargs["voice"] == replacement_voice
|
||||
else:
|
||||
# Non-streaming
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert (
|
||||
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
)
|
||||
assert mock_process_tts.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts.call_args.kwargs["gender"] is None
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == replacement_voice
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
|
||||
issue = issue_registry.async_get_issue(
|
||||
"cloud", f"deprecated_voice_{replacement_voice}"
|
||||
)
|
||||
assert issue is None
|
||||
mock_process_tts.reset_mock()
|
||||
mock_process_tts_stream.reset_mock()
|
||||
|
||||
# Test with deprecated voice.
|
||||
data["options"] = {"voice": deprecated_voice}
|
||||
@ -538,15 +577,30 @@ async def test_deprecated_voice(
|
||||
}
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Force streaming
|
||||
await client.get(response["path"])
|
||||
|
||||
issue_id = f"deprecated_voice_{deprecated_voice}"
|
||||
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
assert mock_process_tts.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts.call_args.kwargs["gender"] is None
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == replacement_voice
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
if data.get("engine_id", "").startswith("tts."):
|
||||
# Streaming
|
||||
assert mock_process_tts_stream.call_count == 1
|
||||
assert mock_process_tts_stream.call_args is not None
|
||||
assert mock_process_tts_stream.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts_stream.call_args.kwargs["gender"] is None
|
||||
assert mock_process_tts_stream.call_args.kwargs["voice"] == replacement_voice
|
||||
else:
|
||||
# Non-streaming
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert (
|
||||
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
)
|
||||
assert mock_process_tts.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts.call_args.kwargs["gender"] is None
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == replacement_voice
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
|
||||
issue = issue_registry.async_get_issue("cloud", issue_id)
|
||||
assert issue is not None
|
||||
assert issue.breaks_in_ha_version == "2024.8.0"
|
||||
@ -623,6 +677,8 @@ async def test_deprecated_gender(
|
||||
return_value=b"",
|
||||
)
|
||||
cloud.voice.process_tts = mock_process_tts
|
||||
mock_process_tts_stream = _make_stream_mock("There is someone at the door.")
|
||||
cloud.voice.process_tts_stream = mock_process_tts_stream
|
||||
|
||||
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
|
||||
await hass.async_block_till_done()
|
||||
@ -649,15 +705,30 @@ async def test_deprecated_gender(
|
||||
}
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
assert mock_process_tts.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == "XiaoxiaoNeural"
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
# Force streaming
|
||||
await client.get(response["path"])
|
||||
|
||||
if data.get("engine_id", "").startswith("tts."):
|
||||
# Streaming
|
||||
assert mock_process_tts_stream.call_count == 1
|
||||
assert mock_process_tts_stream.call_args is not None
|
||||
assert mock_process_tts_stream.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts_stream.call_args.kwargs["voice"] == "XiaoxiaoNeural"
|
||||
else:
|
||||
# Non-streaming
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert (
|
||||
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
)
|
||||
assert mock_process_tts.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == "XiaoxiaoNeural"
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
|
||||
issue = issue_registry.async_get_issue("cloud", "deprecated_gender")
|
||||
assert issue is None
|
||||
mock_process_tts.reset_mock()
|
||||
mock_process_tts_stream.reset_mock()
|
||||
|
||||
# Test with deprecated gender option.
|
||||
data["options"] = {"gender": gender_option}
|
||||
@ -675,15 +746,30 @@ async def test_deprecated_gender(
|
||||
}
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Force streaming
|
||||
await client.get(response["path"])
|
||||
|
||||
issue_id = "deprecated_gender"
|
||||
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
assert mock_process_tts.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts.call_args.kwargs["gender"] == gender_option
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == "XiaoxiaoNeural"
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
if data.get("engine_id", "").startswith("tts."):
|
||||
# Streaming
|
||||
assert mock_process_tts_stream.call_count == 1
|
||||
assert mock_process_tts_stream.call_args is not None
|
||||
assert mock_process_tts_stream.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts_stream.call_args.kwargs["gender"] == gender_option
|
||||
assert mock_process_tts_stream.call_args.kwargs["voice"] == "XiaoxiaoNeural"
|
||||
else:
|
||||
# Non-streaming
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert (
|
||||
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
)
|
||||
assert mock_process_tts.call_args.kwargs["language"] == language
|
||||
assert mock_process_tts.call_args.kwargs["gender"] == gender_option
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == "XiaoxiaoNeural"
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
|
||||
issue = issue_registry.async_get_issue("cloud", issue_id)
|
||||
assert issue is not None
|
||||
assert issue.breaks_in_ha_version == "2024.10.0"
|
||||
@ -772,6 +858,8 @@ async def test_tts_services(
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
mock_process_tts = AsyncMock(return_value=b"")
|
||||
cloud.voice.process_tts = mock_process_tts
|
||||
mock_process_tts_stream = _make_stream_mock("There is someone at the door.")
|
||||
cloud.voice.process_tts_stream = mock_process_tts_stream
|
||||
|
||||
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
|
||||
await hass.async_block_till_done()
|
||||
@ -793,9 +881,51 @@ async def test_tts_services(
|
||||
assert response.status == HTTPStatus.OK
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
assert mock_process_tts.call_args.kwargs["language"] == service_data[ATTR_LANGUAGE]
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == "GadisNeural"
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
if service_data.get("entity_id", "").startswith("tts."):
|
||||
# Streaming
|
||||
assert mock_process_tts_stream.call_count == 1
|
||||
assert mock_process_tts_stream.call_args is not None
|
||||
assert (
|
||||
mock_process_tts_stream.call_args.kwargs["language"]
|
||||
== service_data[ATTR_LANGUAGE]
|
||||
)
|
||||
assert mock_process_tts_stream.call_args.kwargs["voice"] == "GadisNeural"
|
||||
else:
|
||||
# Non-streaming
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert (
|
||||
mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
)
|
||||
assert (
|
||||
mock_process_tts.call_args.kwargs["language"] == service_data[ATTR_LANGUAGE]
|
||||
)
|
||||
assert mock_process_tts.call_args.kwargs["voice"] == "GadisNeural"
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
|
||||
|
||||
def _make_stream_mock(expected_text: str) -> MagicMock:
|
||||
"""Create a mock TTS stream generator with just a WAV header."""
|
||||
with io.BytesIO() as wav_io:
|
||||
wav_writer: wave.Wave_write = wave.open(wav_io, "wb")
|
||||
with wav_writer:
|
||||
wav_writer.setframerate(24000)
|
||||
wav_writer.setsampwidth(2)
|
||||
wav_writer.setnchannels(1)
|
||||
|
||||
wav_io.seek(0)
|
||||
wav_bytes = wav_io.getvalue()
|
||||
|
||||
process_tts_stream = MagicMock()
|
||||
|
||||
async def fake_process_tts_stream(*, text_stream: AsyncIterable[str], **kwargs):
|
||||
# Verify text
|
||||
actual_text = "".join([text_chunk async for text_chunk in text_stream])
|
||||
assert actual_text == expected_text
|
||||
|
||||
# WAV header
|
||||
yield wav_bytes
|
||||
|
||||
process_tts_stream.side_effect = fake_process_tts_stream
|
||||
|
||||
return process_tts_stream
|
||||
|
Loading…
x
Reference in New Issue
Block a user