mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
Add a test for Assist Pipeline streaming deltas to TTS (#144711)
* Add a test for Assist Pipeline streaming deltas to TTS * Adjust tests to new TTS engine
This commit is contained in:
parent
d471de5645
commit
2266e97417
@ -37,7 +37,7 @@ from tests.common import (
|
|||||||
mock_platform,
|
mock_platform,
|
||||||
)
|
)
|
||||||
from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity
|
from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity
|
||||||
from tests.components.tts.common import MockTTSProvider
|
from tests.components.tts.common import MockTTSEntity, MockTTSProvider
|
||||||
|
|
||||||
_TRANSCRIPT = "test transcript"
|
_TRANSCRIPT = "test transcript"
|
||||||
|
|
||||||
@ -68,6 +68,15 @@ async def mock_tts_provider() -> MockTTSProvider:
|
|||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tts_entity() -> MockTTSEntity:
|
||||||
|
"""Test TTS entity."""
|
||||||
|
entity = MockTTSEntity("en")
|
||||||
|
entity._attr_unique_id = "test_tts"
|
||||||
|
entity._attr_supported_languages = ["en-US"]
|
||||||
|
return entity
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mock_stt_provider() -> MockSTTProvider:
|
async def mock_stt_provider() -> MockSTTProvider:
|
||||||
"""Mock STT provider."""
|
"""Mock STT provider."""
|
||||||
@ -198,6 +207,7 @@ async def init_supporting_components(
|
|||||||
mock_stt_provider: MockSTTProvider,
|
mock_stt_provider: MockSTTProvider,
|
||||||
mock_stt_provider_entity: MockSTTProviderEntity,
|
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||||
mock_tts_provider: MockTTSProvider,
|
mock_tts_provider: MockTTSProvider,
|
||||||
|
mock_tts_entity: MockTTSEntity,
|
||||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
mock_wake_word_provider_entity2: MockWakeWordEntity2,
|
mock_wake_word_provider_entity2: MockWakeWordEntity2,
|
||||||
config_flow_fixture,
|
config_flow_fixture,
|
||||||
@ -209,7 +219,7 @@ async def init_supporting_components(
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""Set up test config entry."""
|
"""Set up test config entry."""
|
||||||
await hass.config_entries.async_forward_entry_setups(
|
await hass.config_entries.async_forward_entry_setups(
|
||||||
config_entry, [Platform.STT, Platform.WAKE_WORD]
|
config_entry, [Platform.STT, Platform.TTS, Platform.WAKE_WORD]
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -230,6 +240,14 @@ async def init_supporting_components(
|
|||||||
"""Set up test stt platform via config entry."""
|
"""Set up test stt platform via config entry."""
|
||||||
async_add_entities([mock_stt_provider_entity])
|
async_add_entities([mock_stt_provider_entity])
|
||||||
|
|
||||||
|
async def async_setup_entry_tts_platform(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up test tts platform via config entry."""
|
||||||
|
async_add_entities([mock_tts_entity])
|
||||||
|
|
||||||
async def async_setup_entry_wake_word_platform(
|
async def async_setup_entry_wake_word_platform(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config_entry: ConfigEntry,
|
config_entry: ConfigEntry,
|
||||||
@ -253,6 +271,7 @@ async def init_supporting_components(
|
|||||||
"test.tts",
|
"test.tts",
|
||||||
MockTTSPlatform(
|
MockTTSPlatform(
|
||||||
async_get_engine=AsyncMock(return_value=mock_tts_provider),
|
async_get_engine=AsyncMock(return_value=mock_tts_provider),
|
||||||
|
async_setup_entry=async_setup_entry_tts_platform,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
mock_platform(
|
mock_platform(
|
||||||
|
@ -74,17 +74,17 @@
|
|||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'engine': 'test',
|
'engine': 'tts.test',
|
||||||
'language': 'en-US',
|
'language': 'en_US',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': 'james_earl_jones',
|
'voice': None,
|
||||||
}),
|
}),
|
||||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
@ -395,17 +395,17 @@
|
|||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'engine': 'test',
|
'engine': 'tts.test',
|
||||||
'language': 'en-US',
|
'language': 'en_US',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': 'james_earl_jones',
|
'voice': None,
|
||||||
}),
|
}),
|
||||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
|
@ -1,4 +1,158 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
|
# name: test_chat_log_tts_streaming[to_stream_tts0]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'tts_output': dict({
|
||||||
|
'mime_type': 'audio/mpeg',
|
||||||
|
'token': 'mocked-token.mp3',
|
||||||
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'device_id': None,
|
||||||
|
'engine': 'test-agent',
|
||||||
|
'intent_input': 'Set a timer',
|
||||||
|
'language': 'en',
|
||||||
|
'prefer_local_intents': False,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'hello,',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': ' ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'how',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': ' ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'are',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': ' ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'you',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': '?',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'continue_conversation': True,
|
||||||
|
'conversation_id': <ANY>,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'failed': list([
|
||||||
|
]),
|
||||||
|
'success': list([
|
||||||
|
]),
|
||||||
|
'targets': list([
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'language': 'en',
|
||||||
|
'response_type': 'action_done',
|
||||||
|
'speech': dict({
|
||||||
|
'plain': dict({
|
||||||
|
'extra_data': None,
|
||||||
|
'speech': 'hello, how are you?',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'processed_locally': False,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'engine': 'tts.test',
|
||||||
|
'language': 'en_US',
|
||||||
|
'tts_input': 'hello, how are you?',
|
||||||
|
'voice': None,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'tts_output': dict({
|
||||||
|
'media_id': 'media-source://tts/tts.test?message=hello,+how+are+you?&language=en_US&tts_options=%7B%7D',
|
||||||
|
'mime_type': 'audio/mpeg',
|
||||||
|
'token': 'mocked-token.mp3',
|
||||||
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': None,
|
||||||
|
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_pipeline_language_used_instead_of_conversation_language
|
# name: test_pipeline_language_used_instead_of_conversation_language
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
|
@ -71,16 +71,16 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline.5
|
# name: test_audio_pipeline.5
|
||||||
dict({
|
dict({
|
||||||
'engine': 'test',
|
'engine': 'tts.test',
|
||||||
'language': 'en-US',
|
'language': 'en_US',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': 'james_earl_jones',
|
'voice': None,
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline.6
|
# name: test_audio_pipeline.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
@ -162,16 +162,16 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_debug.5
|
# name: test_audio_pipeline_debug.5
|
||||||
dict({
|
dict({
|
||||||
'engine': 'test',
|
'engine': 'tts.test',
|
||||||
'language': 'en-US',
|
'language': 'en_US',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': 'james_earl_jones',
|
'voice': None,
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_debug.6
|
# name: test_audio_pipeline_debug.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
@ -265,16 +265,16 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_with_enhancements.5
|
# name: test_audio_pipeline_with_enhancements.5
|
||||||
dict({
|
dict({
|
||||||
'engine': 'test',
|
'engine': 'tts.test',
|
||||||
'language': 'en-US',
|
'language': 'en_US',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': 'james_earl_jones',
|
'voice': None,
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_with_enhancements.6
|
# name: test_audio_pipeline_with_enhancements.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
@ -378,16 +378,16 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_with_wake_word_no_timeout.7
|
# name: test_audio_pipeline_with_wake_word_no_timeout.7
|
||||||
dict({
|
dict({
|
||||||
'engine': 'test',
|
'engine': 'tts.test',
|
||||||
'language': 'en-US',
|
'language': 'en_US',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': 'james_earl_jones',
|
'voice': None,
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_with_wake_word_no_timeout.8
|
# name: test_audio_pipeline_with_wake_word_no_timeout.8
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
|
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
|
@ -40,6 +40,7 @@ from . import MANY_LANGUAGES, process_events
|
|||||||
from .conftest import (
|
from .conftest import (
|
||||||
MockSTTProvider,
|
MockSTTProvider,
|
||||||
MockSTTProviderEntity,
|
MockSTTProviderEntity,
|
||||||
|
MockTTSEntity,
|
||||||
MockTTSProvider,
|
MockTTSProvider,
|
||||||
MockWakeWordEntity,
|
MockWakeWordEntity,
|
||||||
make_10ms_chunk,
|
make_10ms_chunk,
|
||||||
@ -62,6 +63,12 @@ async def load_homeassistant(hass: HomeAssistant) -> None:
|
|||||||
assert await async_setup_component(hass, "homeassistant", {})
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def disable_tts_entity(mock_tts_entity: tts.TextToSpeechEntity) -> None:
|
||||||
|
"""Disable the TTS entity."""
|
||||||
|
mock_tts_entity._attr_entity_registry_enabled_default = False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("init_components")
|
@pytest.mark.usefixtures("init_components")
|
||||||
async def test_load_pipelines(hass: HomeAssistant) -> None:
|
async def test_load_pipelines(hass: HomeAssistant) -> None:
|
||||||
"""Make sure that we can load/save data correctly."""
|
"""Make sure that we can load/save data correctly."""
|
||||||
@ -283,6 +290,7 @@ async def test_migrate_pipeline_store(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("init_supporting_components")
|
@pytest.mark.usefixtures("init_supporting_components")
|
||||||
|
@pytest.mark.usefixtures("disable_tts_entity")
|
||||||
async def test_create_default_pipeline(hass: HomeAssistant) -> None:
|
async def test_create_default_pipeline(hass: HomeAssistant) -> None:
|
||||||
"""Test async_create_default_pipeline."""
|
"""Test async_create_default_pipeline."""
|
||||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
@ -430,6 +438,7 @@ async def test_default_pipeline_no_stt_tts(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.usefixtures("init_supporting_components")
|
@pytest.mark.usefixtures("init_supporting_components")
|
||||||
|
@pytest.mark.usefixtures("disable_tts_entity")
|
||||||
async def test_default_pipeline(
|
async def test_default_pipeline(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider_entity: MockSTTProviderEntity,
|
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||||
@ -474,6 +483,7 @@ async def test_default_pipeline(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("init_supporting_components")
|
@pytest.mark.usefixtures("init_supporting_components")
|
||||||
|
@pytest.mark.usefixtures("disable_tts_entity")
|
||||||
async def test_default_pipeline_unsupported_stt_language(
|
async def test_default_pipeline_unsupported_stt_language(
|
||||||
hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity
|
hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -504,6 +514,7 @@ async def test_default_pipeline_unsupported_stt_language(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("init_supporting_components")
|
@pytest.mark.usefixtures("init_supporting_components")
|
||||||
|
@pytest.mark.usefixtures("disable_tts_entity")
|
||||||
async def test_default_pipeline_unsupported_tts_language(
|
async def test_default_pipeline_unsupported_tts_language(
|
||||||
hass: HomeAssistant, mock_tts_provider: MockTTSProvider
|
hass: HomeAssistant, mock_tts_provider: MockTTSProvider
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -825,7 +836,7 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
|
|||||||
async def test_tts_audio_output(
|
async def test_tts_audio_output(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
mock_tts_provider: MockTTSProvider,
|
mock_tts_entity: MockTTSProvider,
|
||||||
init_components,
|
init_components,
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
mock_chat_session: chat_session.ChatSession,
|
mock_chat_session: chat_session.ChatSession,
|
||||||
@ -869,7 +880,7 @@ async def test_tts_audio_output(
|
|||||||
== 1
|
== 1
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
|
with patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio:
|
||||||
await pipeline_input.execute()
|
await pipeline_input.execute()
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
@ -881,14 +892,14 @@ async def test_tts_audio_output(
|
|||||||
# Ensure that no unsupported options were passed in
|
# Ensure that no unsupported options were passed in
|
||||||
assert mock_get_tts_audio.called
|
assert mock_get_tts_audio.called
|
||||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||||
extra_options = set(options).difference(mock_tts_provider.supported_options)
|
extra_options = set(options).difference(mock_tts_entity.supported_options)
|
||||||
assert len(extra_options) == 0, extra_options
|
assert len(extra_options) == 0, extra_options
|
||||||
|
|
||||||
|
|
||||||
async def test_tts_wav_preferred_format(
|
async def test_tts_wav_preferred_format(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
mock_tts_provider: MockTTSProvider,
|
mock_tts_entity: MockTTSEntity,
|
||||||
init_components,
|
init_components,
|
||||||
mock_chat_session: chat_session.ChatSession,
|
mock_chat_session: chat_session.ChatSession,
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
@ -920,7 +931,7 @@ async def test_tts_wav_preferred_format(
|
|||||||
await pipeline_input.validate()
|
await pipeline_input.validate()
|
||||||
|
|
||||||
# Make the TTS provider support preferred format options
|
# Make the TTS provider support preferred format options
|
||||||
supported_options = list(mock_tts_provider.supported_options or [])
|
supported_options = list(mock_tts_entity.supported_options or [])
|
||||||
supported_options.extend(
|
supported_options.extend(
|
||||||
[
|
[
|
||||||
tts.ATTR_PREFERRED_FORMAT,
|
tts.ATTR_PREFERRED_FORMAT,
|
||||||
@ -931,8 +942,8 @@ async def test_tts_wav_preferred_format(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
patch.object(mock_tts_entity, "_supported_options", supported_options),
|
||||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio,
|
||||||
):
|
):
|
||||||
await pipeline_input.execute()
|
await pipeline_input.execute()
|
||||||
|
|
||||||
@ -955,7 +966,7 @@ async def test_tts_wav_preferred_format(
|
|||||||
async def test_tts_dict_preferred_format(
|
async def test_tts_dict_preferred_format(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
mock_tts_provider: MockTTSProvider,
|
mock_tts_entity: MockTTSEntity,
|
||||||
init_components,
|
init_components,
|
||||||
mock_chat_session: chat_session.ChatSession,
|
mock_chat_session: chat_session.ChatSession,
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
@ -992,7 +1003,7 @@ async def test_tts_dict_preferred_format(
|
|||||||
await pipeline_input.validate()
|
await pipeline_input.validate()
|
||||||
|
|
||||||
# Make the TTS provider support preferred format options
|
# Make the TTS provider support preferred format options
|
||||||
supported_options = list(mock_tts_provider.supported_options or [])
|
supported_options = list(mock_tts_entity.supported_options or [])
|
||||||
supported_options.extend(
|
supported_options.extend(
|
||||||
[
|
[
|
||||||
tts.ATTR_PREFERRED_FORMAT,
|
tts.ATTR_PREFERRED_FORMAT,
|
||||||
@ -1003,8 +1014,8 @@ async def test_tts_dict_preferred_format(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
patch.object(mock_tts_entity, "_supported_options", supported_options),
|
||||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio,
|
||||||
):
|
):
|
||||||
await pipeline_input.execute()
|
await pipeline_input.execute()
|
||||||
|
|
||||||
@ -1545,3 +1556,143 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
|||||||
mock_async_converse.call_args_list[0].kwargs.get("language")
|
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||||
== pipeline.language
|
== pipeline.language
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"to_stream_tts",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
"hello,",
|
||||||
|
" ",
|
||||||
|
"how",
|
||||||
|
" ",
|
||||||
|
"are",
|
||||||
|
" ",
|
||||||
|
"you",
|
||||||
|
"?",
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_chat_log_tts_streaming(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
mock_chat_session: chat_session.ChatSession,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
mock_tts_entity: MockTTSEntity,
|
||||||
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
|
to_stream_tts: list[str],
|
||||||
|
) -> None:
|
||||||
|
"""Test that chat log events are streamed to the TTS entity."""
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
await assist_pipeline.pipeline.async_update_pipeline(
|
||||||
|
hass, pipeline, conversation_engine="test-agent"
|
||||||
|
)
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="Set a timer",
|
||||||
|
session=mock_chat_session,
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||||
|
event_callback=events.append,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
received_tts = []
|
||||||
|
|
||||||
|
async def async_stream_tts_audio(
|
||||||
|
request: tts.TTSAudioRequest,
|
||||||
|
) -> tts.TTSAudioResponse:
|
||||||
|
"""Mock stream TTS audio."""
|
||||||
|
|
||||||
|
async def gen_data():
|
||||||
|
async for msg in request.message_gen:
|
||||||
|
received_tts.append(msg)
|
||||||
|
yield msg.encode()
|
||||||
|
|
||||||
|
return tts.TTSAudioResponse(
|
||||||
|
extension="mp3",
|
||||||
|
data_gen=gen_data(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
|
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||||
|
):
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
async def mock_converse(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
text: str,
|
||||||
|
conversation_id: str | None,
|
||||||
|
context: Context,
|
||||||
|
language: str | None = None,
|
||||||
|
agent_id: str | None = None,
|
||||||
|
device_id: str | None = None,
|
||||||
|
extra_system_prompt: str | None = None,
|
||||||
|
):
|
||||||
|
"""Mock converse."""
|
||||||
|
conversation_input = conversation.ConversationInput(
|
||||||
|
text=text,
|
||||||
|
context=context,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
device_id=device_id,
|
||||||
|
language=language,
|
||||||
|
agent_id=agent_id,
|
||||||
|
extra_system_prompt=extra_system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stream_llm_response():
|
||||||
|
yield {"role": "assistant"}
|
||||||
|
for chunk in to_stream_tts:
|
||||||
|
yield {"content": chunk}
|
||||||
|
|
||||||
|
with (
|
||||||
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
|
conversation.async_get_chat_log(
|
||||||
|
hass,
|
||||||
|
session,
|
||||||
|
conversation_input,
|
||||||
|
) as chat_log,
|
||||||
|
):
|
||||||
|
async for _content in chat_log.async_add_delta_content_stream(
|
||||||
|
agent_id, stream_llm_response()
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
intent_response = intent.IntentResponse(language)
|
||||||
|
intent_response.async_set_speech("".join(to_stream_tts))
|
||||||
|
return conversation.ConversationResult(
|
||||||
|
response=intent_response,
|
||||||
|
conversation_id=chat_log.conversation_id,
|
||||||
|
continue_conversation=chat_log.continue_conversation,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||||
|
mock_converse,
|
||||||
|
):
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
stream = tts.async_get_stream(hass, events[0].data["tts_output"]["token"])
|
||||||
|
assert stream is not None
|
||||||
|
tts_result = "".join(
|
||||||
|
[chunk.decode() async for chunk in stream.async_stream_result()]
|
||||||
|
)
|
||||||
|
|
||||||
|
streamed_text = "".join(to_stream_tts)
|
||||||
|
assert tts_result == streamed_text
|
||||||
|
assert len(received_tts) == 1
|
||||||
|
assert "".join(received_tts) == streamed_text
|
||||||
|
|
||||||
|
assert process_events(events) == snapshot
|
||||||
|
@ -1153,9 +1153,9 @@ async def test_get_pipeline(
|
|||||||
"name": "Home Assistant",
|
"name": "Home Assistant",
|
||||||
"stt_engine": "stt.mock_stt",
|
"stt_engine": "stt.mock_stt",
|
||||||
"stt_language": "en-US",
|
"stt_language": "en-US",
|
||||||
"tts_engine": "test",
|
"tts_engine": "tts.test",
|
||||||
"tts_language": "en-US",
|
"tts_language": "en_US",
|
||||||
"tts_voice": "james_earl_jones",
|
"tts_voice": None,
|
||||||
"wake_word_entity": None,
|
"wake_word_entity": None,
|
||||||
"wake_word_id": None,
|
"wake_word_id": None,
|
||||||
"prefer_local_intents": False,
|
"prefer_local_intents": False,
|
||||||
@ -1179,9 +1179,9 @@ async def test_get_pipeline(
|
|||||||
# It found these defaults
|
# It found these defaults
|
||||||
"stt_engine": "stt.mock_stt",
|
"stt_engine": "stt.mock_stt",
|
||||||
"stt_language": "en-US",
|
"stt_language": "en-US",
|
||||||
"tts_engine": "test",
|
"tts_engine": "tts.test",
|
||||||
"tts_language": "en-US",
|
"tts_language": "en_US",
|
||||||
"tts_voice": "james_earl_jones",
|
"tts_voice": None,
|
||||||
"wake_word_entity": None,
|
"wake_word_entity": None,
|
||||||
"wake_word_id": None,
|
"wake_word_id": None,
|
||||||
"prefer_local_intents": False,
|
"prefer_local_intents": False,
|
||||||
@ -1266,9 +1266,9 @@ async def test_list_pipelines(
|
|||||||
"name": "Home Assistant",
|
"name": "Home Assistant",
|
||||||
"stt_engine": "stt.mock_stt",
|
"stt_engine": "stt.mock_stt",
|
||||||
"stt_language": "en-US",
|
"stt_language": "en-US",
|
||||||
"tts_engine": "test",
|
"tts_engine": "tts.test",
|
||||||
"tts_language": "en-US",
|
"tts_language": "en_US",
|
||||||
"tts_voice": "james_earl_jones",
|
"tts_voice": None,
|
||||||
"wake_word_entity": None,
|
"wake_word_entity": None,
|
||||||
"wake_word_id": None,
|
"wake_word_id": None,
|
||||||
"prefer_local_intents": False,
|
"prefer_local_intents": False,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user