mirror of
https://github.com/home-assistant/core.git
synced 2025-11-13 04:50:17 +00:00
Compare commits
3 Commits
claude/tri
...
tts-cleanu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
15f930e6cc | ||
|
|
f8a496dff2 | ||
|
|
3b1b33d7f2 |
@@ -20,9 +20,6 @@ import hass_nabucasa
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation, stt, tts, wake_word, websocket_api
|
from homeassistant.components import conversation, stt, tts, wake_word, websocket_api
|
||||||
from homeassistant.components.tts import (
|
|
||||||
generate_media_source_id as tts_generate_media_source_id,
|
|
||||||
)
|
|
||||||
from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL
|
from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@@ -1276,33 +1273,19 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
# Synthesize audio and get URL
|
|
||||||
tts_media_id = tts_generate_media_source_id(
|
|
||||||
self.hass,
|
|
||||||
tts_input,
|
|
||||||
engine=self.tts_stream.engine,
|
|
||||||
language=self.tts_stream.language,
|
|
||||||
options=self.tts_stream.options,
|
|
||||||
)
|
|
||||||
except Exception as src_error:
|
|
||||||
_LOGGER.exception("Unexpected error during text-to-speech")
|
|
||||||
raise TextToSpeechError(
|
|
||||||
code="tts-failed",
|
|
||||||
message="Unexpected error during text-to-speech",
|
|
||||||
) from src_error
|
|
||||||
|
|
||||||
self.tts_stream.async_set_message(tts_input)
|
self.tts_stream.async_set_message(tts_input)
|
||||||
|
|
||||||
tts_output = {
|
|
||||||
"media_id": tts_media_id,
|
|
||||||
"token": self.tts_stream.token,
|
|
||||||
"url": self.tts_stream.url,
|
|
||||||
"mime_type": self.tts_stream.content_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.process_event(
|
self.process_event(
|
||||||
PipelineEvent(PipelineEventType.TTS_END, {"tts_output": tts_output})
|
PipelineEvent(
|
||||||
|
PipelineEventType.TTS_END,
|
||||||
|
{
|
||||||
|
"tts_output": {
|
||||||
|
"token": self.tts_stream.token,
|
||||||
|
"url": self.tts_stream.url,
|
||||||
|
"mime_type": self.tts_stream.content_type,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _capture_chunk(self, audio_bytes: bytes | None) -> None:
|
def _capture_chunk(self, audio_bytes: bytes | None) -> None:
|
||||||
|
|||||||
@@ -92,9 +92,6 @@ class AssistSatelliteAnnouncement:
|
|||||||
media_id: str
|
media_id: str
|
||||||
"""Media ID to be played."""
|
"""Media ID to be played."""
|
||||||
|
|
||||||
original_media_id: str
|
|
||||||
"""The raw media ID before processing."""
|
|
||||||
|
|
||||||
tts_token: str | None
|
tts_token: str | None
|
||||||
"""The TTS token of the media."""
|
"""The TTS token of the media."""
|
||||||
|
|
||||||
@@ -501,9 +498,7 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
media_id_source: Literal["url", "media_id", "tts"] | None = None
|
media_id_source: Literal["url", "media_id", "tts"] | None = None
|
||||||
tts_token: str | None = None
|
tts_token: str | None = None
|
||||||
|
|
||||||
if media_id:
|
if not media_id:
|
||||||
original_media_id = media_id
|
|
||||||
else:
|
|
||||||
media_id_source = "tts"
|
media_id_source = "tts"
|
||||||
# Synthesize audio and get URL
|
# Synthesize audio and get URL
|
||||||
pipeline_id = self._resolve_pipeline()
|
pipeline_id = self._resolve_pipeline()
|
||||||
@@ -530,13 +525,6 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
|
|
||||||
tts_token = stream.token
|
tts_token = stream.token
|
||||||
media_id = stream.url
|
media_id = stream.url
|
||||||
original_media_id = tts.generate_media_source_id(
|
|
||||||
self.hass,
|
|
||||||
message,
|
|
||||||
engine=engine,
|
|
||||||
language=pipeline.tts_language,
|
|
||||||
options=tts_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
if media_source.is_media_source_id(media_id):
|
if media_source.is_media_source_id(media_id):
|
||||||
if not media_id_source:
|
if not media_id_source:
|
||||||
@@ -572,7 +560,6 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
return AssistSatelliteAnnouncement(
|
return AssistSatelliteAnnouncement(
|
||||||
message=message,
|
message=message,
|
||||||
media_id=media_id,
|
media_id=media_id,
|
||||||
original_media_id=original_media_id,
|
|
||||||
tts_token=tts_token,
|
tts_token=tts_token,
|
||||||
media_id_source=media_id_source,
|
media_id_source=media_id_source,
|
||||||
preannounce_media_id=preannounce_media_id,
|
preannounce_media_id=preannounce_media_id,
|
||||||
|
|||||||
@@ -63,7 +63,6 @@ from .const import (
|
|||||||
from .entity import TextToSpeechEntity, TTSAudioRequest
|
from .entity import TextToSpeechEntity, TTSAudioRequest
|
||||||
from .helper import get_engine_instance
|
from .helper import get_engine_instance
|
||||||
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
|
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
|
||||||
from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
|
||||||
from .models import Voice
|
from .models import Voice
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -83,8 +82,6 @@ __all__ = [
|
|||||||
"TtsAudioType",
|
"TtsAudioType",
|
||||||
"Voice",
|
"Voice",
|
||||||
"async_default_engine",
|
"async_default_engine",
|
||||||
"async_get_media_source_audio",
|
|
||||||
"generate_media_source_id",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@@ -267,19 +264,6 @@ def async_get_stream(hass: HomeAssistant, token: str) -> ResultStream | None:
|
|||||||
return hass.data[DATA_TTS_MANAGER].token_to_stream.get(token)
|
return hass.data[DATA_TTS_MANAGER].token_to_stream.get(token)
|
||||||
|
|
||||||
|
|
||||||
async def async_get_media_source_audio(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
media_source_id: str,
|
|
||||||
) -> tuple[str, bytes]:
|
|
||||||
"""Get TTS audio as extension, data."""
|
|
||||||
manager = hass.data[DATA_TTS_MANAGER]
|
|
||||||
cache = manager.async_cache_message_in_memory(
|
|
||||||
**media_source_id_to_kwargs(media_source_id)
|
|
||||||
)
|
|
||||||
data = b"".join([chunk async for chunk in cache.async_stream_data()])
|
|
||||||
return cache.extension, data
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
|
def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
|
||||||
"""Return a set with the union of languages supported by tts engines."""
|
"""Return a set with the union of languages supported by tts engines."""
|
||||||
|
|||||||
@@ -84,7 +84,6 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
@@ -183,7 +182,6 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D",
|
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
@@ -282,7 +280,6 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D",
|
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
@@ -405,7 +402,6 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
|
|||||||
@@ -80,7 +80,6 @@
|
|||||||
# name: test_audio_pipeline.6
|
# name: test_audio_pipeline.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
@@ -171,7 +170,6 @@
|
|||||||
# name: test_audio_pipeline_debug.6
|
# name: test_audio_pipeline_debug.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
@@ -274,7 +272,6 @@
|
|||||||
# name: test_audio_pipeline_with_enhancements.6
|
# name: test_audio_pipeline_with_enhancements.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
@@ -387,7 +384,6 @@
|
|||||||
# name: test_audio_pipeline_with_wake_word_no_timeout.8
|
# name: test_audio_pipeline_with_wake_word_no_timeout.8
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
|
|||||||
@@ -190,7 +190,6 @@ async def test_new_pipeline_cancels_pipeline(
|
|||||||
AssistSatelliteAnnouncement(
|
AssistSatelliteAnnouncement(
|
||||||
message="Hello",
|
message="Hello",
|
||||||
media_id="http://10.10.10.10:8123/api/tts_proxy/test-token",
|
media_id="http://10.10.10.10:8123/api/tts_proxy/test-token",
|
||||||
original_media_id="media-source://bla",
|
|
||||||
tts_token="test-token",
|
tts_token="test-token",
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
),
|
),
|
||||||
@@ -204,7 +203,6 @@ async def test_new_pipeline_cancels_pipeline(
|
|||||||
AssistSatelliteAnnouncement(
|
AssistSatelliteAnnouncement(
|
||||||
message="Hello",
|
message="Hello",
|
||||||
media_id="https://www.home-assistant.io/resolved.mp3",
|
media_id="https://www.home-assistant.io/resolved.mp3",
|
||||||
original_media_id="media-source://given",
|
|
||||||
tts_token=None,
|
tts_token=None,
|
||||||
media_id_source="media_id",
|
media_id_source="media_id",
|
||||||
),
|
),
|
||||||
@@ -214,7 +212,6 @@ async def test_new_pipeline_cancels_pipeline(
|
|||||||
AssistSatelliteAnnouncement(
|
AssistSatelliteAnnouncement(
|
||||||
message="",
|
message="",
|
||||||
media_id="http://example.com/bla.mp3",
|
media_id="http://example.com/bla.mp3",
|
||||||
original_media_id="http://example.com/bla.mp3",
|
|
||||||
tts_token=None,
|
tts_token=None,
|
||||||
media_id_source="url",
|
media_id_source="url",
|
||||||
),
|
),
|
||||||
@@ -227,7 +224,6 @@ async def test_new_pipeline_cancels_pipeline(
|
|||||||
AssistSatelliteAnnouncement(
|
AssistSatelliteAnnouncement(
|
||||||
message="",
|
message="",
|
||||||
media_id="http://example.com/bla.mp3",
|
media_id="http://example.com/bla.mp3",
|
||||||
original_media_id="http://example.com/bla.mp3",
|
|
||||||
tts_token=None,
|
tts_token=None,
|
||||||
media_id_source="url",
|
media_id_source="url",
|
||||||
preannounce_media_id="http://example.com/preannounce.mp3",
|
preannounce_media_id="http://example.com/preannounce.mp3",
|
||||||
@@ -250,23 +246,7 @@ async def test_announce(
|
|||||||
assert entity.state == AssistSatelliteState.RESPONDING
|
assert entity.state == AssistSatelliteState.RESPONDING
|
||||||
await original_announce(announcement)
|
await original_announce(announcement)
|
||||||
|
|
||||||
def tts_generate_media_source_id(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
message: str,
|
|
||||||
engine: str | None = None,
|
|
||||||
language: str | None = None,
|
|
||||||
options: dict | None = None,
|
|
||||||
cache: bool | None = None,
|
|
||||||
):
|
|
||||||
# Check that TTS options are passed here
|
|
||||||
assert options == {"test-option": "test-value", "voice": "test-voice"}
|
|
||||||
return "media-source://bla"
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.generate_media_source_id",
|
|
||||||
new=tts_generate_media_source_id,
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.tts.async_resolve_engine",
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
return_value="tts.cloud",
|
return_value="tts.cloud",
|
||||||
@@ -550,7 +530,6 @@ async def test_vad_sensitivity_entity_not_found(
|
|||||||
message="Hello",
|
message="Hello",
|
||||||
media_id="http://10.10.10.10:8123/api/tts_proxy/test-token",
|
media_id="http://10.10.10.10:8123/api/tts_proxy/test-token",
|
||||||
tts_token="test-token",
|
tts_token="test-token",
|
||||||
original_media_id="media-source://generated",
|
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@@ -568,7 +547,6 @@ async def test_vad_sensitivity_entity_not_found(
|
|||||||
message="Hello",
|
message="Hello",
|
||||||
media_id="https://www.home-assistant.io/resolved.mp3",
|
media_id="https://www.home-assistant.io/resolved.mp3",
|
||||||
tts_token=None,
|
tts_token=None,
|
||||||
original_media_id="media-source://given",
|
|
||||||
media_id_source="media_id",
|
media_id_source="media_id",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@@ -585,7 +563,6 @@ async def test_vad_sensitivity_entity_not_found(
|
|||||||
message="",
|
message="",
|
||||||
media_id="http://example.com/given.mp3",
|
media_id="http://example.com/given.mp3",
|
||||||
tts_token=None,
|
tts_token=None,
|
||||||
original_media_id="http://example.com/given.mp3",
|
|
||||||
media_id_source="url",
|
media_id_source="url",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@@ -602,7 +579,6 @@ async def test_vad_sensitivity_entity_not_found(
|
|||||||
message="",
|
message="",
|
||||||
media_id="http://example.com/given.mp3",
|
media_id="http://example.com/given.mp3",
|
||||||
tts_token=None,
|
tts_token=None,
|
||||||
original_media_id="http://example.com/given.mp3",
|
|
||||||
media_id_source="url",
|
media_id_source="url",
|
||||||
preannounce_media_id="http://example.com/preannounce.mp3",
|
preannounce_media_id="http://example.com/preannounce.mp3",
|
||||||
),
|
),
|
||||||
@@ -633,10 +609,6 @@ async def test_start_conversation(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.generate_media_source_id",
|
|
||||||
return_value="media-source://generated",
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.tts.async_resolve_engine",
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
return_value="tts.cloud",
|
return_value="tts.cloud",
|
||||||
|
|||||||
@@ -19,10 +19,6 @@ async def mock_tts(hass: HomeAssistant):
|
|||||||
"""Mock TTS service."""
|
"""Mock TTS service."""
|
||||||
assert await async_setup_component(hass, "tts", {})
|
assert await async_setup_component(hass, "tts", {})
|
||||||
with (
|
with (
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.generate_media_source_id",
|
|
||||||
return_value="media-source://bla",
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.tts.async_create_stream",
|
"homeassistant.components.tts.async_create_stream",
|
||||||
return_value=MockResultStream(hass, "wav", b""),
|
return_value=MockResultStream(hass, "wav", b""),
|
||||||
|
|||||||
@@ -1196,10 +1196,6 @@ async def test_announce_message(
|
|||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.generate_media_source_id",
|
|
||||||
return_value="media-source://bla",
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.tts.async_resolve_engine",
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
return_value="tts.cloud_tts",
|
return_value="tts.cloud_tts",
|
||||||
@@ -1373,10 +1369,6 @@ async def test_announce_message_with_preannounce(
|
|||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.generate_media_source_id",
|
|
||||||
return_value="media-source://bla",
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.tts.async_resolve_engine",
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
return_value="tts.cloud_tts",
|
return_value="tts.cloud_tts",
|
||||||
@@ -1493,10 +1485,6 @@ async def test_start_conversation_message(
|
|||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.generate_media_source_id",
|
|
||||||
return_value="media-source://bla",
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.tts.async_resolve_engine",
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
return_value="tts.cloud_tts",
|
return_value="tts.cloud_tts",
|
||||||
@@ -1708,10 +1696,6 @@ async def test_start_conversation_message_with_preannounce(
|
|||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.generate_media_source_id",
|
|
||||||
return_value="media-source://bla",
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.tts.async_resolve_engine",
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
return_value="tts.cloud_tts",
|
return_value="tts.cloud_tts",
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from homeassistant.components.tts import (
|
|||||||
Voice,
|
Voice,
|
||||||
_get_cache_files,
|
_get_cache_files,
|
||||||
)
|
)
|
||||||
|
from homeassistant.components.tts.media_source import media_source_id_to_kwargs
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
@@ -44,6 +45,19 @@ SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
|
|||||||
TEST_DOMAIN = "test"
|
TEST_DOMAIN = "test"
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_media_source_audio(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
media_source_id: str,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
"""Get TTS audio as extension, data."""
|
||||||
|
manager = hass.data[DATA_TTS_MANAGER]
|
||||||
|
cache = manager.async_cache_message_in_memory(
|
||||||
|
**media_source_id_to_kwargs(media_source_id)
|
||||||
|
)
|
||||||
|
data = b"".join([chunk async for chunk in cache.async_stream_data()])
|
||||||
|
return cache.extension, data
|
||||||
|
|
||||||
|
|
||||||
def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]:
|
def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]:
|
||||||
"""Mock the list TTS cache function."""
|
"""Mock the list TTS cache function."""
|
||||||
with patch(
|
with patch(
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from homeassistant.components.media_player import (
|
|||||||
SERVICE_PLAY_MEDIA,
|
SERVICE_PLAY_MEDIA,
|
||||||
MediaType,
|
MediaType,
|
||||||
)
|
)
|
||||||
|
from homeassistant.components.tts.media_source import generate_media_source_id
|
||||||
from homeassistant.config_entries import ConfigEntryState
|
from homeassistant.config_entries import ConfigEntryState
|
||||||
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN
|
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@@ -32,6 +33,7 @@ from .common import (
|
|||||||
MockTTS,
|
MockTTS,
|
||||||
MockTTSEntity,
|
MockTTSEntity,
|
||||||
MockTTSProvider,
|
MockTTSProvider,
|
||||||
|
async_get_media_source_audio,
|
||||||
get_media_source_url,
|
get_media_source_url,
|
||||||
mock_config_entry_setup,
|
mock_config_entry_setup,
|
||||||
mock_setup,
|
mock_setup,
|
||||||
@@ -820,7 +822,7 @@ async def test_service_receive_voice(
|
|||||||
assert req.status == HTTPStatus.OK
|
assert req.status == HTTPStatus.OK
|
||||||
assert await req.read() == tts_data
|
assert await req.read() == tts_data
|
||||||
|
|
||||||
extension, data = await tts.async_get_media_source_audio(
|
extension, data = await async_get_media_source_audio(
|
||||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
)
|
)
|
||||||
assert extension == "mp3"
|
assert extension == "mp3"
|
||||||
@@ -1314,7 +1316,7 @@ async def test_generate_media_source_id(
|
|||||||
result_query: str,
|
result_query: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test generating a media source ID."""
|
"""Test generating a media source ID."""
|
||||||
media_source_id = tts.generate_media_source_id(
|
media_source_id = generate_media_source_id(
|
||||||
hass, "msg", engine, language, options, cache
|
hass, "msg", engine, language, options, cache
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1352,7 +1354,7 @@ async def test_generate_media_source_id_invalid_options(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test generating a media source ID."""
|
"""Test generating a media source ID."""
|
||||||
with pytest.raises(HomeAssistantError):
|
with pytest.raises(HomeAssistantError):
|
||||||
tts.generate_media_source_id(hass, "msg", engine, language, options, None)
|
generate_media_source_id(hass, "msg", engine, language, options, None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -1404,7 +1406,7 @@ async def test_legacy_fetching_in_async(
|
|||||||
await mock_setup(hass, ProviderWithAsyncFetching(DEFAULT_LANG))
|
await mock_setup(hass, ProviderWithAsyncFetching(DEFAULT_LANG))
|
||||||
|
|
||||||
# Test async_get_media_source_audio
|
# Test async_get_media_source_audio
|
||||||
media_source_id = tts.generate_media_source_id(
|
media_source_id = generate_media_source_id(
|
||||||
hass,
|
hass,
|
||||||
"test message",
|
"test message",
|
||||||
"test",
|
"test",
|
||||||
@@ -1412,12 +1414,8 @@ async def test_legacy_fetching_in_async(
|
|||||||
cache=None,
|
cache=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
task = hass.async_create_task(
|
task = hass.async_create_task(async_get_media_source_audio(hass, media_source_id))
|
||||||
tts.async_get_media_source_audio(hass, media_source_id)
|
task2 = hass.async_create_task(async_get_media_source_audio(hass, media_source_id))
|
||||||
)
|
|
||||||
task2 = hass.async_create_task(
|
|
||||||
tts.async_get_media_source_audio(hass, media_source_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
url = await get_media_source_url(hass, media_source_id)
|
url = await get_media_source_url(hass, media_source_id)
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@@ -1438,17 +1436,17 @@ async def test_legacy_fetching_in_async(
|
|||||||
assert await req.read() == b"test"
|
assert await req.read() == b"test"
|
||||||
|
|
||||||
# Test error is not cached
|
# Test error is not cached
|
||||||
media_source_id = tts.generate_media_source_id(
|
media_source_id = generate_media_source_id(
|
||||||
hass, "test message 2", "test", "en_US", None, None
|
hass, "test message 2", "test", "en_US", None, None
|
||||||
)
|
)
|
||||||
tts_audio = asyncio.Future()
|
tts_audio = asyncio.Future()
|
||||||
tts_audio.set_exception(HomeAssistantError("test error"))
|
tts_audio.set_exception(HomeAssistantError("test error"))
|
||||||
with pytest.raises(HomeAssistantError):
|
with pytest.raises(HomeAssistantError):
|
||||||
assert await tts.async_get_media_source_audio(hass, media_source_id)
|
assert await async_get_media_source_audio(hass, media_source_id)
|
||||||
|
|
||||||
tts_audio = asyncio.Future()
|
tts_audio = asyncio.Future()
|
||||||
tts_audio.set_result(b"test 2")
|
tts_audio.set_result(b"test 2")
|
||||||
assert await tts.async_get_media_source_audio(hass, media_source_id) == (
|
assert await async_get_media_source_audio(hass, media_source_id) == (
|
||||||
"mp3",
|
"mp3",
|
||||||
b"test 2",
|
b"test 2",
|
||||||
)
|
)
|
||||||
@@ -1471,7 +1469,7 @@ async def test_fetching_in_async(
|
|||||||
await mock_config_entry_setup(hass, EntityWithAsyncFetching(DEFAULT_LANG))
|
await mock_config_entry_setup(hass, EntityWithAsyncFetching(DEFAULT_LANG))
|
||||||
|
|
||||||
# Test async_get_media_source_audio
|
# Test async_get_media_source_audio
|
||||||
media_source_id = tts.generate_media_source_id(
|
media_source_id = generate_media_source_id(
|
||||||
hass,
|
hass,
|
||||||
"test message",
|
"test message",
|
||||||
"tts.test",
|
"tts.test",
|
||||||
@@ -1479,12 +1477,8 @@ async def test_fetching_in_async(
|
|||||||
cache=None,
|
cache=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
task = hass.async_create_task(
|
task = hass.async_create_task(async_get_media_source_audio(hass, media_source_id))
|
||||||
tts.async_get_media_source_audio(hass, media_source_id)
|
task2 = hass.async_create_task(async_get_media_source_audio(hass, media_source_id))
|
||||||
)
|
|
||||||
task2 = hass.async_create_task(
|
|
||||||
tts.async_get_media_source_audio(hass, media_source_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
url = await get_media_source_url(hass, media_source_id)
|
url = await get_media_source_url(hass, media_source_id)
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@@ -1505,17 +1499,17 @@ async def test_fetching_in_async(
|
|||||||
assert await req.read() == b"test"
|
assert await req.read() == b"test"
|
||||||
|
|
||||||
# Test error is not cached
|
# Test error is not cached
|
||||||
media_source_id = tts.generate_media_source_id(
|
media_source_id = generate_media_source_id(
|
||||||
hass, "test message 2", "tts.test", "en_US", None, None
|
hass, "test message 2", "tts.test", "en_US", None, None
|
||||||
)
|
)
|
||||||
tts_audio = asyncio.Future()
|
tts_audio = asyncio.Future()
|
||||||
tts_audio.set_exception(HomeAssistantError("test error"))
|
tts_audio.set_exception(HomeAssistantError("test error"))
|
||||||
with pytest.raises(HomeAssistantError):
|
with pytest.raises(HomeAssistantError):
|
||||||
assert await tts.async_get_media_source_audio(hass, media_source_id)
|
assert await async_get_media_source_audio(hass, media_source_id)
|
||||||
|
|
||||||
tts_audio = asyncio.Future()
|
tts_audio = asyncio.Future()
|
||||||
tts_audio.set_result(b"test 2")
|
tts_audio.set_result(b"test 2")
|
||||||
assert await tts.async_get_media_source_audio(hass, media_source_id) == (
|
assert await async_get_media_source_audio(hass, media_source_id) == (
|
||||||
"mp3",
|
"mp3",
|
||||||
b"test 2",
|
b"test 2",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -832,7 +832,6 @@ async def test_announce(
|
|||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
tts_token=mock_tts_result_stream.token,
|
tts_token=mock_tts_result_stream.token,
|
||||||
original_media_id=_MEDIA_ID,
|
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -883,7 +882,6 @@ async def test_voip_id_is_ip_address(
|
|||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
tts_token=mock_tts_result_stream.token,
|
tts_token=mock_tts_result_stream.token,
|
||||||
original_media_id=_MEDIA_ID,
|
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -939,7 +937,6 @@ async def test_announce_timeout(
|
|||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
tts_token=mock_tts_result_stream.token,
|
tts_token=mock_tts_result_stream.token,
|
||||||
original_media_id=_MEDIA_ID,
|
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -981,7 +978,6 @@ async def test_start_conversation(
|
|||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
tts_token=mock_tts_result_stream.token,
|
tts_token=mock_tts_result_stream.token,
|
||||||
original_media_id=_MEDIA_ID,
|
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1125,10 +1121,6 @@ async def test_start_conversation_user_doesnt_pick_up(
|
|||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.generate_media_source_id",
|
|
||||||
return_value="media-source://bla",
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.tts.async_resolve_engine",
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
return_value="test tts",
|
return_value="test tts",
|
||||||
|
|||||||
@@ -11,12 +11,17 @@ from syrupy import SnapshotAssertion
|
|||||||
from wyoming.audio import AudioChunk, AudioStop
|
from wyoming.audio import AudioChunk, AudioStop
|
||||||
|
|
||||||
from homeassistant.components import tts, wyoming
|
from homeassistant.components import tts, wyoming
|
||||||
|
|
||||||
|
# pylint: disable=hass-component-root-import
|
||||||
|
from homeassistant.components.tts.media_source import generate_media_source_id
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.entity_component import DATA_INSTANCES
|
from homeassistant.helpers.entity_component import DATA_INSTANCES
|
||||||
|
|
||||||
from . import MockAsyncTcpClient
|
from . import MockAsyncTcpClient
|
||||||
|
|
||||||
|
from tests.components.tts.common import async_get_media_source_audio
|
||||||
|
|
||||||
|
|
||||||
async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
|
async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
|
||||||
"""Test supported properties."""
|
"""Test supported properties."""
|
||||||
@@ -59,9 +64,9 @@ async def test_get_tts_audio(
|
|||||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||||
MockAsyncTcpClient(audio_events),
|
MockAsyncTcpClient(audio_events),
|
||||||
) as mock_client:
|
) as mock_client:
|
||||||
extension, data = await tts.async_get_media_source_audio(
|
extension, data = await async_get_media_source_audio(
|
||||||
hass,
|
hass,
|
||||||
tts.generate_media_source_id(
|
generate_media_source_id(
|
||||||
hass,
|
hass,
|
||||||
"Hello world",
|
"Hello world",
|
||||||
"tts.test_tts",
|
"tts.test_tts",
|
||||||
@@ -96,9 +101,9 @@ async def test_get_tts_audio_different_formats(
|
|||||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||||
MockAsyncTcpClient(audio_events),
|
MockAsyncTcpClient(audio_events),
|
||||||
) as mock_client:
|
) as mock_client:
|
||||||
extension, data = await tts.async_get_media_source_audio(
|
extension, data = await async_get_media_source_audio(
|
||||||
hass,
|
hass,
|
||||||
tts.generate_media_source_id(
|
generate_media_source_id(
|
||||||
hass,
|
hass,
|
||||||
"Hello world",
|
"Hello world",
|
||||||
"tts.test_tts",
|
"tts.test_tts",
|
||||||
@@ -130,9 +135,9 @@ async def test_get_tts_audio_different_formats(
|
|||||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||||
MockAsyncTcpClient(audio_events),
|
MockAsyncTcpClient(audio_events),
|
||||||
) as mock_client:
|
) as mock_client:
|
||||||
extension, data = await tts.async_get_media_source_audio(
|
extension, data = await async_get_media_source_audio(
|
||||||
hass,
|
hass,
|
||||||
tts.generate_media_source_id(
|
generate_media_source_id(
|
||||||
hass,
|
hass,
|
||||||
"Hello world",
|
"Hello world",
|
||||||
"tts.test_tts",
|
"tts.test_tts",
|
||||||
@@ -182,9 +187,9 @@ async def test_get_tts_audio_audio_oserror(
|
|||||||
HomeAssistantError,
|
HomeAssistantError,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
await tts.async_get_media_source_audio(
|
await async_get_media_source_audio(
|
||||||
hass,
|
hass,
|
||||||
tts.generate_media_source_id(
|
generate_media_source_id(
|
||||||
hass, "Hello world", "tts.test_tts", hass.config.language
|
hass, "Hello world", "tts.test_tts", hass.config.language
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -204,9 +209,9 @@ async def test_voice_speaker(
|
|||||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||||
MockAsyncTcpClient(audio_events),
|
MockAsyncTcpClient(audio_events),
|
||||||
) as mock_client:
|
) as mock_client:
|
||||||
await tts.async_get_media_source_audio(
|
await async_get_media_source_audio(
|
||||||
hass,
|
hass,
|
||||||
tts.generate_media_source_id(
|
generate_media_source_id(
|
||||||
hass,
|
hass,
|
||||||
"Hello world",
|
"Hello world",
|
||||||
"tts.test_tts",
|
"tts.test_tts",
|
||||||
|
|||||||
Reference in New Issue
Block a user