"""Test Voice Assistant init.""" import asyncio from dataclasses import asdict import itertools as it from pathlib import Path import tempfile from unittest.mock import ANY, patch import wave import hass_nabucasa import pytest from syrupy.assertion import SnapshotAssertion from homeassistant.components import ( assist_pipeline, conversation, media_source, stt, tts, ) from homeassistant.components.assist_pipeline.const import ( BYTES_PER_CHUNK, CONF_DEBUG_RECORDING_DIR, DOMAIN, ) from homeassistant.const import MATCH_ALL from homeassistant.core import Context, HomeAssistant from homeassistant.helpers import intent from homeassistant.setup import async_setup_component from .conftest import ( BYTES_ONE_SECOND, MockSTTProvider, MockSTTProviderEntity, MockTTSProvider, MockWakeWordEntity, make_10ms_chunk, ) from tests.typing import ClientSessionGenerator, WebSocketGenerator def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]: """Process events to remove dynamic values.""" processed = [] for event in events: as_dict = asdict(event) as_dict.pop("timestamp") if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START: as_dict["data"]["pipeline"] = ANY processed.append(as_dict) return processed async def test_pipeline_from_audio_stream_auto( hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity, init_components, snapshot: SnapshotAssertion, ) -> None: """Test creating a pipeline from an audio stream. In this test, no pipeline is specified. """ events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): yield make_10ms_chunk(b"part1") yield make_10ms_chunk(b"part2") yield b"" with patch( "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token" ): await assist_pipeline.async_pipeline_from_audio_stream( hass, context=Context(), event_callback=events.append, stt_metadata=stt.SpeechMetadata( language="", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ) assert process_events(events) == snapshot assert len(mock_stt_provider_entity.received) == 2 assert mock_stt_provider_entity.received[0].startswith(b"part1") assert mock_stt_provider_entity.received[1].startswith(b"part2") async def test_pipeline_from_audio_stream_legacy( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, mock_stt_provider: MockSTTProvider, init_components, snapshot: SnapshotAssertion, ) -> None: """Test creating a pipeline from an audio stream. In this test, a pipeline using a legacy stt engine is used. """ client = await hass_ws_client(hass) events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): yield make_10ms_chunk(b"part1") yield make_10ms_chunk(b"part2") yield b"" # Create a pipeline using an stt entity await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "homeassistant", "conversation_language": "en-US", "language": "en", "name": "test_name", "stt_engine": "test", "stt_language": "en-US", "tts_engine": "test", "tts_language": "en-US", "tts_voice": "Arnold Schwarzenegger", "wake_word_entity": None, "wake_word_id": None, } ) msg = await client.receive_json() assert msg["success"] pipeline_id = msg["result"]["id"] with patch( "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token" ): # Use the created pipeline await assist_pipeline.async_pipeline_from_audio_stream( hass, context=Context(), event_callback=events.append, stt_metadata=stt.SpeechMetadata( language="en-UK", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), pipeline_id=pipeline_id, audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ) assert process_events(events) == snapshot assert len(mock_stt_provider.received) == 2 assert mock_stt_provider.received[0].startswith(b"part1") assert mock_stt_provider.received[1].startswith(b"part2") async def test_pipeline_from_audio_stream_entity( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, mock_stt_provider_entity: MockSTTProviderEntity, init_components, snapshot: SnapshotAssertion, ) -> None: """Test creating a pipeline from an audio stream. In this test, a pipeline using am stt entity is used. """ client = await hass_ws_client(hass) events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): yield make_10ms_chunk(b"part1") yield make_10ms_chunk(b"part2") yield b"" # Create a pipeline using an stt entity await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "homeassistant", "conversation_language": "en-US", "language": "en", "name": "test_name", "stt_engine": mock_stt_provider_entity.entity_id, "stt_language": "en-US", "tts_engine": "test", "tts_language": "en-US", "tts_voice": "Arnold Schwarzenegger", "wake_word_entity": None, "wake_word_id": None, } ) msg = await client.receive_json() assert msg["success"] pipeline_id = msg["result"]["id"] with patch( "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token" ): # Use the created pipeline await assist_pipeline.async_pipeline_from_audio_stream( hass, context=Context(), event_callback=events.append, stt_metadata=stt.SpeechMetadata( language="en-UK", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), pipeline_id=pipeline_id, audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ) assert process_events(events) == snapshot assert len(mock_stt_provider_entity.received) == 2 assert mock_stt_provider_entity.received[0].startswith(b"part1") assert mock_stt_provider_entity.received[1].startswith(b"part2") async def test_pipeline_from_audio_stream_no_stt( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, mock_stt_provider: MockSTTProvider, init_components, snapshot: SnapshotAssertion, ) -> None: """Test creating a pipeline from an audio stream. In this test, the pipeline does not support stt """ client = await hass_ws_client(hass) events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): yield make_10ms_chunk(b"part1") yield make_10ms_chunk(b"part2") yield b"" # Create a pipeline without stt support await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "homeassistant", "conversation_language": "en-US", "language": "en", "name": "test_name", "stt_engine": None, "stt_language": None, "tts_engine": "test", "tts_language": "en-AU", "tts_voice": "Arnold Schwarzenegger", "wake_word_entity": None, "wake_word_id": None, } ) msg = await client.receive_json() assert msg["success"] pipeline_id = msg["result"]["id"] # Try to use the created pipeline with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError): await assist_pipeline.async_pipeline_from_audio_stream( hass, context=Context(), event_callback=events.append, stt_metadata=stt.SpeechMetadata( language="en-UK", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), pipeline_id=pipeline_id, audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ) assert not events async def test_pipeline_from_audio_stream_unknown_pipeline( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, mock_stt_provider: MockSTTProvider, init_components, snapshot: SnapshotAssertion, ) -> None: """Test creating a pipeline from an audio stream. In this test, the pipeline does not exist. """ events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): yield make_10ms_chunk(b"part1") yield make_10ms_chunk(b"part2") yield b"" # Try to use the created pipeline with pytest.raises(assist_pipeline.PipelineNotFound): await assist_pipeline.async_pipeline_from_audio_stream( hass, context=Context(), event_callback=events.append, stt_metadata=stt.SpeechMetadata( language="en-UK", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), pipeline_id="blah", ) assert not events async def test_pipeline_from_audio_stream_wake_word( hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity, mock_wake_word_provider_entity: MockWakeWordEntity, init_components, snapshot: SnapshotAssertion, ) -> None: """Test creating a pipeline from an audio stream with wake word.""" events: list[assist_pipeline.PipelineEvent] = [] # [0, 1, ...] wake_chunk_1 = bytes(it.islice(it.cycle(range(256)), BYTES_ONE_SECOND)) # [0, 2, ...] wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND)) samples_per_chunk = 160 # 10ms @ 16Khz bytes_per_chunk = samples_per_chunk * 2 # 16-bit async def audio_data(): # 1 second in chunks i = 0 while i < len(wake_chunk_1): yield wake_chunk_1[i : i + bytes_per_chunk] i += bytes_per_chunk # 1 second in chunks i = 0 while i < len(wake_chunk_2): yield wake_chunk_2[i : i + bytes_per_chunk] i += bytes_per_chunk for header in (b"wake word!", b"part1", b"part2"): yield make_10ms_chunk(header) yield b"" with patch( "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token" ): await assist_pipeline.async_pipeline_from_audio_stream( hass, context=Context(), event_callback=events.append, stt_metadata=stt.SpeechMetadata( language="", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), start_stage=assist_pipeline.PipelineStage.WAKE_WORD, wake_word_settings=assist_pipeline.WakeWordSettings( audio_seconds_to_buffer=1.5 ), audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ) assert process_events(events) == snapshot # 1. Half of wake_chunk_1 + all wake_chunk_2 # 2. queued audio (from mock wake word entity) # 3. part1 # 4. part2 assert len(mock_stt_provider_entity.received) > 3 first_chunk = bytes( [c_byte for c in mock_stt_provider_entity.received[:-3] for c_byte in c] ) assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2 assert mock_stt_provider_entity.received[-3] == b"queued audio" assert mock_stt_provider_entity.received[-2].startswith(b"part1") assert mock_stt_provider_entity.received[-1].startswith(b"part2") async def test_pipeline_save_audio( hass: HomeAssistant, mock_stt_provider: MockSTTProvider, mock_wake_word_provider_entity: MockWakeWordEntity, init_supporting_components, snapshot: SnapshotAssertion, ) -> None: """Test saving audio during a pipeline run.""" with tempfile.TemporaryDirectory() as temp_dir_str: # Enable audio recording to temporary directory temp_dir = Path(temp_dir_str) assert await async_setup_component( hass, DOMAIN, {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}}, ) pipeline = assist_pipeline.async_get_pipeline(hass) events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): yield make_10ms_chunk(b"wake word") # queued audio yield make_10ms_chunk(b"part1") yield make_10ms_chunk(b"part2") yield b"" await assist_pipeline.async_pipeline_from_audio_stream( hass, context=Context(), event_callback=events.append, stt_metadata=stt.SpeechMetadata( language="", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), pipeline_id=pipeline.id, start_stage=assist_pipeline.PipelineStage.WAKE_WORD, end_stage=assist_pipeline.PipelineStage.STT, audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ) pipeline_dirs = list(temp_dir.iterdir()) # Only one pipeline run # // assert len(pipeline_dirs) == 1 assert pipeline_dirs[0].is_dir() assert pipeline_dirs[0].name == pipeline.name # Wake and stt files run_dirs = list(pipeline_dirs[0].iterdir()) assert run_dirs[0].is_dir() run_files = list(run_dirs[0].iterdir()) assert len(run_files) == 2 wake_file = run_files[0] if "wake" in run_files[0].name else run_files[1] stt_file = run_files[0] if "stt" in run_files[0].name else run_files[1] assert wake_file != stt_file # Verify wake file with wave.open(str(wake_file), "rb") as wake_wav: wake_data = wake_wav.readframes(wake_wav.getnframes()) assert wake_data.startswith(b"wake word") # Verify stt file with wave.open(str(stt_file), "rb") as stt_wav: stt_data = stt_wav.readframes(stt_wav.getnframes()) assert stt_data.startswith(b"queued audio") stt_data = stt_data[len(b"queued audio") :] assert stt_data.startswith(b"part1") stt_data = stt_data[BYTES_PER_CHUNK:] assert stt_data.startswith(b"part2") async def test_pipeline_saved_audio_with_device_id( hass: HomeAssistant, mock_stt_provider: MockSTTProvider, mock_wake_word_provider_entity: MockWakeWordEntity, init_supporting_components, snapshot: SnapshotAssertion, ) -> None: """Test that saved audio directory uses device id.""" device_id = "test-device-id" with tempfile.TemporaryDirectory() as temp_dir_str: # Enable audio recording to temporary directory temp_dir = Path(temp_dir_str) assert await async_setup_component( hass, DOMAIN, {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}}, ) def event_callback(event: assist_pipeline.PipelineEvent): if event.type == "run-end": # Verify that saved audio directory is named after device id device_dirs = list(temp_dir.iterdir()) assert device_dirs[0].name == device_id async def audio_data(): yield b"not used" # Force a timeout during wake word detection with patch.object( mock_wake_word_provider_entity, "async_process_audio_stream", side_effect=assist_pipeline.error.WakeWordTimeoutError( code="timeout", message="timeout" ), ): await assist_pipeline.async_pipeline_from_audio_stream( hass, context=Context(), event_callback=event_callback, stt_metadata=stt.SpeechMetadata( language="", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), start_stage=assist_pipeline.PipelineStage.WAKE_WORD, end_stage=assist_pipeline.PipelineStage.STT, device_id=device_id, ) async def test_pipeline_saved_audio_write_error( hass: HomeAssistant, mock_stt_provider: MockSTTProvider, mock_wake_word_provider_entity: MockWakeWordEntity, init_supporting_components, snapshot: SnapshotAssertion, ) -> None: """Test that saved audio thread closes WAV file even if there's a write error.""" with tempfile.TemporaryDirectory() as temp_dir_str: # Enable audio recording to temporary directory temp_dir = Path(temp_dir_str) assert await async_setup_component( hass, DOMAIN, {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}}, ) def event_callback(event: assist_pipeline.PipelineEvent): if event.type == "run-end": # Verify WAV file exists, but contains no data pipeline_dirs = list(temp_dir.iterdir()) run_dirs = list(pipeline_dirs[0].iterdir()) wav_path = next(run_dirs[0].iterdir()) with wave.open(str(wav_path), "rb") as wav_file: assert wav_file.getnframes() == 0 async def audio_data(): yield b"not used" # Force a timeout during wake word detection with patch("wave.Wave_write.writeframes", raises=RuntimeError()): await assist_pipeline.async_pipeline_from_audio_stream( hass, context=Context(), event_callback=event_callback, stt_metadata=stt.SpeechMetadata( language="", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), start_stage=assist_pipeline.PipelineStage.WAKE_WORD, end_stage=assist_pipeline.PipelineStage.STT, ) async def test_pipeline_saved_audio_empty_queue( hass: HomeAssistant, mock_stt_provider: MockSTTProvider, mock_wake_word_provider_entity: MockWakeWordEntity, init_supporting_components, snapshot: SnapshotAssertion, ) -> None: """Test that saved audio thread closes WAV file even if there's an empty queue.""" with tempfile.TemporaryDirectory() as temp_dir_str: # Enable audio recording to temporary directory temp_dir = Path(temp_dir_str) assert await async_setup_component( hass, DOMAIN, {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}}, ) def event_callback(event: assist_pipeline.PipelineEvent): if event.type == "run-end": # Verify WAV file exists, but contains no data pipeline_dirs = list(temp_dir.iterdir()) run_dirs = list(pipeline_dirs[0].iterdir()) wav_path = next(run_dirs[0].iterdir()) with wave.open(str(wav_path), "rb") as wav_file: assert wav_file.getnframes() == 0 async def audio_data(): # Force timeout in _pipeline_debug_recording_thread_proc await asyncio.sleep(1) yield b"not used" # Wrap original function to time out immediately _pipeline_debug_recording_thread_proc = ( assist_pipeline.pipeline._pipeline_debug_recording_thread_proc ) def proc_wrapper(run_recording_dir, queue): _pipeline_debug_recording_thread_proc( run_recording_dir, queue, message_timeout=0 ) with patch( "homeassistant.components.assist_pipeline.pipeline._pipeline_debug_recording_thread_proc", proc_wrapper, ): await assist_pipeline.async_pipeline_from_audio_stream( hass, context=Context(), event_callback=event_callback, stt_metadata=stt.SpeechMetadata( language="", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), start_stage=assist_pipeline.PipelineStage.WAKE_WORD, end_stage=assist_pipeline.PipelineStage.STT, ) async def test_wake_word_detection_aborted( hass: HomeAssistant, mock_stt_provider: MockSTTProvider, mock_wake_word_provider_entity: MockWakeWordEntity, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, snapshot: SnapshotAssertion, ) -> None: """Test creating a pipeline from an audio stream with wake word.""" events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): yield make_10ms_chunk(b"silence!") yield make_10ms_chunk(b"wake word!") yield make_10ms_chunk(b"part1") yield make_10ms_chunk(b"part2") yield b"" pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( conversation_id=None, device_id=None, stt_metadata=stt.SpeechMetadata( language="", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.WAKE_WORD, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, tts_audio_output=None, wake_word_settings=assist_pipeline.WakeWordSettings( audio_seconds_to_buffer=1.5 ), audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ), ) await pipeline_input.validate() updates = pipeline.to_json() updates.pop("id") await pipeline_store.async_update_item( pipeline_id, updates, ) await pipeline_input.execute() assert process_events(events) == snapshot def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None: """Test that pipeline run equality uses unique id.""" def event_callback(event): pass pipeline = assist_pipeline.pipeline.async_get_pipeline(hass) run_1 = assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.STT, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=event_callback, ) run_2 = assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.STT, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=event_callback, ) assert run_1 == run_1 # noqa: PLR0124 assert run_1 != run_2 assert run_1 != 1234 async def test_tts_audio_output( hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_provider: MockTTSProvider, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, snapshot: SnapshotAssertion, ) -> None: """Test using tts_audio_output with wav sets options correctly.""" client = await hass_client() assert await async_setup_component(hass, media_source.DOMAIN, {}) events: list[assist_pipeline.PipelineEvent] = [] pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", conversation_id=None, device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.TTS, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, tts_audio_output="wav", ), ) await pipeline_input.validate() # Verify TTS audio settings assert pipeline_input.run.tts_options is not None assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000 assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1 with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio: await pipeline_input.execute() for event in events: if event.type == assist_pipeline.PipelineEventType.TTS_END: # We must fetch the media URL to trigger the TTS assert event.data media_id = event.data["tts_output"]["media_id"] resolved = await media_source.async_resolve_media(hass, media_id, None) await client.get(resolved.url) # Ensure that no unsupported options were passed in assert mock_get_tts_audio.called options = mock_get_tts_audio.call_args_list[0].kwargs["options"] extra_options = set(options).difference(mock_tts_provider.supported_options) assert len(extra_options) == 0, extra_options async def test_tts_wav_preferred_format( hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_provider: MockTTSProvider, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that preferred format options are given to the TTS system if supported.""" client = await hass_client() assert await async_setup_component(hass, media_source.DOMAIN, {}) events: list[assist_pipeline.PipelineEvent] = [] pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", conversation_id=None, device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.TTS, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, tts_audio_output="wav", ), ) await pipeline_input.validate() # Make the TTS provider support preferred format options supported_options = list(mock_tts_provider.supported_options or []) supported_options.extend( [ tts.ATTR_PREFERRED_FORMAT, tts.ATTR_PREFERRED_SAMPLE_RATE, tts.ATTR_PREFERRED_SAMPLE_CHANNELS, tts.ATTR_PREFERRED_SAMPLE_BYTES, ] ) with ( patch.object(mock_tts_provider, "_supported_options", supported_options), patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio, ): await pipeline_input.execute() for event in events: if event.type == assist_pipeline.PipelineEventType.TTS_END: # We must fetch the media URL to trigger the TTS assert event.data media_id = event.data["tts_output"]["media_id"] resolved = await media_source.async_resolve_media(hass, media_id, None) await client.get(resolved.url) assert mock_get_tts_audio.called options = mock_get_tts_audio.call_args_list[0].kwargs["options"] # We should have received preferred format options in get_tts_audio assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000 assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1 assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2 async def test_tts_dict_preferred_format( hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_provider: MockTTSProvider, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that preferred format options are given to the TTS system if supported.""" client = await hass_client() assert await async_setup_component(hass, media_source.DOMAIN, {}) events: list[assist_pipeline.PipelineEvent] = [] pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", conversation_id=None, device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.TTS, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, tts_audio_output={ tts.ATTR_PREFERRED_FORMAT: "flac", tts.ATTR_PREFERRED_SAMPLE_RATE: 48000, tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2, tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, }, ), ) await pipeline_input.validate() # Make the TTS provider support preferred format options supported_options = list(mock_tts_provider.supported_options or []) supported_options.extend( [ tts.ATTR_PREFERRED_FORMAT, tts.ATTR_PREFERRED_SAMPLE_RATE, tts.ATTR_PREFERRED_SAMPLE_CHANNELS, tts.ATTR_PREFERRED_SAMPLE_BYTES, ] ) with ( patch.object(mock_tts_provider, "_supported_options", supported_options), patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio, ): await pipeline_input.execute() for event in events: if event.type == assist_pipeline.PipelineEventType.TTS_END: # We must fetch the media URL to trigger the TTS assert event.data media_id = event.data["tts_output"]["media_id"] resolved = await media_source.async_resolve_media(hass, media_id, None) await client.get(resolved.url) assert mock_get_tts_audio.called options = mock_get_tts_audio.call_args_list[0].kwargs["options"] # We should have received preferred format options in get_tts_audio assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac" assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000 assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2 assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2 async def test_sentence_trigger_overrides_conversation_agent( hass: HomeAssistant, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that sentence triggers are checked before a non-default conversation agent.""" assert await async_setup_component( hass, "automation", { "automation": { "trigger": { "platform": "conversation", "command": [ "test trigger sentence", ], }, "action": { "set_conversation_response": "test trigger response", }, } }, ) events: list[assist_pipeline.PipelineEvent] = [] pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test trigger sentence", run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, intent_agent="test-agent", # not the default agent ), ) # Ensure prepare succeeds with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), ): await pipeline_input.validate() with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse" ) as mock_async_converse: await pipeline_input.execute() # Sentence trigger should have been handled mock_async_converse.assert_not_called() # Verify sentence trigger response intent_end_event = next( ( e for e in events if e.type == assist_pipeline.PipelineEventType.INTENT_END ), None, ) assert (intent_end_event is not None) and intent_end_event.data assert ( intent_end_event.data["intent_output"]["response"]["speech"]["plain"][ "speech" ] == "test trigger response" ) async def test_prefer_local_intents( hass: HomeAssistant, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that the default agent is checked first when local intents are preferred.""" events: list[assist_pipeline.PipelineEvent] = [] # Reuse custom sentences in test config class OrderBeerIntentHandler(intent.IntentHandler): intent_type = "OrderBeer" async def async_handle( self, intent_obj: intent.Intent ) -> intent.IntentResponse: response = intent_obj.create_response() response.async_set_speech("Order confirmed") return response handler = OrderBeerIntentHandler() intent.async_register(hass, handler) # Fake a test agent and prefer local intents pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) await assist_pipeline.pipeline.async_update_pipeline( hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True ) pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="I'd like to order a stout please", run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, ), ) # Ensure prepare succeeds with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), ): await pipeline_input.validate() with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse" ) as mock_async_converse: await pipeline_input.execute() # Test agent should not have been called mock_async_converse.assert_not_called() # Verify local intent response intent_end_event = next( ( e for e in events if e.type == assist_pipeline.PipelineEventType.INTENT_END ), None, ) assert (intent_end_event is not None) and intent_end_event.data assert ( intent_end_event.data["intent_output"]["response"]["speech"]["plain"][ "speech" ] == "Order confirmed" ) async def test_stt_language_used_instead_of_conversation_language( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, snapshot: SnapshotAssertion, ) -> None: """Test that the STT language is used first when the conversation language is '*' (all languages).""" client = await hass_ws_client(hass) events: list[assist_pipeline.PipelineEvent] = [] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "homeassistant", "conversation_language": MATCH_ALL, "language": "en", "name": "test_name", "stt_engine": "test", "stt_language": "en-US", "tts_engine": "test", "tts_language": "en-US", "tts_voice": "Arnold Schwarzenegger", "wake_word_entity": None, "wake_word_id": None, } ) msg = await client.receive_json() assert msg["success"] pipeline_id = msg["result"]["id"] pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, ), ) await pipeline_input.validate() with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", return_value=conversation.ConversationResult( intent.IntentResponse(pipeline.language) ), ) as mock_async_converse: await pipeline_input.execute() # Check intent start event assert process_events(events) == snapshot intent_start: assist_pipeline.PipelineEvent | None = None for event in events: if event.type == assist_pipeline.PipelineEventType.INTENT_START: intent_start = event break assert intent_start is not None # STT language (en-US) should be used instead of '*' assert intent_start.data.get("language") == pipeline.stt_language # Check input to async_converse mock_async_converse.assert_called_once() assert ( mock_async_converse.call_args_list[0].kwargs.get("language") == pipeline.stt_language ) async def test_tts_language_used_instead_of_conversation_language( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, snapshot: SnapshotAssertion, ) -> None: """Test that the TTS language is used after STT when the conversation language is '*' (all languages).""" client = await hass_ws_client(hass) events: list[assist_pipeline.PipelineEvent] = [] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "homeassistant", "conversation_language": MATCH_ALL, "language": "en", "name": "test_name", "stt_engine": None, "stt_language": None, "tts_engine": None, "tts_language": "en-us", "tts_voice": "Arnold Schwarzenegger", "wake_word_entity": None, "wake_word_id": None, } ) msg = await client.receive_json() assert msg["success"] pipeline_id = msg["result"]["id"] pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, ), ) await pipeline_input.validate() with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", return_value=conversation.ConversationResult( intent.IntentResponse(pipeline.language) ), ) as mock_async_converse: await pipeline_input.execute() # Check intent start event assert process_events(events) == snapshot intent_start: assist_pipeline.PipelineEvent | None = None for event in events: if event.type == assist_pipeline.PipelineEventType.INTENT_START: intent_start = event break assert intent_start is not None # STT language (en-US) should be used instead of '*' assert intent_start.data.get("language") == pipeline.tts_language # Check input to async_converse mock_async_converse.assert_called_once() assert ( mock_async_converse.call_args_list[0].kwargs.get("language") == pipeline.tts_language ) async def test_pipeline_language_used_instead_of_conversation_language( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, snapshot: SnapshotAssertion, ) -> None: """Test that the pipeline language is used last when the conversation language is '*' (all languages).""" client = await hass_ws_client(hass) events: list[assist_pipeline.PipelineEvent] = [] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "homeassistant", "conversation_language": MATCH_ALL, "language": "en", "name": "test_name", "stt_engine": None, "stt_language": None, "tts_engine": None, "tts_language": None, "tts_voice": None, "wake_word_entity": None, "wake_word_id": None, } ) msg = await client.receive_json() assert msg["success"] pipeline_id = msg["result"]["id"] pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, ), ) await pipeline_input.validate() with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", return_value=conversation.ConversationResult( intent.IntentResponse(pipeline.language) ), ) as mock_async_converse: await pipeline_input.execute() # Check intent start event assert process_events(events) == snapshot intent_start: assist_pipeline.PipelineEvent | None = None for event in events: if event.type == assist_pipeline.PipelineEventType.INTENT_START: intent_start = event break assert intent_start is not None # STT language (en-US) should be used instead of '*' assert intent_start.data.get("language") == pipeline.language # Check input to async_converse mock_async_converse.assert_called_once() assert ( mock_async_converse.call_args_list[0].kwargs.get("language") == pipeline.language ) async def test_pipeline_from_audio_stream_with_cloud_auth_fail( hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity, init_components, snapshot: SnapshotAssertion, ) -> None: """Test creating a pipeline from an audio stream but the cloud authentication fails.""" events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): yield b"audio" with patch.object( mock_stt_provider_entity, "async_process_audio_stream", side_effect=hass_nabucasa.auth.Unauthenticated, ): await assist_pipeline.async_pipeline_from_audio_stream( hass, context=Context(), event_callback=events.append, stt_metadata=stt.SpeechMetadata( language="", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ) assert process_events(events) == snapshot assert len(events) == 4 # run start, stt start, error, run end assert events[2].type == assist_pipeline.PipelineEventType.ERROR assert events[2].data["code"] == "cloud-auth-failed"