diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index a028fa638df..34a4f82c388 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -20,9 +20,6 @@ import hass_nabucasa import voluptuous as vol from homeassistant.components import conversation, stt, tts, wake_word, websocket_api -from homeassistant.components.tts import ( - generate_media_source_id as tts_generate_media_source_id, -) from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError @@ -1275,26 +1272,10 @@ class PipelineRun: ) ) - try: - # Synthesize audio and get URL - tts_media_id = tts_generate_media_source_id( - self.hass, - tts_input, - engine=self.tts_stream.engine, - language=self.tts_stream.language, - options=self.tts_stream.options, - ) - except Exception as src_error: - _LOGGER.exception("Unexpected error during text-to-speech") - raise TextToSpeechError( - code="tts-failed", - message="Unexpected error during text-to-speech", - ) from src_error - self.tts_stream.async_set_message(tts_input) tts_output = { - "media_id": tts_media_id, + "media_id": self.tts_stream.media_source_id, "url": self.tts_stream.url, "mime_type": self.tts_stream.content_type, } diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 98ce76cafde..23bdfb5cdda 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -27,6 +27,10 @@ import voluptuous as vol from homeassistant.components import ffmpeg, websocket_api from homeassistant.components.http import HomeAssistantView +from homeassistant.components.media_source import ( + Unresolvable, + generate_media_source_id as ms_generate_media_source_id, +) from homeassistant.config_entries import ConfigEntry from homeassistant.const import EVENT_HOMEASSISTANT_STOP, PLATFORM_FORMAT from homeassistant.core import ( @@ -188,10 +192,19 @@ async def async_get_media_source_audio( ) -> tuple[str, bytes]: """Get TTS audio as extension, data.""" manager = hass.data[DATA_TTS_MANAGER] - cache_key = manager.async_cache_message_in_memory( - **media_source_id_to_kwargs(media_source_id) - ) - return await manager.async_get_tts_audio(cache_key) + + if not media_source_id.startswith("media-source://tts/temporary/"): + cache_key = manager.async_cache_message_in_memory( + **media_source_id_to_kwargs(media_source_id) + ) + return await manager.async_get_tts_audio(cache_key) + + token = media_source_id.partition("/")[2] + if (stream := manager.token_to_stream.get(token)) is None: + raise Unresolvable("Token from media source not found") + + data = b"".join([chunk async for chunk in stream.async_stream_result()]) + return stream.extension, data @callback @@ -394,6 +407,11 @@ class ResultStream: """Get the URL to stream the result.""" return f"/api/tts_proxy/{self.token}" + @cached_property + def media_source_id(self) -> str: + """Get the media source ID for the result.""" + return ms_generate_media_source_id(DOMAIN, f"temporary/{self.token}") + @cached_property def _result_cache_key(self) -> asyncio.Future[str]: """Get the future that returns the cache key.""" diff --git a/homeassistant/components/tts/media_source.py b/homeassistant/components/tts/media_source.py index aa2cd6e7555..4771153cb73 100644 --- a/homeassistant/components/tts/media_source.py +++ b/homeassistant/components/tts/media_source.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from typing import TypedDict +from typing import TypedDict, TYPE_CHECKING from yarl import URL @@ -22,12 +22,15 @@ from homeassistant.exceptions import HomeAssistantError from .const import DATA_COMPONENT, DATA_TTS_MANAGER, DOMAIN from .helper import get_engine_instance +if TYPE_CHECKING: + from . import SpeechManager + URL_QUERY_TTS_OPTIONS = "tts_options" async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource: """Set up tts media source.""" - return TTSMediaSource(hass) + return TTSMediaSource(hass, hass.data[DATA_TTS_MANAGER]) @callback @@ -109,22 +112,31 @@ class TTSMediaSource(MediaSource): """Provide text-to-speech providers as media sources.""" name: str = "Text-to-speech" + manager: SpeechManager - def __init__(self, hass: HomeAssistant) -> None: + def __init__(self, hass: HomeAssistant, manager: SpeechManager) -> None: """Initialize TTSMediaSource.""" super().__init__(DOMAIN) self.hass = hass + self.manager = manager async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: """Resolve media to a url.""" - try: - stream = self.hass.data[DATA_TTS_MANAGER].async_create_result_stream( - **media_source_id_to_kwargs(item.identifier) - ) - except Unresolvable: - raise - except HomeAssistantError as err: - raise Unresolvable(str(err)) from err + if item.identifier.startswith("temporary/"): + token = item.identifier.partition("/")[2] + stream = self.manager.token_to_stream.get(token) + if stream is None: + raise Unresolvable("Temporary media not found") + + else: + try: + stream = self.manager.async_create_result_stream( + **media_source_id_to_kwargs(item.identifier) + ) + except Unresolvable: + raise + except HomeAssistantError as err: + raise Unresolvable(str(err)) from err return PlayMedia(stream.url, stream.content_type) @@ -134,6 +146,9 @@ class TTSMediaSource(MediaSource): ) -> BrowseMediaSource: """Return media.""" if item.identifier: + if item.identifier.startswith("temporary/"): + raise BrowseError("Temporary media cannot be browsed") + engine, _, params = item.identifier.partition("?") return self._engine_item(engine, params) diff --git a/tests/components/tts/test_media_source.py b/tests/components/tts/test_media_source.py index 9e50cc6b512..a5c727541d2 100644 --- a/tests/components/tts/test_media_source.py +++ b/tests/components/tts/test_media_source.py @@ -302,3 +302,9 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs( "options": {"age": {"k1": [5, 6], "k2": "v2"}}, "use_file_cache": True, } + + +async def test_stream_media_sources(hass: HomeAssistant) -> None: + """Test ResultStream as media sources.""" + + # media-source://tts/temporary/AT1BH2ZsWHipW0pCy0cm7w.mp3