From 2cce1b024e69186498f2b25d8f63d3a708258b0f Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 1 Mar 2025 15:43:00 -0500 Subject: [PATCH] Migrate Assist Pipeline to use TTS stream (#139542) * Migrate Pipeline to use TTS stream * Fix tests --- .../components/assist_pipeline/pipeline.py | 63 ++++----- homeassistant/components/tts/__init__.py | 35 +++-- .../assist_pipeline/snapshots/test_init.ambr | 24 ++++ .../snapshots/test_websocket.ambr | 90 +++++++++---- tests/components/assist_pipeline/test_init.py | 42 +++--- .../assist_pipeline/test_websocket.py | 121 +++++------------- tests/components/tts/test_init.py | 23 ---- 7 files changed, 196 insertions(+), 202 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 038874d1966..a028fa638df 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -19,14 +19,7 @@ import wave import hass_nabucasa import voluptuous as vol -from homeassistant.components import ( - conversation, - media_source, - stt, - tts, - wake_word, - websocket_api, -) +from homeassistant.components import conversation, stt, tts, wake_word, websocket_api from homeassistant.components.tts import ( generate_media_source_id as tts_generate_media_source_id, ) @@ -569,8 +562,7 @@ class PipelineRun: id: str = field(default_factory=ulid_util.ulid_now) stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False, repr=False) - tts_engine: str = field(init=False, repr=False) - tts_options: dict | None = field(init=False, default=None) + tts_stream: tts.ResultStream | None = field(init=False, default=None) wake_word_entity_id: str | None = field(init=False, default=None, repr=False) wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False, repr=False) @@ -648,13 +640,18 @@ class PipelineRun: self._device_id = device_id self._start_debug_recording_thread() - data = { + data: dict[str, Any] = { "pipeline": self.pipeline.id, "language": self.language, "conversation_id": conversation_id, } if self.runner_data is not None: data["runner_data"] = self.runner_data + if self.tts_stream: + data["tts_output"] = { + "url": self.tts_stream.url, + "mime_type": self.tts_stream.content_type, + } self.process_event(PipelineEvent(PipelineEventType.RUN_START, data)) @@ -1246,36 +1243,31 @@ class PipelineRun: tts_options[tts.ATTR_PREFERRED_SAMPLE_BYTES] = SAMPLE_WIDTH try: - options_supported = await tts.async_support_options( - self.hass, - engine, - self.pipeline.tts_language, - tts_options, + self.tts_stream = tts.async_create_stream( + hass=self.hass, + engine=engine, + language=self.pipeline.tts_language, + options=tts_options, ) except HomeAssistantError as err: - raise TextToSpeechError( - code="tts-not-supported", - message=f"Text-to-speech engine '{engine}' not found", - ) from err - if not options_supported: raise TextToSpeechError( code="tts-not-supported", message=( f"Text-to-speech engine {engine} " - f"does not support language {self.pipeline.tts_language} or options {tts_options}" + f"does not support language {self.pipeline.tts_language} or options {tts_options}:" + f" {err}" ), - ) - - self.tts_engine = engine - self.tts_options = tts_options + ) from err async def text_to_speech(self, tts_input: str) -> None: """Run text-to-speech portion of pipeline.""" + assert self.tts_stream is not None + self.process_event( PipelineEvent( PipelineEventType.TTS_START, { - "engine": self.tts_engine, + "engine": self.tts_stream.engine, "language": self.pipeline.tts_language, "voice": self.pipeline.tts_voice, "tts_input": tts_input, @@ -1288,14 +1280,9 @@ class PipelineRun: tts_media_id = tts_generate_media_source_id( self.hass, tts_input, - engine=self.tts_engine, - language=self.pipeline.tts_language, - options=self.tts_options, - ) - tts_media = await media_source.async_resolve_media( - self.hass, - tts_media_id, - None, + engine=self.tts_stream.engine, + language=self.tts_stream.language, + options=self.tts_stream.options, ) except Exception as src_error: _LOGGER.exception("Unexpected error during text-to-speech") @@ -1304,10 +1291,12 @@ class PipelineRun: message="Unexpected error during text-to-speech", ) from src_error - _LOGGER.debug("TTS result %s", tts_media) + self.tts_stream.async_set_message(tts_input) + tts_output = { "media_id": tts_media_id, - **asdict(tts_media), + "url": self.tts_stream.url, + "mime_type": self.tts_stream.content_type, } self.process_event( diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 32c4ba20670..98ce76cafde 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -79,13 +79,13 @@ __all__ = [ "PLATFORM_SCHEMA", "PLATFORM_SCHEMA_BASE", "Provider", + "ResultStream", "SampleFormat", "TextToSpeechEntity", "TtsAudioType", "Voice", "async_default_engine", "async_get_media_source_audio", - "async_support_options", "generate_media_source_id", ] @@ -167,22 +167,19 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None: return async_default_engine(hass) -async def async_support_options( +@callback +def async_create_stream( hass: HomeAssistant, engine: str, language: str | None = None, options: dict | None = None, -) -> bool: - """Return if an engine supports options.""" - if (engine_instance := get_engine_instance(hass, engine)) is None: - raise HomeAssistantError(f"Provider {engine} not found") - - try: - hass.data[DATA_TTS_MANAGER].process_options(engine_instance, language, options) - except HomeAssistantError: - return False - - return True +) -> ResultStream: + """Create a streaming URL where the rendered TTS can be retrieved.""" + return hass.data[DATA_TTS_MANAGER].async_create_result_stream( + engine=engine, + language=language, + options=options, + ) async def async_get_media_source_audio( @@ -407,6 +404,18 @@ class ResultStream: """Set cache key for message to be streamed.""" self._result_cache_key.set_result(cache_key) + @callback + def async_set_message(self, message: str) -> None: + """Set message to be generated.""" + cache_key = self._manager.async_cache_message_in_memory( + engine=self.engine, + message=message, + use_file_cache=self.use_file_cache, + language=self.language, + options=self.options, + ) + self._result_cache_key.set_result(cache_key) + async def async_stream_result(self) -> AsyncGenerator[bytes]: """Get the stream of this result.""" cache_key = await self._result_cache_key diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index f5e5f813db6..2375d48fcf9 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -6,6 +6,10 @@ 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/test_token.mp3', + }), }), 'type': , }), @@ -99,6 +103,10 @@ 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/test_token.mp3', + }), }), 'type': , }), @@ -192,6 +200,10 @@ 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/test_token.mp3', + }), }), 'type': , }), @@ -285,6 +297,10 @@ 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/test_token.mp3', + }), }), 'type': , }), @@ -402,6 +418,10 @@ 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }), 'type': , }), @@ -598,6 +618,10 @@ 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }), 'type': , }), diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 509f2072509..d937b5396d1 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -8,6 +8,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/test_token.mp3', + }), }) # --- # name: test_audio_pipeline.1 @@ -93,6 +97,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/test_token.mp3', + }), }) # --- # name: test_audio_pipeline_debug.1 @@ -190,6 +198,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/test_token.mp3', + }), }) # --- # name: test_audio_pipeline_with_enhancements.1 @@ -275,6 +287,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/test_token.mp3', + }), }) # --- # name: test_audio_pipeline_with_wake_word_no_timeout.1 @@ -382,6 +398,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/test_token.mp3', + }), }) # --- # name: test_audio_pipeline_with_wake_word_timeout.1 @@ -585,6 +605,10 @@ 'stt_binary_handler_id': None, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_pipeline_empty_tts_output.1 @@ -634,6 +658,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_stt_cooldown_different_ids.1 @@ -645,6 +673,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_stt_cooldown_same_id @@ -656,6 +688,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_stt_cooldown_same_id.1 @@ -667,6 +703,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_stt_stream_failed @@ -678,6 +718,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_stt_stream_failed.1 @@ -798,28 +842,6 @@ 'message': 'Timeout running pipeline', }) # --- -# name: test_tts_failed - dict({ - 'conversation_id': 'mock-ulid', - 'language': 'en', - 'pipeline': , - 'runner_data': dict({ - 'stt_binary_handler_id': None, - 'timeout': 300, - }), - }) -# --- -# name: test_tts_failed.1 - dict({ - 'engine': 'test', - 'language': 'en-US', - 'tts_input': 'Lights are on.', - 'voice': 'james_earl_jones', - }) -# --- -# name: test_tts_failed.2 - None -# --- # name: test_wake_word_cooldown_different_entities dict({ 'conversation_id': 'mock-ulid', @@ -829,6 +851,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_wake_word_cooldown_different_entities.1 @@ -840,6 +866,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_wake_word_cooldown_different_entities.2 @@ -892,6 +922,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_wake_word_cooldown_different_ids.1 @@ -903,6 +937,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_wake_word_cooldown_different_ids.2 @@ -958,6 +996,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_wake_word_cooldown_same_id.1 @@ -969,6 +1011,10 @@ 'stt_binary_handler_id': 1, 'timeout': 300, }), + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), }) # --- # name: test_wake_word_cooldown_same_id.2 diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index e983e4a96e3..0e04d1f0cd2 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -43,13 +43,21 @@ from tests.typing import ClientSessionGenerator, WebSocketGenerator @pytest.fixture(autouse=True) -def mock_ulid() -> Generator[Mock]: - """Mock the ulid of chat sessions.""" - with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now: - mock_ulid_now.return_value = "mock-ulid" +def mock_chat_session_id() -> Generator[Mock]: + """Mock the conversation ID of chat sessions.""" + with patch( + "homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid" + ) as mock_ulid_now: yield mock_ulid_now +@pytest.fixture(autouse=True) +def mock_tts_token() -> Generator[None]: + """Mock the TTS token for URLs.""" + with patch("secrets.token_urlsafe", return_value="mocked-token"): + yield + + def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]: """Process events to remove dynamic values.""" processed = [] @@ -797,10 +805,16 @@ async def test_tts_audio_output( await pipeline_input.validate() # Verify TTS audio settings - assert pipeline_input.run.tts_options is not None - assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" - assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000 - assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1 + assert pipeline_input.run.tts_stream.options is not None + assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" + assert ( + pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) + == 16000 + ) + assert ( + pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) + == 1 + ) with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio: await pipeline_input.execute() @@ -809,9 +823,7 @@ async def test_tts_audio_output( if event.type == assist_pipeline.PipelineEventType.TTS_END: # We must fetch the media URL to trigger the TTS assert event.data - media_id = event.data["tts_output"]["media_id"] - resolved = await media_source.async_resolve_media(hass, media_id, None) - await client.get(resolved.url) + await client.get(event.data["tts_output"]["url"]) # Ensure that no unsupported options were passed in assert mock_get_tts_audio.called @@ -875,9 +887,7 @@ async def test_tts_wav_preferred_format( if event.type == assist_pipeline.PipelineEventType.TTS_END: # We must fetch the media URL to trigger the TTS assert event.data - media_id = event.data["tts_output"]["media_id"] - resolved = await media_source.async_resolve_media(hass, media_id, None) - await client.get(resolved.url) + await client.get(event.data["tts_output"]["url"]) assert mock_get_tts_audio.called options = mock_get_tts_audio.call_args_list[0].kwargs["options"] @@ -949,9 +959,7 @@ async def test_tts_dict_preferred_format( if event.type == assist_pipeline.PipelineEventType.TTS_END: # We must fetch the media URL to trigger the TTS assert event.data - media_id = event.data["tts_output"]["media_id"] - resolved = await media_source.async_resolve_media(hass, media_id, None) - await client.get(resolved.url) + await client.get(event.data["tts_output"]["url"]) assert mock_get_tts_audio.called options = mock_get_tts_audio.call_args_list[0].kwargs["options"] diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index f856bbe7f61..060c0dce660 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -20,6 +20,8 @@ from homeassistant.components.assist_pipeline.pipeline import ( DeviceAudioQueue, Pipeline, PipelineData, + async_get_pipelines, + async_update_pipeline, ) from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -38,13 +40,21 @@ from tests.typing import WebSocketGenerator @pytest.fixture(autouse=True) -def mock_ulid() -> Generator[Mock]: - """Mock the ulid of chat sessions.""" - with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now: - mock_ulid_now.return_value = "mock-ulid" +def mock_chat_session_id() -> Generator[Mock]: + """Mock the conversation ID of chat sessions.""" + with patch( + "homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid" + ) as mock_ulid_now: yield mock_ulid_now +@pytest.fixture(autouse=True) +def mock_tts_token() -> Generator[None]: + """Mock the TTS token for URLs.""" + with patch("secrets.token_urlsafe", return_value="mocked-token"): + yield + + @pytest.mark.parametrize( "extra_msg", [ @@ -825,74 +835,6 @@ async def test_stt_stream_failed( assert msg["result"] == {"events": events} -async def test_tts_failed( - hass: HomeAssistant, - hass_ws_client: WebSocketGenerator, - init_components, - snapshot: SnapshotAssertion, -) -> None: - """Test pipeline run with text-to-speech error.""" - events = [] - client = await hass_ws_client(hass) - - with patch( - "homeassistant.components.media_source.async_resolve_media", - side_effect=RuntimeError, - ): - await client.send_json_auto_id( - { - "type": "assist_pipeline/run", - "start_stage": "tts", - "end_stage": "tts", - "input": {"text": "Lights are on."}, - } - ) - - # result - msg = await client.receive_json() - assert msg["success"] - - # run start - msg = await client.receive_json() - assert msg["event"]["type"] == "run-start" - msg["event"]["data"]["pipeline"] = ANY - assert msg["event"]["data"] == snapshot - events.append(msg["event"]) - - # tts start - msg = await client.receive_json() - assert msg["event"]["type"] == "tts-start" - assert msg["event"]["data"] == snapshot - events.append(msg["event"]) - - # tts error - msg = await client.receive_json() - assert msg["event"]["type"] == "error" - assert msg["event"]["data"]["code"] == "tts-failed" - events.append(msg["event"]) - - # run end - msg = await client.receive_json() - assert msg["event"]["type"] == "run-end" - assert msg["event"]["data"] == snapshot - events.append(msg["event"]) - - pipeline_data: PipelineData = hass.data[DOMAIN] - pipeline_id = list(pipeline_data.pipeline_debug)[0] - pipeline_run_id = list(pipeline_data.pipeline_debug[pipeline_id])[0] - - await client.send_json_auto_id( - { - "type": "assist_pipeline/pipeline_debug/get", - "pipeline_id": pipeline_id, - "pipeline_run_id": pipeline_run_id, - } - ) - msg = await client.receive_json() - assert msg["success"] - assert msg["result"] == {"events": events} - - async def test_tts_provider_missing( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, @@ -903,23 +845,22 @@ async def test_tts_provider_missing( """Test pipeline run with text-to-speech error.""" client = await hass_ws_client(hass) - with patch( - "homeassistant.components.tts.async_support_options", - side_effect=HomeAssistantError, - ): - await client.send_json_auto_id( - { - "type": "assist_pipeline/run", - "start_stage": "tts", - "end_stage": "tts", - "input": {"text": "Lights are on."}, - } - ) + pipelines = async_get_pipelines(hass) + await async_update_pipeline(hass, pipelines[0], tts_engine="unavailable") - # result - msg = await client.receive_json() - assert not msg["success"] - assert msg["error"]["code"] == "tts-not-supported" + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "tts", + "end_stage": "tts", + "input": {"text": "Lights are on."}, + } + ) + + # result + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"]["code"] == "tts-not-supported" async def test_tts_provider_bad_options( @@ -933,8 +874,8 @@ async def test_tts_provider_bad_options( client = await hass_ws_client(hass) with patch( - "homeassistant.components.tts.async_support_options", - return_value=False, + "homeassistant.components.tts.SpeechManager.process_options", + side_effect=HomeAssistantError("Language not supported"), ): await client.send_json_auto_id( { diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 8dece920907..1b9692cc70c 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -1376,29 +1376,6 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None assert tts.async_resolve_engine(hass, None) is None -@pytest.mark.parametrize( - ("setup", "engine_id"), - [ - ("mock_setup", "test"), - ("mock_config_entry_setup", "tts.test"), - ], - indirect=["setup"], -) -async def test_support_options(hass: HomeAssistant, setup: str, engine_id: str) -> None: - """Test supporting options.""" - assert await tts.async_support_options(hass, engine_id, "en_US") is True - assert await tts.async_support_options(hass, engine_id, "nl") is False - assert ( - await tts.async_support_options( - hass, engine_id, "en_US", {"invalid_option": "yo"} - ) - is False - ) - - with pytest.raises(HomeAssistantError): - await tts.async_support_options(hass, "non-existing") - - async def test_legacy_fetching_in_async( hass: HomeAssistant, hass_client: ClientSessionGenerator ) -> None: