diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index a028fa638df..42bb2d4ced8 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -649,6 +649,7 @@ class PipelineRun: data["runner_data"] = self.runner_data if self.tts_stream: data["tts_output"] = { + "token": self.tts_stream.token, "url": self.tts_stream.url, "mime_type": self.tts_stream.content_type, } @@ -1295,6 +1296,7 @@ class PipelineRun: tts_output = { "media_id": tts_media_id, + "token": self.tts_stream.token, "url": self.tts_stream.url, "mime_type": self.tts_stream.content_type, } diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 31a92c62258..6fc25e32091 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -182,6 +182,12 @@ def async_create_stream( ) +@callback +def async_get_stream(hass: HomeAssistant, token: str) -> ResultStream | None: + """Return a result stream given a token.""" + return hass.data[DATA_TTS_MANAGER].token_to_stream.get(token) + + async def async_get_media_source_audio( hass: HomeAssistant, media_source_id: str, diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 2375d48fcf9..f772f877d3a 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -8,6 +8,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }), @@ -85,6 +86,7 @@ 'tts_output': dict({ 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }), @@ -105,6 +107,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }), @@ -182,6 +185,7 @@ 'tts_output': dict({ 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D", 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }), @@ -202,6 +206,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }), @@ -279,6 +284,7 @@ 'tts_output': dict({ 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D", 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }), @@ -299,6 +305,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }), @@ -400,6 +407,7 @@ 'tts_output': dict({ 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }), @@ -420,6 +428,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }), @@ -620,6 +629,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }), diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index d937b5396d1..57ae0095236 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -10,6 +10,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }) @@ -81,6 +82,7 @@ 'tts_output': dict({ 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }) @@ -99,6 +101,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }) @@ -170,6 +173,7 @@ 'tts_output': dict({ 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }) @@ -200,6 +204,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }) @@ -271,6 +276,7 @@ 'tts_output': dict({ 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }) @@ -289,6 +295,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }) @@ -382,6 +389,7 @@ 'tts_output': dict({ 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }) @@ -400,6 +408,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), }) @@ -607,6 +616,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) @@ -660,6 +670,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) @@ -675,6 +686,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) @@ -690,6 +702,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) @@ -705,6 +718,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) @@ -720,6 +734,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) @@ -853,6 +868,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) @@ -868,6 +884,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) @@ -924,6 +941,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) @@ -939,6 +957,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) @@ -998,6 +1017,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) @@ -1013,6 +1033,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), }) diff --git a/tests/components/tts/common.py b/tests/components/tts/common.py index 921cab4cba2..9ae83cb2bb5 100644 --- a/tests/components/tts/common.py +++ b/tests/components/tts/common.py @@ -14,9 +14,11 @@ import voluptuous as vol from homeassistant.components import media_source from homeassistant.components.tts import ( CONF_LANG, + DATA_TTS_MANAGER, DOMAIN as TTS_DOMAIN, PLATFORM_SCHEMA as TTS_PLATFORM_SCHEMA, Provider, + ResultStream, TextToSpeechEntity, TtsAudioType, Voice, @@ -263,3 +265,26 @@ async def mock_config_entry_setup( await hass.async_block_till_done() return config_entry + + +class MockResultStream(ResultStream): + """Mock result stream.""" + + def __init__(self, hass: HomeAssistant, extension: str, data: bytes) -> None: + """Initialize the result stream.""" + super().__init__( + token="test-token", + extension=extension, + content_type=f"audio/mock-{extension}", + engine="test-engine", + use_file_cache=True, + language="en", + options={}, + _manager=hass.data[DATA_TTS_MANAGER], + ) + hass.data[DATA_TTS_MANAGER].token_to_stream[self.token] = self + self._mock_data = data + + async def async_stream_result(self): + """Stream the result.""" + yield self._mock_data diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 1b9692cc70c..8bdd17cf3e9 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -28,6 +28,7 @@ from homeassistant.util import dt as dt_util from .common import ( DEFAULT_LANG, TEST_DOMAIN, + MockResultStream, MockTTS, MockTTSEntity, MockTTSProvider, @@ -1829,3 +1830,19 @@ async def test_default_engine_prefer_cloud_entity( provider_engine = tts.async_resolve_engine(hass, "test") assert provider_engine == "test" assert tts.async_default_engine(hass) == "tts.cloud_tts_entity" + + +async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> None: + """Test creating streams.""" + await mock_config_entry_setup(hass, mock_tts_entity) + stream = tts.async_create_stream(hass, mock_tts_entity.entity_id) + assert stream.language == mock_tts_entity.default_language + assert stream.options == (mock_tts_entity.default_options or {}) + assert tts.async_get_stream(hass, stream.token) is stream + + data = b"beer" + stream2 = MockResultStream(hass, "wav", data) + assert tts.async_get_stream(hass, stream2.token) is stream2 + assert stream2.extension == "wav" + result_data = b"".join([chunk async for chunk in stream2.async_stream_result()]) + assert result_data == data