Minor cleanup for pipeline tts stream test (#145146)

This commit is contained in:
Paulus Schoutsen 2025-05-19 05:58:58 -04:00 committed by GitHub
parent a1d6df6ce9
commit 919684e20a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 13 deletions

View File

@ -1,5 +1,5 @@
# serializer version: 1 # serializer version: 1
# name: test_chat_log_tts_streaming[to_stream_tts0] # name: test_chat_log_tts_streaming[to_stream_tts0-1]
list([ list([
dict({ dict({
'data': dict({ 'data': dict({

View File

@ -1559,8 +1559,9 @@ async def test_pipeline_language_used_instead_of_conversation_language(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"to_stream_tts", ("to_stream_tts", "expected_chunks"),
[ [
(
[ [
"hello,", "hello,",
" ", " ",
@ -1570,7 +1571,9 @@ async def test_pipeline_language_used_instead_of_conversation_language(
" ", " ",
"you", "you",
"?", "?",
] ],
1,
),
], ],
) )
async def test_chat_log_tts_streaming( async def test_chat_log_tts_streaming(
@ -1582,6 +1585,7 @@ async def test_chat_log_tts_streaming(
mock_tts_entity: MockTTSEntity, mock_tts_entity: MockTTSEntity,
pipeline_data: assist_pipeline.pipeline.PipelineData, pipeline_data: assist_pipeline.pipeline.PipelineData,
to_stream_tts: list[str], to_stream_tts: list[str],
expected_chunks: int,
) -> None: ) -> None:
"""Test that chat log events are streamed to the TTS entity.""" """Test that chat log events are streamed to the TTS entity."""
events: list[assist_pipeline.PipelineEvent] = [] events: list[assist_pipeline.PipelineEvent] = []
@ -1625,6 +1629,7 @@ async def test_chat_log_tts_streaming(
) )
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
with patch( with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
@ -1692,7 +1697,7 @@ async def test_chat_log_tts_streaming(
streamed_text = "".join(to_stream_tts) streamed_text = "".join(to_stream_tts)
assert tts_result == streamed_text assert tts_result == streamed_text
assert len(received_tts) == 1 assert len(received_tts) == expected_chunks
assert "".join(received_tts) == streamed_text assert "".join(received_tts) == streamed_text
assert process_events(events) == snapshot assert process_events(events) == snapshot