Add a test for Assist Pipeline streaming deltas to TTS (#144711)

* Add a test for Assist Pipeline streaming deltas to TTS

* Adjust tests to new TTS engine
This commit is contained in:
Paulus Schoutsen 2025-05-12 12:15:05 -04:00 committed by GitHub
parent d471de5645
commit 2266e97417
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 370 additions and 46 deletions

View File

@ -37,7 +37,7 @@ from tests.common import (
mock_platform,
)
from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity
from tests.components.tts.common import MockTTSProvider
from tests.components.tts.common import MockTTSEntity, MockTTSProvider
_TRANSCRIPT = "test transcript"
@ -68,6 +68,15 @@ async def mock_tts_provider() -> MockTTSProvider:
return provider
@pytest.fixture
def mock_tts_entity() -> MockTTSEntity:
"""Test TTS entity."""
entity = MockTTSEntity("en")
entity._attr_unique_id = "test_tts"
entity._attr_supported_languages = ["en-US"]
return entity
@pytest.fixture
async def mock_stt_provider() -> MockSTTProvider:
"""Mock STT provider."""
@ -198,6 +207,7 @@ async def init_supporting_components(
mock_stt_provider: MockSTTProvider,
mock_stt_provider_entity: MockSTTProviderEntity,
mock_tts_provider: MockTTSProvider,
mock_tts_entity: MockTTSEntity,
mock_wake_word_provider_entity: MockWakeWordEntity,
mock_wake_word_provider_entity2: MockWakeWordEntity2,
config_flow_fixture,
@ -209,7 +219,7 @@ async def init_supporting_components(
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(
config_entry, [Platform.STT, Platform.WAKE_WORD]
config_entry, [Platform.STT, Platform.TTS, Platform.WAKE_WORD]
)
return True
@ -230,6 +240,14 @@ async def init_supporting_components(
"""Set up test stt platform via config entry."""
async_add_entities([mock_stt_provider_entity])
async def async_setup_entry_tts_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up test tts platform via config entry."""
async_add_entities([mock_tts_entity])
async def async_setup_entry_wake_word_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
@ -253,6 +271,7 @@ async def init_supporting_components(
"test.tts",
MockTTSPlatform(
async_get_engine=AsyncMock(return_value=mock_tts_provider),
async_setup_entry=async_setup_entry_tts_platform,
),
)
mock_platform(

View File

@ -74,17 +74,17 @@
}),
dict({
'data': dict({
'engine': 'test',
'language': 'en-US',
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
'voice': None,
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
@ -395,17 +395,17 @@
}),
dict({
'data': dict({
'engine': 'test',
'language': 'en-US',
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
'voice': None,
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',

View File

@ -1,4 +1,158 @@
# serializer version: 1
# name: test_chat_log_tts_streaming[to_stream_tts0]
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'test-agent',
'intent_input': 'Set a timer',
'language': 'en',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'role': 'assistant',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'hello,',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': ' ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'how',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': ' ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'are',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': ' ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'you',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': '?',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'intent_output': dict({
'continue_conversation': True,
'conversation_id': <ANY>,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'hello, how are you?',
}),
}),
}),
}),
'processed_locally': False,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': dict({
'engine': 'tts.test',
'language': 'en_US',
'tts_input': 'hello, how are you?',
'voice': None,
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': 'media-source://tts/tts.test?message=hello,+how+are+you?&language=en_US&tts_options=%7B%7D',
'mime_type': 'audio/mpeg',
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.TTS_END: 'tts-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_pipeline_language_used_instead_of_conversation_language
list([
dict({

View File

@ -71,16 +71,16 @@
# ---
# name: test_audio_pipeline.5
dict({
'engine': 'test',
'language': 'en-US',
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
'voice': None,
})
# ---
# name: test_audio_pipeline.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
@ -162,16 +162,16 @@
# ---
# name: test_audio_pipeline_debug.5
dict({
'engine': 'test',
'language': 'en-US',
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
'voice': None,
})
# ---
# name: test_audio_pipeline_debug.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
@ -265,16 +265,16 @@
# ---
# name: test_audio_pipeline_with_enhancements.5
dict({
'engine': 'test',
'language': 'en-US',
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
'voice': None,
})
# ---
# name: test_audio_pipeline_with_enhancements.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
@ -378,16 +378,16 @@
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.7
dict({
'engine': 'test',
'language': 'en-US',
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
'voice': None,
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.8
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',

View File

@ -40,6 +40,7 @@ from . import MANY_LANGUAGES, process_events
from .conftest import (
MockSTTProvider,
MockSTTProviderEntity,
MockTTSEntity,
MockTTSProvider,
MockWakeWordEntity,
make_10ms_chunk,
@ -62,6 +63,12 @@ async def load_homeassistant(hass: HomeAssistant) -> None:
assert await async_setup_component(hass, "homeassistant", {})
@pytest.fixture
async def disable_tts_entity(mock_tts_entity: tts.TextToSpeechEntity) -> None:
"""Disable the TTS entity."""
mock_tts_entity._attr_entity_registry_enabled_default = False
@pytest.mark.usefixtures("init_components")
async def test_load_pipelines(hass: HomeAssistant) -> None:
"""Make sure that we can load/save data correctly."""
@ -283,6 +290,7 @@ async def test_migrate_pipeline_store(
@pytest.mark.usefixtures("init_supporting_components")
@pytest.mark.usefixtures("disable_tts_entity")
async def test_create_default_pipeline(hass: HomeAssistant) -> None:
"""Test async_create_default_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {})
@ -430,6 +438,7 @@ async def test_default_pipeline_no_stt_tts(
],
)
@pytest.mark.usefixtures("init_supporting_components")
@pytest.mark.usefixtures("disable_tts_entity")
async def test_default_pipeline(
hass: HomeAssistant,
mock_stt_provider_entity: MockSTTProviderEntity,
@ -474,6 +483,7 @@ async def test_default_pipeline(
@pytest.mark.usefixtures("init_supporting_components")
@pytest.mark.usefixtures("disable_tts_entity")
async def test_default_pipeline_unsupported_stt_language(
hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity
) -> None:
@ -504,6 +514,7 @@ async def test_default_pipeline_unsupported_stt_language(
@pytest.mark.usefixtures("init_supporting_components")
@pytest.mark.usefixtures("disable_tts_entity")
async def test_default_pipeline_unsupported_tts_language(
hass: HomeAssistant, mock_tts_provider: MockTTSProvider
) -> None:
@ -825,7 +836,7 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
async def test_tts_audio_output(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
mock_tts_entity: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
@ -869,7 +880,7 @@ async def test_tts_audio_output(
== 1
)
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
with patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio:
await pipeline_input.execute()
for event in events:
@ -881,14 +892,14 @@ async def test_tts_audio_output(
# Ensure that no unsupported options were passed in
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
extra_options = set(options).difference(mock_tts_provider.supported_options)
extra_options = set(options).difference(mock_tts_entity.supported_options)
assert len(extra_options) == 0, extra_options
async def test_tts_wav_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
mock_tts_entity: MockTTSEntity,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
@ -920,7 +931,7 @@ async def test_tts_wav_preferred_format(
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_provider.supported_options or [])
supported_options = list(mock_tts_entity.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
@ -931,8 +942,8 @@ async def test_tts_wav_preferred_format(
)
with (
patch.object(mock_tts_provider, "_supported_options", supported_options),
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
patch.object(mock_tts_entity, "_supported_options", supported_options),
patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
@ -955,7 +966,7 @@ async def test_tts_wav_preferred_format(
async def test_tts_dict_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
mock_tts_entity: MockTTSEntity,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
@ -992,7 +1003,7 @@ async def test_tts_dict_preferred_format(
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_provider.supported_options or [])
supported_options = list(mock_tts_entity.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
@ -1003,8 +1014,8 @@ async def test_tts_dict_preferred_format(
)
with (
patch.object(mock_tts_provider, "_supported_options", supported_options),
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
patch.object(mock_tts_entity, "_supported_options", supported_options),
patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
@ -1545,3 +1556,143 @@ async def test_pipeline_language_used_instead_of_conversation_language(
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.language
)
@pytest.mark.parametrize(
"to_stream_tts",
[
[
"hello,",
" ",
"how",
" ",
"are",
" ",
"you",
"?",
]
],
)
async def test_chat_log_tts_streaming(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
mock_tts_entity: MockTTSEntity,
pipeline_data: assist_pipeline.pipeline.PipelineData,
to_stream_tts: list[str],
) -> None:
"""Test that chat log events are streamed to the TTS entity."""
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent"
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Set a timer",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
),
)
received_tts = []
async def async_stream_tts_audio(
request: tts.TTSAudioRequest,
) -> tts.TTSAudioResponse:
"""Mock stream TTS audio."""
async def gen_data():
async for msg in request.message_gen:
received_tts.append(msg)
yield msg.encode()
return tts.TTSAudioResponse(
extension="mp3",
data_gen=gen_data(),
)
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
):
await pipeline_input.validate()
async def mock_converse(
hass: HomeAssistant,
text: str,
conversation_id: str | None,
context: Context,
language: str | None = None,
agent_id: str | None = None,
device_id: str | None = None,
extra_system_prompt: str | None = None,
):
"""Mock converse."""
conversation_input = conversation.ConversationInput(
text=text,
context=context,
conversation_id=conversation_id,
device_id=device_id,
language=language,
agent_id=agent_id,
extra_system_prompt=extra_system_prompt,
)
async def stream_llm_response():
yield {"role": "assistant"}
for chunk in to_stream_tts:
yield {"content": chunk}
with (
chat_session.async_get_chat_session(hass, conversation_id) as session,
conversation.async_get_chat_log(
hass,
session,
conversation_input,
) as chat_log,
):
async for _content in chat_log.async_add_delta_content_stream(
agent_id, stream_llm_response()
):
pass
intent_response = intent.IntentResponse(language)
intent_response.async_set_speech("".join(to_stream_tts))
return conversation.ConversationResult(
response=intent_response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
mock_converse,
):
await pipeline_input.execute()
stream = tts.async_get_stream(hass, events[0].data["tts_output"]["token"])
assert stream is not None
tts_result = "".join(
[chunk.decode() async for chunk in stream.async_stream_result()]
)
streamed_text = "".join(to_stream_tts)
assert tts_result == streamed_text
assert len(received_tts) == 1
assert "".join(received_tts) == streamed_text
assert process_events(events) == snapshot

View File

@ -1153,9 +1153,9 @@ async def test_get_pipeline(
"name": "Home Assistant",
"stt_engine": "stt.mock_stt",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "james_earl_jones",
"tts_engine": "tts.test",
"tts_language": "en_US",
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
@ -1179,9 +1179,9 @@ async def test_get_pipeline(
# It found these defaults
"stt_engine": "stt.mock_stt",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "james_earl_jones",
"tts_engine": "tts.test",
"tts_language": "en_US",
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
@ -1266,9 +1266,9 @@ async def test_list_pipelines(
"name": "Home Assistant",
"stt_engine": "stt.mock_stt",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "james_earl_jones",
"tts_engine": "tts.test",
"tts_language": "en_US",
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,