From 2266e9741773d2338f1eb523e48b613b86698f5c Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 12 May 2025 12:15:05 -0400 Subject: [PATCH] Add a test for Assist Pipeline streaming deltas to TTS (#144711) * Add a test for Assist Pipeline streaming deltas to TTS * Adjust tests to new TTS engine --- tests/components/assist_pipeline/conftest.py | 23 ++- .../assist_pipeline/snapshots/test_init.ambr | 16 +- .../snapshots/test_pipeline.ambr | 154 ++++++++++++++++ .../snapshots/test_websocket.ambr | 32 ++-- .../assist_pipeline/test_pipeline.py | 173 ++++++++++++++++-- .../assist_pipeline/test_websocket.py | 18 +- 6 files changed, 370 insertions(+), 46 deletions(-) diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index a0549f27f05..e20452a1f93 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -37,7 +37,7 @@ from tests.common import ( mock_platform, ) from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity -from tests.components.tts.common import MockTTSProvider +from tests.components.tts.common import MockTTSEntity, MockTTSProvider _TRANSCRIPT = "test transcript" @@ -68,6 +68,15 @@ async def mock_tts_provider() -> MockTTSProvider: return provider +@pytest.fixture +def mock_tts_entity() -> MockTTSEntity: + """Test TTS entity.""" + entity = MockTTSEntity("en") + entity._attr_unique_id = "test_tts" + entity._attr_supported_languages = ["en-US"] + return entity + + @pytest.fixture async def mock_stt_provider() -> MockSTTProvider: """Mock STT provider.""" @@ -198,6 +207,7 @@ async def init_supporting_components( mock_stt_provider: MockSTTProvider, mock_stt_provider_entity: MockSTTProviderEntity, mock_tts_provider: MockTTSProvider, + mock_tts_entity: MockTTSEntity, mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity2: MockWakeWordEntity2, config_flow_fixture, @@ -209,7 +219,7 @@ async def init_supporting_components( ) -> bool: """Set up test config entry.""" await hass.config_entries.async_forward_entry_setups( - config_entry, [Platform.STT, Platform.WAKE_WORD] + config_entry, [Platform.STT, Platform.TTS, Platform.WAKE_WORD] ) return True @@ -230,6 +240,14 @@ async def init_supporting_components( """Set up test stt platform via config entry.""" async_add_entities([mock_stt_provider_entity]) + async def async_setup_entry_tts_platform( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddConfigEntryEntitiesCallback, + ) -> None: + """Set up test tts platform via config entry.""" + async_add_entities([mock_tts_entity]) + async def async_setup_entry_wake_word_platform( hass: HomeAssistant, config_entry: ConfigEntry, @@ -253,6 +271,7 @@ async def init_supporting_components( "test.tts", MockTTSPlatform( async_get_engine=AsyncMock(return_value=mock_tts_provider), + async_setup_entry=async_setup_entry_tts_platform, ), ) mock_platform( diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 81972191868..816430f58d0 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -74,17 +74,17 @@ }), dict({ 'data': dict({ - 'engine': 'test', - 'language': 'en-US', + 'engine': 'tts.test', + 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", - 'voice': 'james_earl_jones', + 'voice': None, }), 'type': , }), dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", + 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', @@ -395,17 +395,17 @@ }), dict({ 'data': dict({ - 'engine': 'test', - 'language': 'en-US', + 'engine': 'tts.test', + 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", - 'voice': 'james_earl_jones', + 'voice': None, }), 'type': , }), dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", + 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', diff --git a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr index 7c0ac254b6e..717823fe4e4 100644 --- a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr +++ b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr @@ -1,4 +1,158 @@ # serializer version: 1 +# name: test_chat_log_tts_streaming[to_stream_tts0] + list([ + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'language': 'en', + 'pipeline': , + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'device_id': None, + 'engine': 'test-agent', + 'intent_input': 'Set a timer', + 'language': 'en', + 'prefer_local_intents': False, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'role': 'assistant', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'hello,', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': ' ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'how', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': ' ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'are', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': ' ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'you', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': '?', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'continue_conversation': True, + 'conversation_id': , + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'hello, how are you?', + }), + }), + }), + }), + 'processed_locally': False, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'tts.test', + 'language': 'en_US', + 'tts_input': 'hello, how are you?', + 'voice': None, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'tts_output': dict({ + 'media_id': 'media-source://tts/tts.test?message=hello,+how+are+you?&language=en_US&tts_options=%7B%7D', + 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- # name: test_pipeline_language_used_instead_of_conversation_language list([ dict({ diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 57ae0095236..41bdba9f3cd 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -71,16 +71,16 @@ # --- # name: test_audio_pipeline.5 dict({ - 'engine': 'test', - 'language': 'en-US', + 'engine': 'tts.test', + 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", - 'voice': 'james_earl_jones', + 'voice': None, }) # --- # name: test_audio_pipeline.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", + 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', @@ -162,16 +162,16 @@ # --- # name: test_audio_pipeline_debug.5 dict({ - 'engine': 'test', - 'language': 'en-US', + 'engine': 'tts.test', + 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", - 'voice': 'james_earl_jones', + 'voice': None, }) # --- # name: test_audio_pipeline_debug.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", + 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', @@ -265,16 +265,16 @@ # --- # name: test_audio_pipeline_with_enhancements.5 dict({ - 'engine': 'test', - 'language': 'en-US', + 'engine': 'tts.test', + 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", - 'voice': 'james_earl_jones', + 'voice': None, }) # --- # name: test_audio_pipeline_with_enhancements.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", + 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', @@ -378,16 +378,16 @@ # --- # name: test_audio_pipeline_with_wake_word_no_timeout.7 dict({ - 'engine': 'test', - 'language': 'en-US', + 'engine': 'tts.test', + 'language': 'en_US', 'tts_input': "Sorry, I couldn't understand that", - 'voice': 'james_earl_jones', + 'voice': None, }) # --- # name: test_audio_pipeline_with_wake_word_no_timeout.8 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", + 'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D", 'mime_type': 'audio/mpeg', 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 4f15853b296..e318862a2f2 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -40,6 +40,7 @@ from . import MANY_LANGUAGES, process_events from .conftest import ( MockSTTProvider, MockSTTProviderEntity, + MockTTSEntity, MockTTSProvider, MockWakeWordEntity, make_10ms_chunk, @@ -62,6 +63,12 @@ async def load_homeassistant(hass: HomeAssistant) -> None: assert await async_setup_component(hass, "homeassistant", {}) +@pytest.fixture +async def disable_tts_entity(mock_tts_entity: tts.TextToSpeechEntity) -> None: + """Disable the TTS entity.""" + mock_tts_entity._attr_entity_registry_enabled_default = False + + @pytest.mark.usefixtures("init_components") async def test_load_pipelines(hass: HomeAssistant) -> None: """Make sure that we can load/save data correctly.""" @@ -283,6 +290,7 @@ async def test_migrate_pipeline_store( @pytest.mark.usefixtures("init_supporting_components") +@pytest.mark.usefixtures("disable_tts_entity") async def test_create_default_pipeline(hass: HomeAssistant) -> None: """Test async_create_default_pipeline.""" assert await async_setup_component(hass, "assist_pipeline", {}) @@ -430,6 +438,7 @@ async def test_default_pipeline_no_stt_tts( ], ) @pytest.mark.usefixtures("init_supporting_components") +@pytest.mark.usefixtures("disable_tts_entity") async def test_default_pipeline( hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity, @@ -474,6 +483,7 @@ async def test_default_pipeline( @pytest.mark.usefixtures("init_supporting_components") +@pytest.mark.usefixtures("disable_tts_entity") async def test_default_pipeline_unsupported_stt_language( hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity ) -> None: @@ -504,6 +514,7 @@ async def test_default_pipeline_unsupported_stt_language( @pytest.mark.usefixtures("init_supporting_components") +@pytest.mark.usefixtures("disable_tts_entity") async def test_default_pipeline_unsupported_tts_language( hass: HomeAssistant, mock_tts_provider: MockTTSProvider ) -> None: @@ -825,7 +836,7 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None: async def test_tts_audio_output( hass: HomeAssistant, hass_client: ClientSessionGenerator, - mock_tts_provider: MockTTSProvider, + mock_tts_entity: MockTTSProvider, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, mock_chat_session: chat_session.ChatSession, @@ -869,7 +880,7 @@ async def test_tts_audio_output( == 1 ) - with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio: + with patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio: await pipeline_input.execute() for event in events: @@ -881,14 +892,14 @@ async def test_tts_audio_output( # Ensure that no unsupported options were passed in assert mock_get_tts_audio.called options = mock_get_tts_audio.call_args_list[0].kwargs["options"] - extra_options = set(options).difference(mock_tts_provider.supported_options) + extra_options = set(options).difference(mock_tts_entity.supported_options) assert len(extra_options) == 0, extra_options async def test_tts_wav_preferred_format( hass: HomeAssistant, hass_client: ClientSessionGenerator, - mock_tts_provider: MockTTSProvider, + mock_tts_entity: MockTTSEntity, init_components, mock_chat_session: chat_session.ChatSession, pipeline_data: assist_pipeline.pipeline.PipelineData, @@ -920,7 +931,7 @@ async def test_tts_wav_preferred_format( await pipeline_input.validate() # Make the TTS provider support preferred format options - supported_options = list(mock_tts_provider.supported_options or []) + supported_options = list(mock_tts_entity.supported_options or []) supported_options.extend( [ tts.ATTR_PREFERRED_FORMAT, @@ -931,8 +942,8 @@ async def test_tts_wav_preferred_format( ) with ( - patch.object(mock_tts_provider, "_supported_options", supported_options), - patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio, + patch.object(mock_tts_entity, "_supported_options", supported_options), + patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio, ): await pipeline_input.execute() @@ -955,7 +966,7 @@ async def test_tts_wav_preferred_format( async def test_tts_dict_preferred_format( hass: HomeAssistant, hass_client: ClientSessionGenerator, - mock_tts_provider: MockTTSProvider, + mock_tts_entity: MockTTSEntity, init_components, mock_chat_session: chat_session.ChatSession, pipeline_data: assist_pipeline.pipeline.PipelineData, @@ -992,7 +1003,7 @@ async def test_tts_dict_preferred_format( await pipeline_input.validate() # Make the TTS provider support preferred format options - supported_options = list(mock_tts_provider.supported_options or []) + supported_options = list(mock_tts_entity.supported_options or []) supported_options.extend( [ tts.ATTR_PREFERRED_FORMAT, @@ -1003,8 +1014,8 @@ async def test_tts_dict_preferred_format( ) with ( - patch.object(mock_tts_provider, "_supported_options", supported_options), - patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio, + patch.object(mock_tts_entity, "_supported_options", supported_options), + patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio, ): await pipeline_input.execute() @@ -1545,3 +1556,143 @@ async def test_pipeline_language_used_instead_of_conversation_language( mock_async_converse.call_args_list[0].kwargs.get("language") == pipeline.language ) + + +@pytest.mark.parametrize( + "to_stream_tts", + [ + [ + "hello,", + " ", + "how", + " ", + "are", + " ", + "you", + "?", + ] + ], +) +async def test_chat_log_tts_streaming( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + mock_chat_session: chat_session.ChatSession, + snapshot: SnapshotAssertion, + mock_tts_entity: MockTTSEntity, + pipeline_data: assist_pipeline.pipeline.PipelineData, + to_stream_tts: list[str], +) -> None: + """Test that chat log events are streamed to the TTS entity.""" + events: list[assist_pipeline.PipelineEvent] = [] + + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + await assist_pipeline.pipeline.async_update_pipeline( + hass, pipeline, conversation_engine="test-agent" + ) + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="Set a timer", + session=mock_chat_session, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=events.append, + ), + ) + + received_tts = [] + + async def async_stream_tts_audio( + request: tts.TTSAudioRequest, + ) -> tts.TTSAudioResponse: + """Mock stream TTS audio.""" + + async def gen_data(): + async for msg in request.message_gen: + received_tts.append(msg) + yield msg.encode() + + return tts.TTSAudioResponse( + extension="mp3", + data_gen=gen_data(), + ) + + mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", + return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + ): + await pipeline_input.validate() + + async def mock_converse( + hass: HomeAssistant, + text: str, + conversation_id: str | None, + context: Context, + language: str | None = None, + agent_id: str | None = None, + device_id: str | None = None, + extra_system_prompt: str | None = None, + ): + """Mock converse.""" + conversation_input = conversation.ConversationInput( + text=text, + context=context, + conversation_id=conversation_id, + device_id=device_id, + language=language, + agent_id=agent_id, + extra_system_prompt=extra_system_prompt, + ) + + async def stream_llm_response(): + yield {"role": "assistant"} + for chunk in to_stream_tts: + yield {"content": chunk} + + with ( + chat_session.async_get_chat_session(hass, conversation_id) as session, + conversation.async_get_chat_log( + hass, + session, + conversation_input, + ) as chat_log, + ): + async for _content in chat_log.async_add_delta_content_stream( + agent_id, stream_llm_response() + ): + pass + intent_response = intent.IntentResponse(language) + intent_response.async_set_speech("".join(to_stream_tts)) + return conversation.ConversationResult( + response=intent_response, + conversation_id=chat_log.conversation_id, + continue_conversation=chat_log.continue_conversation, + ) + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", + mock_converse, + ): + await pipeline_input.execute() + + stream = tts.async_get_stream(hass, events[0].data["tts_output"]["token"]) + assert stream is not None + tts_result = "".join( + [chunk.decode() async for chunk in stream.async_stream_result()] + ) + + streamed_text = "".join(to_stream_tts) + assert tts_result == streamed_text + assert len(received_tts) == 1 + assert "".join(received_tts) == streamed_text + + assert process_events(events) == snapshot diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 060c0dce660..bf9818f2a5f 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -1153,9 +1153,9 @@ async def test_get_pipeline( "name": "Home Assistant", "stt_engine": "stt.mock_stt", "stt_language": "en-US", - "tts_engine": "test", - "tts_language": "en-US", - "tts_voice": "james_earl_jones", + "tts_engine": "tts.test", + "tts_language": "en_US", + "tts_voice": None, "wake_word_entity": None, "wake_word_id": None, "prefer_local_intents": False, @@ -1179,9 +1179,9 @@ async def test_get_pipeline( # It found these defaults "stt_engine": "stt.mock_stt", "stt_language": "en-US", - "tts_engine": "test", - "tts_language": "en-US", - "tts_voice": "james_earl_jones", + "tts_engine": "tts.test", + "tts_language": "en_US", + "tts_voice": None, "wake_word_entity": None, "wake_word_id": None, "prefer_local_intents": False, @@ -1266,9 +1266,9 @@ async def test_list_pipelines( "name": "Home Assistant", "stt_engine": "stt.mock_stt", "stt_language": "en-US", - "tts_engine": "test", - "tts_language": "en-US", - "tts_voice": "james_earl_jones", + "tts_engine": "tts.test", + "tts_language": "en_US", + "tts_voice": None, "wake_word_entity": None, "wake_word_id": None, "prefer_local_intents": False,