mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 23:27:37 +00:00
Add a test for Assist Pipeline streaming deltas to TTS (#144711)
* Add a test for Assist Pipeline streaming deltas to TTS * Adjust tests to new TTS engine
This commit is contained in:
parent
d471de5645
commit
2266e97417
@ -37,7 +37,7 @@ from tests.common import (
|
||||
mock_platform,
|
||||
)
|
||||
from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity
|
||||
from tests.components.tts.common import MockTTSProvider
|
||||
from tests.components.tts.common import MockTTSEntity, MockTTSProvider
|
||||
|
||||
_TRANSCRIPT = "test transcript"
|
||||
|
||||
@ -68,6 +68,15 @@ async def mock_tts_provider() -> MockTTSProvider:
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts_entity() -> MockTTSEntity:
|
||||
"""Test TTS entity."""
|
||||
entity = MockTTSEntity("en")
|
||||
entity._attr_unique_id = "test_tts"
|
||||
entity._attr_supported_languages = ["en-US"]
|
||||
return entity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_stt_provider() -> MockSTTProvider:
|
||||
"""Mock STT provider."""
|
||||
@ -198,6 +207,7 @@ async def init_supporting_components(
|
||||
mock_stt_provider: MockSTTProvider,
|
||||
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
mock_wake_word_provider_entity2: MockWakeWordEntity2,
|
||||
config_flow_fixture,
|
||||
@ -209,7 +219,7 @@ async def init_supporting_components(
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(
|
||||
config_entry, [Platform.STT, Platform.WAKE_WORD]
|
||||
config_entry, [Platform.STT, Platform.TTS, Platform.WAKE_WORD]
|
||||
)
|
||||
return True
|
||||
|
||||
@ -230,6 +240,14 @@ async def init_supporting_components(
|
||||
"""Set up test stt platform via config entry."""
|
||||
async_add_entities([mock_stt_provider_entity])
|
||||
|
||||
async def async_setup_entry_tts_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test tts platform via config entry."""
|
||||
async_add_entities([mock_tts_entity])
|
||||
|
||||
async def async_setup_entry_wake_word_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
@ -253,6 +271,7 @@ async def init_supporting_components(
|
||||
"test.tts",
|
||||
MockTTSPlatform(
|
||||
async_get_engine=AsyncMock(return_value=mock_tts_provider),
|
||||
async_setup_entry=async_setup_entry_tts_platform,
|
||||
),
|
||||
)
|
||||
mock_platform(
|
||||
|
@ -74,17 +74,17 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'engine': 'tts.test',
|
||||
'language': 'en_US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': 'james_earl_jones',
|
||||
'voice': None,
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
@ -395,17 +395,17 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'engine': 'tts.test',
|
||||
'language': 'en_US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': 'james_earl_jones',
|
||||
'voice': None,
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
|
@ -1,4 +1,158 @@
|
||||
# serializer version: 1
|
||||
# name: test_chat_log_tts_streaming[to_stream_tts0]
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'test-agent',
|
||||
'intent_input': 'Set a timer',
|
||||
'language': 'en',
|
||||
'prefer_local_intents': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'role': 'assistant',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'hello,',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': ' ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'how',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': ' ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'are',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': ' ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'you',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': '?',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'continue_conversation': True,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': 'hello, how are you?',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'processed_locally': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'tts.test',
|
||||
'language': 'en_US',
|
||||
'tts_input': 'hello, how are you?',
|
||||
'voice': None,
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': 'media-source://tts/tts.test?message=hello,+how+are+you?&language=en_US&tts_options=%7B%7D',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_pipeline_language_used_instead_of_conversation_language
|
||||
list([
|
||||
dict({
|
||||
|
@ -71,16 +71,16 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline.5
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'engine': 'tts.test',
|
||||
'language': 'en_US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': 'james_earl_jones',
|
||||
'voice': None,
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
@ -162,16 +162,16 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.5
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'engine': 'tts.test',
|
||||
'language': 'en_US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': 'james_earl_jones',
|
||||
'voice': None,
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
@ -265,16 +265,16 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.5
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'engine': 'tts.test',
|
||||
'language': 'en_US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': 'james_earl_jones',
|
||||
'voice': None,
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
@ -378,16 +378,16 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.7
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'engine': 'tts.test',
|
||||
'language': 'en_US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': 'james_earl_jones',
|
||||
'voice': None,
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.8
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
|
@ -40,6 +40,7 @@ from . import MANY_LANGUAGES, process_events
|
||||
from .conftest import (
|
||||
MockSTTProvider,
|
||||
MockSTTProviderEntity,
|
||||
MockTTSEntity,
|
||||
MockTTSProvider,
|
||||
MockWakeWordEntity,
|
||||
make_10ms_chunk,
|
||||
@ -62,6 +63,12 @@ async def load_homeassistant(hass: HomeAssistant) -> None:
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def disable_tts_entity(mock_tts_entity: tts.TextToSpeechEntity) -> None:
|
||||
"""Disable the TTS entity."""
|
||||
mock_tts_entity._attr_entity_registry_enabled_default = False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("init_components")
|
||||
async def test_load_pipelines(hass: HomeAssistant) -> None:
|
||||
"""Make sure that we can load/save data correctly."""
|
||||
@ -283,6 +290,7 @@ async def test_migrate_pipeline_store(
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("init_supporting_components")
|
||||
@pytest.mark.usefixtures("disable_tts_entity")
|
||||
async def test_create_default_pipeline(hass: HomeAssistant) -> None:
|
||||
"""Test async_create_default_pipeline."""
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
@ -430,6 +438,7 @@ async def test_default_pipeline_no_stt_tts(
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("init_supporting_components")
|
||||
@pytest.mark.usefixtures("disable_tts_entity")
|
||||
async def test_default_pipeline(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||
@ -474,6 +483,7 @@ async def test_default_pipeline(
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("init_supporting_components")
|
||||
@pytest.mark.usefixtures("disable_tts_entity")
|
||||
async def test_default_pipeline_unsupported_stt_language(
|
||||
hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity
|
||||
) -> None:
|
||||
@ -504,6 +514,7 @@ async def test_default_pipeline_unsupported_stt_language(
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("init_supporting_components")
|
||||
@pytest.mark.usefixtures("disable_tts_entity")
|
||||
async def test_default_pipeline_unsupported_tts_language(
|
||||
hass: HomeAssistant, mock_tts_provider: MockTTSProvider
|
||||
) -> None:
|
||||
@ -825,7 +836,7 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
|
||||
async def test_tts_audio_output(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
mock_tts_entity: MockTTSProvider,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
@ -869,7 +880,7 @@ async def test_tts_audio_output(
|
||||
== 1
|
||||
)
|
||||
|
||||
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
|
||||
with patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio:
|
||||
await pipeline_input.execute()
|
||||
|
||||
for event in events:
|
||||
@ -881,14 +892,14 @@ async def test_tts_audio_output(
|
||||
# Ensure that no unsupported options were passed in
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
extra_options = set(options).difference(mock_tts_provider.supported_options)
|
||||
extra_options = set(options).difference(mock_tts_entity.supported_options)
|
||||
assert len(extra_options) == 0, extra_options
|
||||
|
||||
|
||||
async def test_tts_wav_preferred_format(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
@ -920,7 +931,7 @@ async def test_tts_wav_preferred_format(
|
||||
await pipeline_input.validate()
|
||||
|
||||
# Make the TTS provider support preferred format options
|
||||
supported_options = list(mock_tts_provider.supported_options or [])
|
||||
supported_options = list(mock_tts_entity.supported_options or [])
|
||||
supported_options.extend(
|
||||
[
|
||||
tts.ATTR_PREFERRED_FORMAT,
|
||||
@ -931,8 +942,8 @@ async def test_tts_wav_preferred_format(
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
||||
patch.object(mock_tts_entity, "_supported_options", supported_options),
|
||||
patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio,
|
||||
):
|
||||
await pipeline_input.execute()
|
||||
|
||||
@ -955,7 +966,7 @@ async def test_tts_wav_preferred_format(
|
||||
async def test_tts_dict_preferred_format(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
@ -992,7 +1003,7 @@ async def test_tts_dict_preferred_format(
|
||||
await pipeline_input.validate()
|
||||
|
||||
# Make the TTS provider support preferred format options
|
||||
supported_options = list(mock_tts_provider.supported_options or [])
|
||||
supported_options = list(mock_tts_entity.supported_options or [])
|
||||
supported_options.extend(
|
||||
[
|
||||
tts.ATTR_PREFERRED_FORMAT,
|
||||
@ -1003,8 +1014,8 @@ async def test_tts_dict_preferred_format(
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
||||
patch.object(mock_tts_entity, "_supported_options", supported_options),
|
||||
patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio,
|
||||
):
|
||||
await pipeline_input.execute()
|
||||
|
||||
@ -1545,3 +1556,143 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||
== pipeline.language
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"to_stream_tts",
|
||||
[
|
||||
[
|
||||
"hello,",
|
||||
" ",
|
||||
"how",
|
||||
" ",
|
||||
"are",
|
||||
" ",
|
||||
"you",
|
||||
"?",
|
||||
]
|
||||
],
|
||||
)
|
||||
async def test_chat_log_tts_streaming(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
to_stream_tts: list[str],
|
||||
) -> None:
|
||||
"""Test that chat log events are streamed to the TTS entity."""
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
await assist_pipeline.pipeline.async_update_pipeline(
|
||||
hass, pipeline, conversation_engine="test-agent"
|
||||
)
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="Set a timer",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
|
||||
received_tts = []
|
||||
|
||||
async def async_stream_tts_audio(
|
||||
request: tts.TTSAudioRequest,
|
||||
) -> tts.TTSAudioResponse:
|
||||
"""Mock stream TTS audio."""
|
||||
|
||||
async def gen_data():
|
||||
async for msg in request.message_gen:
|
||||
received_tts.append(msg)
|
||||
yield msg.encode()
|
||||
|
||||
return tts.TTSAudioResponse(
|
||||
extension="mp3",
|
||||
data_gen=gen_data(),
|
||||
)
|
||||
|
||||
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||
):
|
||||
await pipeline_input.validate()
|
||||
|
||||
async def mock_converse(
|
||||
hass: HomeAssistant,
|
||||
text: str,
|
||||
conversation_id: str | None,
|
||||
context: Context,
|
||||
language: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
device_id: str | None = None,
|
||||
extra_system_prompt: str | None = None,
|
||||
):
|
||||
"""Mock converse."""
|
||||
conversation_input = conversation.ConversationInput(
|
||||
text=text,
|
||||
context=context,
|
||||
conversation_id=conversation_id,
|
||||
device_id=device_id,
|
||||
language=language,
|
||||
agent_id=agent_id,
|
||||
extra_system_prompt=extra_system_prompt,
|
||||
)
|
||||
|
||||
async def stream_llm_response():
|
||||
yield {"role": "assistant"}
|
||||
for chunk in to_stream_tts:
|
||||
yield {"content": chunk}
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||
conversation.async_get_chat_log(
|
||||
hass,
|
||||
session,
|
||||
conversation_input,
|
||||
) as chat_log,
|
||||
):
|
||||
async for _content in chat_log.async_add_delta_content_stream(
|
||||
agent_id, stream_llm_response()
|
||||
):
|
||||
pass
|
||||
intent_response = intent.IntentResponse(language)
|
||||
intent_response.async_set_speech("".join(to_stream_tts))
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=chat_log.conversation_id,
|
||||
continue_conversation=chat_log.continue_conversation,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
mock_converse,
|
||||
):
|
||||
await pipeline_input.execute()
|
||||
|
||||
stream = tts.async_get_stream(hass, events[0].data["tts_output"]["token"])
|
||||
assert stream is not None
|
||||
tts_result = "".join(
|
||||
[chunk.decode() async for chunk in stream.async_stream_result()]
|
||||
)
|
||||
|
||||
streamed_text = "".join(to_stream_tts)
|
||||
assert tts_result == streamed_text
|
||||
assert len(received_tts) == 1
|
||||
assert "".join(received_tts) == streamed_text
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
@ -1153,9 +1153,9 @@ async def test_get_pipeline(
|
||||
"name": "Home Assistant",
|
||||
"stt_engine": "stt.mock_stt",
|
||||
"stt_language": "en-US",
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-US",
|
||||
"tts_voice": "james_earl_jones",
|
||||
"tts_engine": "tts.test",
|
||||
"tts_language": "en_US",
|
||||
"tts_voice": None,
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
"prefer_local_intents": False,
|
||||
@ -1179,9 +1179,9 @@ async def test_get_pipeline(
|
||||
# It found these defaults
|
||||
"stt_engine": "stt.mock_stt",
|
||||
"stt_language": "en-US",
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-US",
|
||||
"tts_voice": "james_earl_jones",
|
||||
"tts_engine": "tts.test",
|
||||
"tts_language": "en_US",
|
||||
"tts_voice": None,
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
"prefer_local_intents": False,
|
||||
@ -1266,9 +1266,9 @@ async def test_list_pipelines(
|
||||
"name": "Home Assistant",
|
||||
"stt_engine": "stt.mock_stt",
|
||||
"stt_language": "en-US",
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-US",
|
||||
"tts_voice": "james_earl_jones",
|
||||
"tts_engine": "tts.test",
|
||||
"tts_language": "en_US",
|
||||
"tts_voice": None,
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
"prefer_local_intents": False,
|
||||
|
Loading…
x
Reference in New Issue
Block a user