From 919684e20a2d9417f2dc0d6a8f6da48eb10fa9a2 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 19 May 2025 05:58:58 -0400 Subject: [PATCH] Minor cleanup for pipeline tts stream test (#145146) --- .../snapshots/test_pipeline.ambr | 2 +- .../assist_pipeline/test_pipeline.py | 29 +++++++++++-------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr index 717823fe4e4..bbe08a2adbe 100644 --- a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr +++ b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: test_chat_log_tts_streaming[to_stream_tts0] +# name: test_chat_log_tts_streaming[to_stream_tts0-1] list([ dict({ 'data': dict({ diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index e318862a2f2..abf6572afc9 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -1559,18 +1559,21 @@ async def test_pipeline_language_used_instead_of_conversation_language( @pytest.mark.parametrize( - "to_stream_tts", + ("to_stream_tts", "expected_chunks"), [ - [ - "hello,", - " ", - "how", - " ", - "are", - " ", - "you", - "?", - ] + ( + [ + "hello,", + " ", + "how", + " ", + "are", + " ", + "you", + "?", + ], + 1, + ), ], ) async def test_chat_log_tts_streaming( @@ -1582,6 +1585,7 @@ async def test_chat_log_tts_streaming( mock_tts_entity: MockTTSEntity, pipeline_data: assist_pipeline.pipeline.PipelineData, to_stream_tts: list[str], + expected_chunks: int, ) -> None: """Test that chat log events are streamed to the TTS entity.""" events: list[assist_pipeline.PipelineEvent] = [] @@ -1625,6 +1629,7 @@ async def test_chat_log_tts_streaming( ) mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio + mock_tts_entity.async_supports_streaming_input = Mock(return_value=True) with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", @@ -1692,7 +1697,7 @@ async def test_chat_log_tts_streaming( streamed_text = "".join(to_stream_tts) assert tts_result == streamed_text - assert len(received_tts) == 1 + assert len(received_tts) == expected_chunks assert "".join(received_tts) == streamed_text assert process_events(events) == snapshot