From 4faa920318b0d7ccee97a5ab3f1d325c790bc563 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 11 May 2025 15:38:21 -0400 Subject: [PATCH] Move Assist Pipeline tests to right file (#144696) --- tests/components/assist_pipeline/__init__.py | 18 + .../assist_pipeline/snapshots/test_init.ambr | 201 ---- .../snapshots/test_pipeline.ambr | 202 ++++ tests/components/assist_pipeline/test_init.py | 847 +---------------- .../assist_pipeline/test_pipeline.py | 862 +++++++++++++++++- 5 files changed, 1079 insertions(+), 1051 deletions(-) create mode 100644 tests/components/assist_pipeline/snapshots/test_pipeline.ambr diff --git a/tests/components/assist_pipeline/__init__.py b/tests/components/assist_pipeline/__init__.py index dd0f80e52ad..cc11fcc6c82 100644 --- a/tests/components/assist_pipeline/__init__.py +++ b/tests/components/assist_pipeline/__init__.py @@ -1,5 +1,10 @@ """Tests for the Voice Assistant integration.""" +from dataclasses import asdict +from unittest.mock import ANY + +from homeassistant.components import assist_pipeline + MANY_LANGUAGES = [ "ar", "bg", @@ -54,3 +59,16 @@ MANY_LANGUAGES = [ "zh-hk", "zh-tw", ] + + +def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]: + """Process events to remove dynamic values.""" + processed = [] + for event in events: + as_dict = asdict(event) + as_dict.pop("timestamp") + if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START: + as_dict["data"]["pipeline"] = ANY + processed.append(as_dict) + + return processed diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index f772f877d3a..81972191868 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -461,204 +461,3 @@ }), ]) # --- -# name: test_pipeline_language_used_instead_of_conversation_language - list([ - dict({ - 'data': dict({ - 'conversation_id': 'mock-ulid', - 'language': 'en', - 'pipeline': , - }), - 'type': , - }), - dict({ - 'data': dict({ - 'conversation_id': 'mock-ulid', - 'device_id': None, - 'engine': 'conversation.home_assistant', - 'intent_input': 'test input', - 'language': 'en', - 'prefer_local_intents': False, - }), - 'type': , - }), - dict({ - 'data': dict({ - 'intent_output': dict({ - 'continue_conversation': False, - 'conversation_id': , - 'response': dict({ - 'card': dict({ - }), - 'data': dict({ - 'failed': list([ - ]), - 'success': list([ - ]), - 'targets': list([ - ]), - }), - 'language': 'en', - 'response_type': 'action_done', - 'speech': dict({ - }), - }), - }), - 'processed_locally': True, - }), - 'type': , - }), - dict({ - 'data': None, - 'type': , - }), - ]) -# --- -# name: test_stt_language_used_instead_of_conversation_language - list([ - dict({ - 'data': dict({ - 'conversation_id': 'mock-ulid', - 'language': 'en', - 'pipeline': , - }), - 'type': , - }), - dict({ - 'data': dict({ - 'conversation_id': 'mock-ulid', - 'device_id': None, - 'engine': 'conversation.home_assistant', - 'intent_input': 'test input', - 'language': 'en-US', - 'prefer_local_intents': False, - }), - 'type': , - }), - dict({ - 'data': dict({ - 'intent_output': dict({ - 'continue_conversation': False, - 'conversation_id': , - 'response': dict({ - 'card': dict({ - }), - 'data': dict({ - 'failed': list([ - ]), - 'success': list([ - ]), - 'targets': list([ - ]), - }), - 'language': 'en', - 'response_type': 'action_done', - 'speech': dict({ - }), - }), - }), - 'processed_locally': True, - }), - 'type': , - }), - dict({ - 'data': None, - 'type': , - }), - ]) -# --- -# name: test_tts_language_used_instead_of_conversation_language - list([ - dict({ - 'data': dict({ - 'conversation_id': 'mock-ulid', - 'language': 'en', - 'pipeline': , - }), - 'type': , - }), - dict({ - 'data': dict({ - 'conversation_id': 'mock-ulid', - 'device_id': None, - 'engine': 'conversation.home_assistant', - 'intent_input': 'test input', - 'language': 'en-us', - 'prefer_local_intents': False, - }), - 'type': , - }), - dict({ - 'data': dict({ - 'intent_output': dict({ - 'continue_conversation': False, - 'conversation_id': , - 'response': dict({ - 'card': dict({ - }), - 'data': dict({ - 'failed': list([ - ]), - 'success': list([ - ]), - 'targets': list([ - ]), - }), - 'language': 'en', - 'response_type': 'action_done', - 'speech': dict({ - }), - }), - }), - 'processed_locally': True, - }), - 'type': , - }), - dict({ - 'data': None, - 'type': , - }), - ]) -# --- -# name: test_wake_word_detection_aborted - list([ - dict({ - 'data': dict({ - 'conversation_id': 'mock-ulid', - 'language': 'en', - 'pipeline': , - 'tts_output': dict({ - 'mime_type': 'audio/mpeg', - 'token': 'mocked-token.mp3', - 'url': '/api/tts_proxy/mocked-token.mp3', - }), - }), - 'type': , - }), - dict({ - 'data': dict({ - 'entity_id': 'wake_word.test', - 'metadata': dict({ - 'bit_rate': , - 'channel': , - 'codec': , - 'format': , - 'sample_rate': , - }), - 'timeout': 0, - }), - 'type': , - }), - dict({ - 'data': dict({ - 'code': 'wake_word_detection_aborted', - 'message': '', - }), - 'type': , - }), - dict({ - 'data': None, - 'type': , - }), - ]) -# --- diff --git a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr new file mode 100644 index 00000000000..7c0ac254b6e --- /dev/null +++ b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr @@ -0,0 +1,202 @@ +# serializer version: 1 +# name: test_pipeline_language_used_instead_of_conversation_language + list([ + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'language': 'en', + 'pipeline': , + }), + 'type': , + }), + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'device_id': None, + 'engine': 'conversation.home_assistant', + 'intent_input': 'test input', + 'language': 'en', + 'prefer_local_intents': False, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'continue_conversation': False, + 'conversation_id': , + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + }), + }), + }), + 'processed_locally': True, + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- +# name: test_stt_language_used_instead_of_conversation_language + list([ + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'language': 'en', + 'pipeline': , + }), + 'type': , + }), + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'device_id': None, + 'engine': 'conversation.home_assistant', + 'intent_input': 'test input', + 'language': 'en-US', + 'prefer_local_intents': False, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'continue_conversation': False, + 'conversation_id': , + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + }), + }), + }), + 'processed_locally': True, + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- +# name: test_tts_language_used_instead_of_conversation_language + list([ + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'language': 'en', + 'pipeline': , + }), + 'type': , + }), + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'device_id': None, + 'engine': 'conversation.home_assistant', + 'intent_input': 'test input', + 'language': 'en-us', + 'prefer_local_intents': False, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'continue_conversation': False, + 'conversation_id': , + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + }), + }), + }), + 'processed_locally': True, + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- +# name: test_wake_word_detection_aborted + list([ + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'language': 'en', + 'pipeline': , + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'entity_id': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': , + 'channel': , + 'codec': , + 'format': , + 'sample_rate': , + }), + 'timeout': 0, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'code': 'wake_word_detection_aborted', + 'message': '', + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 0e04d1f0cd2..0294f9953db 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -2,44 +2,35 @@ import asyncio from collections.abc import Generator -from dataclasses import asdict import itertools as it from pathlib import Path import tempfile -from unittest.mock import ANY, Mock, patch +from unittest.mock import Mock, patch import wave import hass_nabucasa import pytest from syrupy.assertion import SnapshotAssertion -from homeassistant.components import ( - assist_pipeline, - conversation, - media_source, - stt, - tts, -) +from homeassistant.components import assist_pipeline, stt from homeassistant.components.assist_pipeline.const import ( BYTES_PER_CHUNK, CONF_DEBUG_RECORDING_DIR, DOMAIN, ) -from homeassistant.const import MATCH_ALL from homeassistant.core import Context, HomeAssistant -from homeassistant.helpers import chat_session, intent from homeassistant.setup import async_setup_component +from . import process_events from .conftest import ( BYTES_ONE_SECOND, MockSTTProvider, MockSTTProviderEntity, - MockTTSProvider, MockWakeWordEntity, make_10ms_chunk, ) -from tests.typing import ClientSessionGenerator, WebSocketGenerator +from tests.typing import WebSocketGenerator @pytest.fixture(autouse=True) @@ -58,19 +49,6 @@ def mock_tts_token() -> Generator[None]: yield -def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]: - """Process events to remove dynamic values.""" - processed = [] - for event in events: - as_dict = asdict(event) - as_dict.pop("timestamp") - if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START: - as_dict["data"]["pipeline"] = ANY - processed.append(as_dict) - - return processed - - async def test_pipeline_from_audio_stream_auto( hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity, @@ -677,823 +655,6 @@ async def test_pipeline_saved_audio_empty_queue( ) -async def test_wake_word_detection_aborted( - hass: HomeAssistant, - mock_stt_provider: MockSTTProvider, - mock_wake_word_provider_entity: MockWakeWordEntity, - init_components, - pipeline_data: assist_pipeline.pipeline.PipelineData, - mock_chat_session: chat_session.ChatSession, - snapshot: SnapshotAssertion, -) -> None: - """Test creating a pipeline from an audio stream with wake word.""" - - events: list[assist_pipeline.PipelineEvent] = [] - - async def audio_data(): - yield make_10ms_chunk(b"silence!") - yield make_10ms_chunk(b"wake word!") - yield make_10ms_chunk(b"part1") - yield make_10ms_chunk(b"part2") - yield b"" - - pipeline_store = pipeline_data.pipeline_store - pipeline_id = pipeline_store.async_get_preferred_item() - pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) - - pipeline_input = assist_pipeline.pipeline.PipelineInput( - session=mock_chat_session, - device_id=None, - stt_metadata=stt.SpeechMetadata( - language="", - format=stt.AudioFormats.WAV, - codec=stt.AudioCodecs.PCM, - bit_rate=stt.AudioBitRates.BITRATE_16, - sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, - channel=stt.AudioChannels.CHANNEL_MONO, - ), - stt_stream=audio_data(), - run=assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.WAKE_WORD, - end_stage=assist_pipeline.PipelineStage.TTS, - event_callback=events.append, - tts_audio_output=None, - wake_word_settings=assist_pipeline.WakeWordSettings( - audio_seconds_to_buffer=1.5 - ), - audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), - ), - ) - await pipeline_input.validate() - - updates = pipeline.to_json() - updates.pop("id") - await pipeline_store.async_update_item( - pipeline_id, - updates, - ) - await pipeline_input.execute() - - assert process_events(events) == snapshot - - -def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None: - """Test that pipeline run equality uses unique id.""" - - def event_callback(event): - pass - - pipeline = assist_pipeline.pipeline.async_get_pipeline(hass) - run_1 = assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.STT, - end_stage=assist_pipeline.PipelineStage.TTS, - event_callback=event_callback, - ) - run_2 = assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.STT, - end_stage=assist_pipeline.PipelineStage.TTS, - event_callback=event_callback, - ) - - assert run_1 == run_1 # noqa: PLR0124 - assert run_1 != run_2 - assert run_1 != 1234 - - -async def test_tts_audio_output( - hass: HomeAssistant, - hass_client: ClientSessionGenerator, - mock_tts_provider: MockTTSProvider, - init_components, - pipeline_data: assist_pipeline.pipeline.PipelineData, - mock_chat_session: chat_session.ChatSession, - snapshot: SnapshotAssertion, -) -> None: - """Test using tts_audio_output with wav sets options correctly.""" - client = await hass_client() - assert await async_setup_component(hass, media_source.DOMAIN, {}) - - events: list[assist_pipeline.PipelineEvent] = [] - - pipeline_store = pipeline_data.pipeline_store - pipeline_id = pipeline_store.async_get_preferred_item() - pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) - - pipeline_input = assist_pipeline.pipeline.PipelineInput( - tts_input="This is a test.", - session=mock_chat_session, - device_id=None, - run=assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.TTS, - end_stage=assist_pipeline.PipelineStage.TTS, - event_callback=events.append, - tts_audio_output="wav", - ), - ) - await pipeline_input.validate() - - # Verify TTS audio settings - assert pipeline_input.run.tts_stream.options is not None - assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" - assert ( - pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) - == 16000 - ) - assert ( - pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) - == 1 - ) - - with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio: - await pipeline_input.execute() - - for event in events: - if event.type == assist_pipeline.PipelineEventType.TTS_END: - # We must fetch the media URL to trigger the TTS - assert event.data - await client.get(event.data["tts_output"]["url"]) - - # Ensure that no unsupported options were passed in - assert mock_get_tts_audio.called - options = mock_get_tts_audio.call_args_list[0].kwargs["options"] - extra_options = set(options).difference(mock_tts_provider.supported_options) - assert len(extra_options) == 0, extra_options - - -async def test_tts_wav_preferred_format( - hass: HomeAssistant, - hass_client: ClientSessionGenerator, - mock_tts_provider: MockTTSProvider, - init_components, - mock_chat_session: chat_session.ChatSession, - pipeline_data: assist_pipeline.pipeline.PipelineData, -) -> None: - """Test that preferred format options are given to the TTS system if supported.""" - client = await hass_client() - assert await async_setup_component(hass, media_source.DOMAIN, {}) - - events: list[assist_pipeline.PipelineEvent] = [] - - pipeline_store = pipeline_data.pipeline_store - pipeline_id = pipeline_store.async_get_preferred_item() - pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) - - pipeline_input = assist_pipeline.pipeline.PipelineInput( - tts_input="This is a test.", - session=mock_chat_session, - device_id=None, - run=assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.TTS, - end_stage=assist_pipeline.PipelineStage.TTS, - event_callback=events.append, - tts_audio_output="wav", - ), - ) - await pipeline_input.validate() - - # Make the TTS provider support preferred format options - supported_options = list(mock_tts_provider.supported_options or []) - supported_options.extend( - [ - tts.ATTR_PREFERRED_FORMAT, - tts.ATTR_PREFERRED_SAMPLE_RATE, - tts.ATTR_PREFERRED_SAMPLE_CHANNELS, - tts.ATTR_PREFERRED_SAMPLE_BYTES, - ] - ) - - with ( - patch.object(mock_tts_provider, "_supported_options", supported_options), - patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio, - ): - await pipeline_input.execute() - - for event in events: - if event.type == assist_pipeline.PipelineEventType.TTS_END: - # We must fetch the media URL to trigger the TTS - assert event.data - await client.get(event.data["tts_output"]["url"]) - - assert mock_get_tts_audio.called - options = mock_get_tts_audio.call_args_list[0].kwargs["options"] - - # We should have received preferred format options in get_tts_audio - assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" - assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000 - assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1 - assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2 - - -async def test_tts_dict_preferred_format( - hass: HomeAssistant, - hass_client: ClientSessionGenerator, - mock_tts_provider: MockTTSProvider, - init_components, - mock_chat_session: chat_session.ChatSession, - pipeline_data: assist_pipeline.pipeline.PipelineData, -) -> None: - """Test that preferred format options are given to the TTS system if supported.""" - client = await hass_client() - assert await async_setup_component(hass, media_source.DOMAIN, {}) - - events: list[assist_pipeline.PipelineEvent] = [] - - pipeline_store = pipeline_data.pipeline_store - pipeline_id = pipeline_store.async_get_preferred_item() - pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) - - pipeline_input = assist_pipeline.pipeline.PipelineInput( - tts_input="This is a test.", - session=mock_chat_session, - device_id=None, - run=assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.TTS, - end_stage=assist_pipeline.PipelineStage.TTS, - event_callback=events.append, - tts_audio_output={ - tts.ATTR_PREFERRED_FORMAT: "flac", - tts.ATTR_PREFERRED_SAMPLE_RATE: 48000, - tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2, - tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, - }, - ), - ) - await pipeline_input.validate() - - # Make the TTS provider support preferred format options - supported_options = list(mock_tts_provider.supported_options or []) - supported_options.extend( - [ - tts.ATTR_PREFERRED_FORMAT, - tts.ATTR_PREFERRED_SAMPLE_RATE, - tts.ATTR_PREFERRED_SAMPLE_CHANNELS, - tts.ATTR_PREFERRED_SAMPLE_BYTES, - ] - ) - - with ( - patch.object(mock_tts_provider, "_supported_options", supported_options), - patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio, - ): - await pipeline_input.execute() - - for event in events: - if event.type == assist_pipeline.PipelineEventType.TTS_END: - # We must fetch the media URL to trigger the TTS - assert event.data - await client.get(event.data["tts_output"]["url"]) - - assert mock_get_tts_audio.called - options = mock_get_tts_audio.call_args_list[0].kwargs["options"] - - # We should have received preferred format options in get_tts_audio - assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac" - assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000 - assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2 - assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2 - - -async def test_sentence_trigger_overrides_conversation_agent( - hass: HomeAssistant, - init_components, - mock_chat_session: chat_session.ChatSession, - pipeline_data: assist_pipeline.pipeline.PipelineData, -) -> None: - """Test that sentence triggers are checked before a non-default conversation agent.""" - assert await async_setup_component( - hass, - "automation", - { - "automation": { - "trigger": { - "platform": "conversation", - "command": [ - "test trigger sentence", - ], - }, - "action": { - "set_conversation_response": "test trigger response", - }, - } - }, - ) - - events: list[assist_pipeline.PipelineEvent] = [] - - pipeline_store = pipeline_data.pipeline_store - pipeline_id = pipeline_store.async_get_preferred_item() - pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) - - pipeline_input = assist_pipeline.pipeline.PipelineInput( - intent_input="test trigger sentence", - session=mock_chat_session, - run=assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.INTENT, - end_stage=assist_pipeline.PipelineStage.INTENT, - event_callback=events.append, - intent_agent="test-agent", # not the default agent - ), - ) - - # Ensure prepare succeeds - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", - return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), - ): - await pipeline_input.validate() - - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse" - ) as mock_async_converse: - await pipeline_input.execute() - - # Sentence trigger should have been handled - mock_async_converse.assert_not_called() - - # Verify sentence trigger response - intent_end_event = next( - ( - e - for e in events - if e.type == assist_pipeline.PipelineEventType.INTENT_END - ), - None, - ) - assert (intent_end_event is not None) and intent_end_event.data - assert ( - intent_end_event.data["intent_output"]["response"]["speech"]["plain"][ - "speech" - ] - == "test trigger response" - ) - - -async def test_prefer_local_intents( - hass: HomeAssistant, - init_components, - mock_chat_session: chat_session.ChatSession, - pipeline_data: assist_pipeline.pipeline.PipelineData, -) -> None: - """Test that the default agent is checked first when local intents are preferred.""" - events: list[assist_pipeline.PipelineEvent] = [] - - # Reuse custom sentences in test config - class OrderBeerIntentHandler(intent.IntentHandler): - intent_type = "OrderBeer" - - async def async_handle( - self, intent_obj: intent.Intent - ) -> intent.IntentResponse: - response = intent_obj.create_response() - response.async_set_speech("Order confirmed") - return response - - handler = OrderBeerIntentHandler() - intent.async_register(hass, handler) - - # Fake a test agent and prefer local intents - pipeline_store = pipeline_data.pipeline_store - pipeline_id = pipeline_store.async_get_preferred_item() - pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) - await assist_pipeline.pipeline.async_update_pipeline( - hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True - ) - pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) - - pipeline_input = assist_pipeline.pipeline.PipelineInput( - intent_input="I'd like to order a stout please", - session=mock_chat_session, - run=assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.INTENT, - end_stage=assist_pipeline.PipelineStage.INTENT, - event_callback=events.append, - ), - ) - - # Ensure prepare succeeds - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", - return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), - ): - await pipeline_input.validate() - - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse" - ) as mock_async_converse: - await pipeline_input.execute() - - # Test agent should not have been called - mock_async_converse.assert_not_called() - - # Verify local intent response - intent_end_event = next( - ( - e - for e in events - if e.type == assist_pipeline.PipelineEventType.INTENT_END - ), - None, - ) - assert (intent_end_event is not None) and intent_end_event.data - assert ( - intent_end_event.data["intent_output"]["response"]["speech"]["plain"][ - "speech" - ] - == "Order confirmed" - ) - - -async def test_intent_continue_conversation( - hass: HomeAssistant, - init_components, - mock_chat_session: chat_session.ChatSession, - pipeline_data: assist_pipeline.pipeline.PipelineData, -) -> None: - """Test that a conversation agent flagging continue conversation gets response.""" - events: list[assist_pipeline.PipelineEvent] = [] - - # Fake a test agent and prefer local intents - pipeline_store = pipeline_data.pipeline_store - pipeline_id = pipeline_store.async_get_preferred_item() - pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) - await assist_pipeline.pipeline.async_update_pipeline( - hass, pipeline, conversation_engine="test-agent" - ) - pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) - - pipeline_input = assist_pipeline.pipeline.PipelineInput( - intent_input="Set a timer", - session=mock_chat_session, - run=assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.INTENT, - end_stage=assist_pipeline.PipelineStage.INTENT, - event_callback=events.append, - ), - ) - - # Ensure prepare succeeds - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", - return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), - ): - await pipeline_input.validate() - - response = intent.IntentResponse("en") - response.async_set_speech("For how long?") - - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", - return_value=conversation.ConversationResult( - response=response, - conversation_id=mock_chat_session.conversation_id, - continue_conversation=True, - ), - ) as mock_async_converse: - await pipeline_input.execute() - - mock_async_converse.assert_called() - - results = [ - event.data - for event in events - if event.type - in ( - assist_pipeline.PipelineEventType.INTENT_START, - assist_pipeline.PipelineEventType.INTENT_END, - ) - ] - assert results[1]["intent_output"]["continue_conversation"] is True - - # Change conversation agent to default one and register sentence trigger that should not be called - await assist_pipeline.pipeline.async_update_pipeline( - hass, pipeline, conversation_engine=None - ) - pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) - assert await async_setup_component( - hass, - "automation", - { - "automation": { - "trigger": { - "platform": "conversation", - "command": ["Hello"], - }, - "action": { - "set_conversation_response": "test trigger response", - }, - } - }, - ) - - # Because we did continue conversation, it should respond to the test agent again. - events.clear() - - pipeline_input = assist_pipeline.pipeline.PipelineInput( - intent_input="Hello", - session=mock_chat_session, - run=assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.INTENT, - end_stage=assist_pipeline.PipelineStage.INTENT, - event_callback=events.append, - ), - ) - - # Ensure prepare succeeds - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", - return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), - ) as mock_prepare: - await pipeline_input.validate() - - # It requested test agent even if that was not default agent. - assert mock_prepare.mock_calls[0][1][1] == "test-agent" - - response = intent.IntentResponse("en") - response.async_set_speech("Timer set for 20 minutes") - - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", - return_value=conversation.ConversationResult( - response=response, - conversation_id=mock_chat_session.conversation_id, - ), - ) as mock_async_converse: - await pipeline_input.execute() - - mock_async_converse.assert_called() - - # Snapshot will show it was still handled by the test agent and not default agent - results = [ - event.data - for event in events - if event.type - in ( - assist_pipeline.PipelineEventType.INTENT_START, - assist_pipeline.PipelineEventType.INTENT_END, - ) - ] - assert results[0]["engine"] == "test-agent" - assert results[1]["intent_output"]["continue_conversation"] is False - - -async def test_stt_language_used_instead_of_conversation_language( - hass: HomeAssistant, - hass_ws_client: WebSocketGenerator, - init_components, - mock_chat_session: chat_session.ChatSession, - snapshot: SnapshotAssertion, -) -> None: - """Test that the STT language is used first when the conversation language is '*' (all languages).""" - client = await hass_ws_client(hass) - - events: list[assist_pipeline.PipelineEvent] = [] - - await client.send_json_auto_id( - { - "type": "assist_pipeline/pipeline/create", - "conversation_engine": "homeassistant", - "conversation_language": MATCH_ALL, - "language": "en", - "name": "test_name", - "stt_engine": "test", - "stt_language": "en-US", - "tts_engine": "test", - "tts_language": "en-US", - "tts_voice": "Arnold Schwarzenegger", - "wake_word_entity": None, - "wake_word_id": None, - } - ) - msg = await client.receive_json() - assert msg["success"] - pipeline_id = msg["result"]["id"] - pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) - - pipeline_input = assist_pipeline.pipeline.PipelineInput( - intent_input="test input", - session=mock_chat_session, - run=assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.INTENT, - end_stage=assist_pipeline.PipelineStage.INTENT, - event_callback=events.append, - ), - ) - await pipeline_input.validate() - - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", - return_value=conversation.ConversationResult( - intent.IntentResponse(pipeline.language) - ), - ) as mock_async_converse: - await pipeline_input.execute() - - # Check intent start event - assert process_events(events) == snapshot - intent_start: assist_pipeline.PipelineEvent | None = None - for event in events: - if event.type == assist_pipeline.PipelineEventType.INTENT_START: - intent_start = event - break - - assert intent_start is not None - - # STT language (en-US) should be used instead of '*' - assert intent_start.data.get("language") == pipeline.stt_language - - # Check input to async_converse - mock_async_converse.assert_called_once() - assert ( - mock_async_converse.call_args_list[0].kwargs.get("language") - == pipeline.stt_language - ) - - -async def test_tts_language_used_instead_of_conversation_language( - hass: HomeAssistant, - hass_ws_client: WebSocketGenerator, - init_components, - mock_chat_session: chat_session.ChatSession, - snapshot: SnapshotAssertion, -) -> None: - """Test that the TTS language is used after STT when the conversation language is '*' (all languages).""" - client = await hass_ws_client(hass) - - events: list[assist_pipeline.PipelineEvent] = [] - - await client.send_json_auto_id( - { - "type": "assist_pipeline/pipeline/create", - "conversation_engine": "homeassistant", - "conversation_language": MATCH_ALL, - "language": "en", - "name": "test_name", - "stt_engine": None, - "stt_language": None, - "tts_engine": None, - "tts_language": "en-us", - "tts_voice": "Arnold Schwarzenegger", - "wake_word_entity": None, - "wake_word_id": None, - } - ) - msg = await client.receive_json() - assert msg["success"] - pipeline_id = msg["result"]["id"] - pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) - - pipeline_input = assist_pipeline.pipeline.PipelineInput( - intent_input="test input", - session=mock_chat_session, - run=assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.INTENT, - end_stage=assist_pipeline.PipelineStage.INTENT, - event_callback=events.append, - ), - ) - await pipeline_input.validate() - - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", - return_value=conversation.ConversationResult( - intent.IntentResponse(pipeline.language) - ), - ) as mock_async_converse: - await pipeline_input.execute() - - # Check intent start event - assert process_events(events) == snapshot - intent_start: assist_pipeline.PipelineEvent | None = None - for event in events: - if event.type == assist_pipeline.PipelineEventType.INTENT_START: - intent_start = event - break - - assert intent_start is not None - - # STT language (en-US) should be used instead of '*' - assert intent_start.data.get("language") == pipeline.tts_language - - # Check input to async_converse - mock_async_converse.assert_called_once() - assert ( - mock_async_converse.call_args_list[0].kwargs.get("language") - == pipeline.tts_language - ) - - -async def test_pipeline_language_used_instead_of_conversation_language( - hass: HomeAssistant, - hass_ws_client: WebSocketGenerator, - init_components, - mock_chat_session: chat_session.ChatSession, - snapshot: SnapshotAssertion, -) -> None: - """Test that the pipeline language is used last when the conversation language is '*' (all languages).""" - client = await hass_ws_client(hass) - - events: list[assist_pipeline.PipelineEvent] = [] - - await client.send_json_auto_id( - { - "type": "assist_pipeline/pipeline/create", - "conversation_engine": "homeassistant", - "conversation_language": MATCH_ALL, - "language": "en", - "name": "test_name", - "stt_engine": None, - "stt_language": None, - "tts_engine": None, - "tts_language": None, - "tts_voice": None, - "wake_word_entity": None, - "wake_word_id": None, - } - ) - msg = await client.receive_json() - assert msg["success"] - pipeline_id = msg["result"]["id"] - pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) - - pipeline_input = assist_pipeline.pipeline.PipelineInput( - intent_input="test input", - session=mock_chat_session, - run=assist_pipeline.pipeline.PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - start_stage=assist_pipeline.PipelineStage.INTENT, - end_stage=assist_pipeline.PipelineStage.INTENT, - event_callback=events.append, - ), - ) - await pipeline_input.validate() - - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", - return_value=conversation.ConversationResult( - intent.IntentResponse(pipeline.language) - ), - ) as mock_async_converse: - await pipeline_input.execute() - - # Check intent start event - assert process_events(events) == snapshot - intent_start: assist_pipeline.PipelineEvent | None = None - for event in events: - if event.type == assist_pipeline.PipelineEventType.INTENT_START: - intent_start = event - break - - assert intent_start is not None - - # STT language (en-US) should be used instead of '*' - assert intent_start.data.get("language") == pipeline.language - - # Check input to async_converse - mock_async_converse.assert_called_once() - assert ( - mock_async_converse.call_args_list[0].kwargs.get("language") - == pipeline.language - ) - - async def test_pipeline_from_audio_stream_with_cloud_auth_fail( hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity, diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index d67a0fd1726..4f15853b296 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -1,13 +1,20 @@ """Websocket tests for Voice Assistant integration.""" -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Generator from typing import Any -from unittest.mock import ANY, patch +from unittest.mock import ANY, Mock, patch from hassil.recognize import Intent, IntentData, RecognizeResult import pytest +from syrupy.assertion import SnapshotAssertion -from homeassistant.components import conversation +from homeassistant.components import ( + assist_pipeline, + conversation, + media_source, + stt, + tts, +) from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.pipeline import ( STORAGE_KEY, @@ -24,14 +31,22 @@ from homeassistant.components.assist_pipeline.pipeline import ( async_migrate_engine, async_update_pipeline, ) -from homeassistant.core import HomeAssistant -from homeassistant.helpers import intent +from homeassistant.const import MATCH_ALL +from homeassistant.core import Context, HomeAssistant +from homeassistant.helpers import chat_session, intent from homeassistant.setup import async_setup_component -from . import MANY_LANGUAGES -from .conftest import MockSTTProviderEntity, MockTTSProvider +from . import MANY_LANGUAGES, process_events +from .conftest import ( + MockSTTProvider, + MockSTTProviderEntity, + MockTTSProvider, + MockWakeWordEntity, + make_10ms_chunk, +) from tests.common import flush_store +from tests.typing import ClientSessionGenerator, WebSocketGenerator @pytest.fixture(autouse=True) @@ -119,6 +134,22 @@ async def test_load_pipelines(hass: HomeAssistant) -> None: assert store1.async_get_preferred_item() == store2.async_get_preferred_item() +@pytest.fixture(autouse=True) +def mock_chat_session_id() -> Generator[Mock]: + """Mock the conversation ID of chat sessions.""" + with patch( + "homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid" + ) as mock_ulid_now: + yield mock_ulid_now + + +@pytest.fixture(autouse=True) +def mock_tts_token() -> Generator[None]: + """Mock the TTS token for URLs.""" + with patch("secrets.token_urlsafe", return_value="mocked-token"): + yield + + async def test_loading_pipelines_from_storage( hass: HomeAssistant, hass_storage: dict[str, Any] ) -> None: @@ -697,3 +728,820 @@ def test_fallback_intent_filter() -> None: ) is False ) + + +async def test_wake_word_detection_aborted( + hass: HomeAssistant, + mock_stt_provider: MockSTTProvider, + mock_wake_word_provider_entity: MockWakeWordEntity, + init_components, + pipeline_data: assist_pipeline.pipeline.PipelineData, + mock_chat_session: chat_session.ChatSession, + snapshot: SnapshotAssertion, +) -> None: + """Test wake word stream is first detected, then aborted.""" + + events: list[assist_pipeline.PipelineEvent] = [] + + async def audio_data(): + yield make_10ms_chunk(b"silence!") + yield make_10ms_chunk(b"wake word!") + yield make_10ms_chunk(b"part1") + yield make_10ms_chunk(b"part2") + yield b"" + + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + session=mock_chat_session, + device_id=None, + stt_metadata=stt.SpeechMetadata( + language="", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=audio_data(), + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.WAKE_WORD, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=events.append, + tts_audio_output=None, + wake_word_settings=assist_pipeline.WakeWordSettings( + audio_seconds_to_buffer=1.5 + ), + audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), + ), + ) + await pipeline_input.validate() + + updates = pipeline.to_json() + updates.pop("id") + await pipeline_store.async_update_item( + pipeline_id, + updates, + ) + await pipeline_input.execute() + + assert process_events(events) == snapshot + + +def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None: + """Test that pipeline run equality uses unique id.""" + + def event_callback(event): + pass + + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass) + run_1 = assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.STT, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=event_callback, + ) + run_2 = assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.STT, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=event_callback, + ) + + assert run_1 == run_1 # noqa: PLR0124 + assert run_1 != run_2 + assert run_1 != 1234 + + +async def test_tts_audio_output( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_tts_provider: MockTTSProvider, + init_components, + pipeline_data: assist_pipeline.pipeline.PipelineData, + mock_chat_session: chat_session.ChatSession, + snapshot: SnapshotAssertion, +) -> None: + """Test using tts_audio_output with wav sets options correctly.""" + client = await hass_client() + assert await async_setup_component(hass, media_source.DOMAIN, {}) + + events: list[assist_pipeline.PipelineEvent] = [] + + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + tts_input="This is a test.", + session=mock_chat_session, + device_id=None, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.TTS, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=events.append, + tts_audio_output="wav", + ), + ) + await pipeline_input.validate() + + # Verify TTS audio settings + assert pipeline_input.run.tts_stream.options is not None + assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" + assert ( + pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) + == 16000 + ) + assert ( + pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) + == 1 + ) + + with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio: + await pipeline_input.execute() + + for event in events: + if event.type == assist_pipeline.PipelineEventType.TTS_END: + # We must fetch the media URL to trigger the TTS + assert event.data + await client.get(event.data["tts_output"]["url"]) + + # Ensure that no unsupported options were passed in + assert mock_get_tts_audio.called + options = mock_get_tts_audio.call_args_list[0].kwargs["options"] + extra_options = set(options).difference(mock_tts_provider.supported_options) + assert len(extra_options) == 0, extra_options + + +async def test_tts_wav_preferred_format( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_tts_provider: MockTTSProvider, + init_components, + mock_chat_session: chat_session.ChatSession, + pipeline_data: assist_pipeline.pipeline.PipelineData, +) -> None: + """Test that preferred format options are given to the TTS system if supported.""" + client = await hass_client() + assert await async_setup_component(hass, media_source.DOMAIN, {}) + + events: list[assist_pipeline.PipelineEvent] = [] + + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + tts_input="This is a test.", + session=mock_chat_session, + device_id=None, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.TTS, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=events.append, + tts_audio_output="wav", + ), + ) + await pipeline_input.validate() + + # Make the TTS provider support preferred format options + supported_options = list(mock_tts_provider.supported_options or []) + supported_options.extend( + [ + tts.ATTR_PREFERRED_FORMAT, + tts.ATTR_PREFERRED_SAMPLE_RATE, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS, + tts.ATTR_PREFERRED_SAMPLE_BYTES, + ] + ) + + with ( + patch.object(mock_tts_provider, "_supported_options", supported_options), + patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio, + ): + await pipeline_input.execute() + + for event in events: + if event.type == assist_pipeline.PipelineEventType.TTS_END: + # We must fetch the media URL to trigger the TTS + assert event.data + await client.get(event.data["tts_output"]["url"]) + + assert mock_get_tts_audio.called + options = mock_get_tts_audio.call_args_list[0].kwargs["options"] + + # We should have received preferred format options in get_tts_audio + assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000 + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1 + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2 + + +async def test_tts_dict_preferred_format( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_tts_provider: MockTTSProvider, + init_components, + mock_chat_session: chat_session.ChatSession, + pipeline_data: assist_pipeline.pipeline.PipelineData, +) -> None: + """Test that preferred format options are given to the TTS system if supported.""" + client = await hass_client() + assert await async_setup_component(hass, media_source.DOMAIN, {}) + + events: list[assist_pipeline.PipelineEvent] = [] + + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + tts_input="This is a test.", + session=mock_chat_session, + device_id=None, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.TTS, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=events.append, + tts_audio_output={ + tts.ATTR_PREFERRED_FORMAT: "flac", + tts.ATTR_PREFERRED_SAMPLE_RATE: 48000, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2, + tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, + }, + ), + ) + await pipeline_input.validate() + + # Make the TTS provider support preferred format options + supported_options = list(mock_tts_provider.supported_options or []) + supported_options.extend( + [ + tts.ATTR_PREFERRED_FORMAT, + tts.ATTR_PREFERRED_SAMPLE_RATE, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS, + tts.ATTR_PREFERRED_SAMPLE_BYTES, + ] + ) + + with ( + patch.object(mock_tts_provider, "_supported_options", supported_options), + patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio, + ): + await pipeline_input.execute() + + for event in events: + if event.type == assist_pipeline.PipelineEventType.TTS_END: + # We must fetch the media URL to trigger the TTS + assert event.data + await client.get(event.data["tts_output"]["url"]) + + assert mock_get_tts_audio.called + options = mock_get_tts_audio.call_args_list[0].kwargs["options"] + + # We should have received preferred format options in get_tts_audio + assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac" + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000 + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2 + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2 + + +async def test_sentence_trigger_overrides_conversation_agent( + hass: HomeAssistant, + init_components, + mock_chat_session: chat_session.ChatSession, + pipeline_data: assist_pipeline.pipeline.PipelineData, +) -> None: + """Test that sentence triggers are checked before a non-default conversation agent.""" + assert await async_setup_component( + hass, + "automation", + { + "automation": { + "trigger": { + "platform": "conversation", + "command": [ + "test trigger sentence", + ], + }, + "action": { + "set_conversation_response": "test trigger response", + }, + } + }, + ) + + events: list[assist_pipeline.PipelineEvent] = [] + + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="test trigger sentence", + session=mock_chat_session, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.INTENT, + event_callback=events.append, + intent_agent="test-agent", # not the default agent + ), + ) + + # Ensure prepare succeeds + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", + return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + ): + await pipeline_input.validate() + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse" + ) as mock_async_converse: + await pipeline_input.execute() + + # Sentence trigger should have been handled + mock_async_converse.assert_not_called() + + # Verify sentence trigger response + intent_end_event = next( + ( + e + for e in events + if e.type == assist_pipeline.PipelineEventType.INTENT_END + ), + None, + ) + assert (intent_end_event is not None) and intent_end_event.data + assert ( + intent_end_event.data["intent_output"]["response"]["speech"]["plain"][ + "speech" + ] + == "test trigger response" + ) + + +async def test_prefer_local_intents( + hass: HomeAssistant, + init_components, + mock_chat_session: chat_session.ChatSession, + pipeline_data: assist_pipeline.pipeline.PipelineData, +) -> None: + """Test that the default agent is checked first when local intents are preferred.""" + events: list[assist_pipeline.PipelineEvent] = [] + + # Reuse custom sentences in test config + class OrderBeerIntentHandler(intent.IntentHandler): + intent_type = "OrderBeer" + + async def async_handle( + self, intent_obj: intent.Intent + ) -> intent.IntentResponse: + response = intent_obj.create_response() + response.async_set_speech("Order confirmed") + return response + + handler = OrderBeerIntentHandler() + intent.async_register(hass, handler) + + # Fake a test agent and prefer local intents + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + await assist_pipeline.pipeline.async_update_pipeline( + hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True + ) + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="I'd like to order a stout please", + session=mock_chat_session, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.INTENT, + event_callback=events.append, + ), + ) + + # Ensure prepare succeeds + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", + return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + ): + await pipeline_input.validate() + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse" + ) as mock_async_converse: + await pipeline_input.execute() + + # Test agent should not have been called + mock_async_converse.assert_not_called() + + # Verify local intent response + intent_end_event = next( + ( + e + for e in events + if e.type == assist_pipeline.PipelineEventType.INTENT_END + ), + None, + ) + assert (intent_end_event is not None) and intent_end_event.data + assert ( + intent_end_event.data["intent_output"]["response"]["speech"]["plain"][ + "speech" + ] + == "Order confirmed" + ) + + +async def test_intent_continue_conversation( + hass: HomeAssistant, + init_components, + mock_chat_session: chat_session.ChatSession, + pipeline_data: assist_pipeline.pipeline.PipelineData, +) -> None: + """Test that a conversation agent flagging continue conversation gets response.""" + events: list[assist_pipeline.PipelineEvent] = [] + + # Fake a test agent and prefer local intents + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + await assist_pipeline.pipeline.async_update_pipeline( + hass, pipeline, conversation_engine="test-agent" + ) + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="Set a timer", + session=mock_chat_session, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.INTENT, + event_callback=events.append, + ), + ) + + # Ensure prepare succeeds + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", + return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + ): + await pipeline_input.validate() + + response = intent.IntentResponse("en") + response.async_set_speech("For how long?") + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", + return_value=conversation.ConversationResult( + response=response, + conversation_id=mock_chat_session.conversation_id, + continue_conversation=True, + ), + ) as mock_async_converse: + await pipeline_input.execute() + + mock_async_converse.assert_called() + + results = [ + event.data + for event in events + if event.type + in ( + assist_pipeline.PipelineEventType.INTENT_START, + assist_pipeline.PipelineEventType.INTENT_END, + ) + ] + assert results[1]["intent_output"]["continue_conversation"] is True + + # Change conversation agent to default one and register sentence trigger that should not be called + await assist_pipeline.pipeline.async_update_pipeline( + hass, pipeline, conversation_engine=None + ) + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + assert await async_setup_component( + hass, + "automation", + { + "automation": { + "trigger": { + "platform": "conversation", + "command": ["Hello"], + }, + "action": { + "set_conversation_response": "test trigger response", + }, + } + }, + ) + + # Because we did continue conversation, it should respond to the test agent again. + events.clear() + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="Hello", + session=mock_chat_session, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.INTENT, + event_callback=events.append, + ), + ) + + # Ensure prepare succeeds + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", + return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + ) as mock_prepare: + await pipeline_input.validate() + + # It requested test agent even if that was not default agent. + assert mock_prepare.mock_calls[0][1][1] == "test-agent" + + response = intent.IntentResponse("en") + response.async_set_speech("Timer set for 20 minutes") + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", + return_value=conversation.ConversationResult( + response=response, + conversation_id=mock_chat_session.conversation_id, + ), + ) as mock_async_converse: + await pipeline_input.execute() + + mock_async_converse.assert_called() + + # Snapshot will show it was still handled by the test agent and not default agent + results = [ + event.data + for event in events + if event.type + in ( + assist_pipeline.PipelineEventType.INTENT_START, + assist_pipeline.PipelineEventType.INTENT_END, + ) + ] + assert results[0]["engine"] == "test-agent" + assert results[1]["intent_output"]["continue_conversation"] is False + + +async def test_stt_language_used_instead_of_conversation_language( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + mock_chat_session: chat_session.ChatSession, + snapshot: SnapshotAssertion, +) -> None: + """Test that the STT language is used first when the conversation language is '*' (all languages).""" + client = await hass_ws_client(hass) + + events: list[assist_pipeline.PipelineEvent] = [] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "homeassistant", + "conversation_language": MATCH_ALL, + "language": "en", + "name": "test_name", + "stt_engine": "test", + "stt_language": "en-US", + "tts_engine": "test", + "tts_language": "en-US", + "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": None, + "wake_word_id": None, + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id = msg["result"]["id"] + pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="test input", + session=mock_chat_session, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.INTENT, + event_callback=events.append, + ), + ) + await pipeline_input.validate() + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", + return_value=conversation.ConversationResult( + intent.IntentResponse(pipeline.language) + ), + ) as mock_async_converse: + await pipeline_input.execute() + + # Check intent start event + assert process_events(events) == snapshot + intent_start: assist_pipeline.PipelineEvent | None = None + for event in events: + if event.type == assist_pipeline.PipelineEventType.INTENT_START: + intent_start = event + break + + assert intent_start is not None + + # STT language (en-US) should be used instead of '*' + assert intent_start.data.get("language") == pipeline.stt_language + + # Check input to async_converse + mock_async_converse.assert_called_once() + assert ( + mock_async_converse.call_args_list[0].kwargs.get("language") + == pipeline.stt_language + ) + + +async def test_tts_language_used_instead_of_conversation_language( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + mock_chat_session: chat_session.ChatSession, + snapshot: SnapshotAssertion, +) -> None: + """Test that the TTS language is used after STT when the conversation language is '*' (all languages).""" + client = await hass_ws_client(hass) + + events: list[assist_pipeline.PipelineEvent] = [] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "homeassistant", + "conversation_language": MATCH_ALL, + "language": "en", + "name": "test_name", + "stt_engine": None, + "stt_language": None, + "tts_engine": None, + "tts_language": "en-us", + "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": None, + "wake_word_id": None, + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id = msg["result"]["id"] + pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="test input", + session=mock_chat_session, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.INTENT, + event_callback=events.append, + ), + ) + await pipeline_input.validate() + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", + return_value=conversation.ConversationResult( + intent.IntentResponse(pipeline.language) + ), + ) as mock_async_converse: + await pipeline_input.execute() + + # Check intent start event + assert process_events(events) == snapshot + intent_start: assist_pipeline.PipelineEvent | None = None + for event in events: + if event.type == assist_pipeline.PipelineEventType.INTENT_START: + intent_start = event + break + + assert intent_start is not None + + # STT language (en-US) should be used instead of '*' + assert intent_start.data.get("language") == pipeline.tts_language + + # Check input to async_converse + mock_async_converse.assert_called_once() + assert ( + mock_async_converse.call_args_list[0].kwargs.get("language") + == pipeline.tts_language + ) + + +async def test_pipeline_language_used_instead_of_conversation_language( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + mock_chat_session: chat_session.ChatSession, + snapshot: SnapshotAssertion, +) -> None: + """Test that the pipeline language is used last when the conversation language is '*' (all languages).""" + client = await hass_ws_client(hass) + + events: list[assist_pipeline.PipelineEvent] = [] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "homeassistant", + "conversation_language": MATCH_ALL, + "language": "en", + "name": "test_name", + "stt_engine": None, + "stt_language": None, + "tts_engine": None, + "tts_language": None, + "tts_voice": None, + "wake_word_entity": None, + "wake_word_id": None, + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id = msg["result"]["id"] + pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="test input", + session=mock_chat_session, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.INTENT, + event_callback=events.append, + ), + ) + await pipeline_input.validate() + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", + return_value=conversation.ConversationResult( + intent.IntentResponse(pipeline.language) + ), + ) as mock_async_converse: + await pipeline_input.execute() + + # Check intent start event + assert process_events(events) == snapshot + intent_start: assist_pipeline.PipelineEvent | None = None + for event in events: + if event.type == assist_pipeline.PipelineEventType.INTENT_START: + intent_start = event + break + + assert intent_start is not None + + # STT language (en-US) should be used instead of '*' + assert intent_start.data.get("language") == pipeline.language + + # Check input to async_converse + mock_async_converse.assert_called_once() + assert ( + mock_async_converse.call_args_list[0].kwargs.get("language") + == pipeline.language + )