diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index a205db4e615..5f811ac955b 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -20,9 +20,6 @@ import hass_nabucasa import voluptuous as vol from homeassistant.components import conversation, stt, tts, wake_word, websocket_api -from homeassistant.components.tts import ( - generate_media_source_id as tts_generate_media_source_id, -) from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError @@ -1276,26 +1273,10 @@ class PipelineRun: ) ) - try: - # Synthesize audio and get URL - tts_media_id = tts_generate_media_source_id( - self.hass, - tts_input, - engine=self.tts_stream.engine, - language=self.tts_stream.language, - options=self.tts_stream.options, - ) - except Exception as src_error: - _LOGGER.exception("Unexpected error during text-to-speech") - raise TextToSpeechError( - code="tts-failed", - message="Unexpected error during text-to-speech", - ) from src_error - self.tts_stream.async_set_message(tts_input) tts_output = { - "media_id": tts_media_id, + "media_id": self.tts_stream.media_source_id, "token": self.tts_stream.token, "url": self.tts_stream.url, "mime_type": self.tts_stream.content_type, diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 526be21ad76..da8a0f2324e 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -25,6 +25,9 @@ import voluptuous as vol from homeassistant.components import ffmpeg, websocket_api from homeassistant.components.http import HomeAssistantView +from homeassistant.components.media_source import ( + generate_media_source_id as ms_generate_media_source_id, +) from homeassistant.config_entries import ConfigEntry from homeassistant.const import EVENT_HOMEASSISTANT_STOP, PLATFORM_FORMAT from homeassistant.core import ( @@ -58,6 +61,7 @@ from .const import ( DEFAULT_CACHE_DIR, DEFAULT_TIME_MEMORY, DOMAIN, + MEDIA_SOURCE_STREAM_PATH, TtsAudioType, ) from .entity import TextToSpeechEntity, TTSAudioRequest, TTSAudioResponse @@ -273,9 +277,17 @@ async def async_get_media_source_audio( media_source_id: str, ) -> tuple[str, bytes]: """Get TTS audio as extension, data.""" + manager = hass.data[DATA_TTS_MANAGER] parsed = parse_media_source_id(media_source_id) - stream = hass.data[DATA_TTS_MANAGER].async_create_result_stream(**parsed["options"]) - stream.async_set_message(parsed["message"]) + if "stream" in parsed: + stream = manager.async_get_result_stream( + parsed["stream"] # type: ignore[typeddict-item] + ) + if stream is None: + raise ValueError("Stream not found") + else: + stream = manager.async_create_result_stream(**parsed["options"]) + stream.async_set_message(parsed["message"]) data = b"".join([chunk async for chunk in stream.async_stream_result()]) return stream.extension, data @@ -478,6 +490,14 @@ class ResultStream: """Get the URL to stream the result.""" return f"/api/tts_proxy/{self.token}" + @cached_property + def media_source_id(self) -> str: + """Get the media source ID of this stream.""" + return ms_generate_media_source_id( + DOMAIN, + f"{MEDIA_SOURCE_STREAM_PATH}/{self.token}", + ) + @cached_property def _result_cache(self) -> asyncio.Future[TTSCache]: """Get the future that returns the cache.""" diff --git a/homeassistant/components/tts/const.py b/homeassistant/components/tts/const.py index 42c7d710ad4..830e0053cee 100644 --- a/homeassistant/components/tts/const.py +++ b/homeassistant/components/tts/const.py @@ -30,4 +30,6 @@ DATA_COMPONENT: HassKey[EntityComponent[TextToSpeechEntity]] = HassKey(DOMAIN) DATA_TTS_MANAGER: HassKey[SpeechManager] = HassKey("tts_manager") +MEDIA_SOURCE_STREAM_PATH = "-stream-" + type TtsAudioType = tuple[str | None, bytes | None] diff --git a/homeassistant/components/tts/media_source.py b/homeassistant/components/tts/media_source.py index f096e082364..91192fdca13 100644 --- a/homeassistant/components/tts/media_source.py +++ b/homeassistant/components/tts/media_source.py @@ -19,7 +19,7 @@ from homeassistant.components.media_source import ( from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError -from .const import DATA_COMPONENT, DATA_TTS_MANAGER, DOMAIN +from .const import DATA_COMPONENT, DATA_TTS_MANAGER, DOMAIN, MEDIA_SOURCE_STREAM_PATH from .helper import get_engine_instance URL_QUERY_TTS_OPTIONS = "tts_options" @@ -81,10 +81,22 @@ class ParsedMediaSourceId(TypedDict): message: str +class ParsedMediaSourceStreamId(TypedDict): + """Parsed media source ID for a stream.""" + + stream: str + + @callback -def parse_media_source_id(media_source_id: str) -> ParsedMediaSourceId: +def parse_media_source_id( + media_source_id: str, +) -> ParsedMediaSourceId | ParsedMediaSourceStreamId: """Turn a media source ID into options.""" parsed = URL(media_source_id) + + if parsed.path.startswith(f"{MEDIA_SOURCE_STREAM_PATH}/"): + return {"stream": parsed.path[len(MEDIA_SOURCE_STREAM_PATH) + 1 :]} + if URL_QUERY_TTS_OPTIONS in parsed.query: try: options = json.loads(parsed.query[URL_QUERY_TTS_OPTIONS]) @@ -122,17 +134,24 @@ class TTSMediaSource(MediaSource): async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: """Resolve media to a url.""" + manager = self.hass.data[DATA_TTS_MANAGER] try: parsed = parse_media_source_id(item.identifier) - stream = self.hass.data[DATA_TTS_MANAGER].async_create_result_stream( - **parsed["options"] - ) - stream.async_set_message(parsed["message"]) + if "stream" in parsed: + stream = manager.async_get_result_stream( + parsed["stream"], # type: ignore[typeddict-item] + ) + else: + stream = manager.async_create_result_stream(**parsed["options"]) + stream.async_set_message(parsed["message"]) except Unresolvable: raise except HomeAssistantError as err: raise Unresolvable(str(err)) from err + if stream is None: + raise Unresolvable("Stream not found") + return PlayMedia(stream.url, stream.content_type) async def async_browse_media( diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 816430f58d0..5d2d25ddc5c 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -84,7 +84,7 @@ dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", + 'media_id': 'media-source://tts/-stream-/test_token.mp3', 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', @@ -183,7 +183,7 @@ dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D", + 'media_id': 'media-source://tts/-stream-/test_token.mp3', 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', @@ -282,7 +282,7 @@ dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D", + 'media_id': 'media-source://tts/-stream-/test_token.mp3', 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', @@ -405,7 +405,7 @@ dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", + 'media_id': 'media-source://tts/-stream-/test_token.mp3', 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', diff --git a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr index bbe08a2adbe..f5940edbc76 100644 --- a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr +++ b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr @@ -139,7 +139,7 @@ dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': 'media-source://tts/tts.test?message=hello,+how+are+you?&language=en_US&tts_options=%7B%7D', + 'media_id': 'media-source://tts/-stream-/mocked-token.mp3', 'mime_type': 'audio/mpeg', 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 41bdba9f3cd..827b9c71ba8 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -80,7 +80,7 @@ # name: test_audio_pipeline.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", + 'media_id': 'media-source://tts/-stream-/test_token.mp3', 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', @@ -171,7 +171,7 @@ # name: test_audio_pipeline_debug.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", + 'media_id': 'media-source://tts/-stream-/test_token.mp3', 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', @@ -274,7 +274,7 @@ # name: test_audio_pipeline_with_enhancements.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", + 'media_id': 'media-source://tts/-stream-/test_token.mp3', 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', @@ -387,7 +387,7 @@ # name: test_audio_pipeline_with_wake_word_no_timeout.8 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", + 'media_id': 'media-source://tts/-stream-/test_token.mp3', 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', diff --git a/tests/components/tts/test_media_source.py b/tests/components/tts/test_media_source.py index eb4b09cab5b..8ec0de8765d 100644 --- a/tests/components/tts/test_media_source.py +++ b/tests/components/tts/test_media_source.py @@ -17,6 +17,7 @@ from homeassistant.setup import async_setup_component from .common import ( DEFAULT_LANG, + MockResultStream, MockTTSEntity, MockTTSProvider, mock_config_entry_setup, @@ -198,6 +199,17 @@ async def test_resolving( assert language == "de_DE" assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"} + # Test with result stream + stream = MockResultStream(hass, "wav", b"") + media = await media_source.async_resolve_media(hass, stream.media_source_id, None) + assert media.url == stream.url + assert media.mime_type == stream.content_type + + with pytest.raises(media_source.Unresolvable): + await media_source.async_resolve_media( + hass, "media-source://tts/-stream-/not-a-valid-token", None + ) + @pytest.mark.parametrize( ("mock_provider", "mock_tts_entity"),