"""Websocket tests for Voice Assistant integration.""" from collections.abc import AsyncGenerator, Generator from typing import Any from unittest.mock import ANY, AsyncMock, Mock, patch from freezegun import freeze_time from hassil.recognize import Intent, IntentData, RecognizeResult import pytest from syrupy.assertion import SnapshotAssertion import voluptuous as vol from homeassistant.components import ( assist_pipeline, conversation, media_player, media_source, stt, tts, ) from homeassistant.components.assist_pipeline.const import ACKNOWLEDGE_PATH, DOMAIN from homeassistant.components.assist_pipeline.pipeline import ( STORAGE_KEY, STORAGE_VERSION, STORAGE_VERSION_MINOR, Pipeline, PipelineData, PipelineEventType, PipelineStorageCollection, PipelineStore, _async_local_fallback_intent_filter, async_create_default_pipeline, async_get_pipeline, async_get_pipelines, async_update_pipeline, ) from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL from homeassistant.core import Context, HomeAssistant from homeassistant.helpers import ( area_registry as ar, chat_session, device_registry as dr, entity_registry as er, intent, llm, ) from homeassistant.setup import async_setup_component from . import MANY_LANGUAGES, process_events from .conftest import ( MockSTTProvider, MockSTTProviderEntity, MockTTSEntity, MockTTSProvider, MockWakeWordEntity, make_10ms_chunk, ) from tests.common import MockConfigEntry, async_mock_service, flush_store from tests.typing import ClientSessionGenerator, WebSocketGenerator @pytest.fixture(autouse=True) async def delay_save_fixture() -> AsyncGenerator[None]: """Load the homeassistant integration.""" with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0): yield @pytest.fixture(autouse=True) async def load_homeassistant(hass: HomeAssistant) -> None: """Load the homeassistant integration.""" assert await async_setup_component(hass, "homeassistant", {}) @pytest.fixture async def disable_tts_entity(mock_tts_entity: tts.TextToSpeechEntity) -> None: """Disable the TTS entity.""" mock_tts_entity._attr_entity_registry_enabled_default = False @pytest.mark.usefixtures("init_components") async def test_load_pipelines(hass: HomeAssistant) -> None: """Make sure that we can load/save data correctly.""" pipelines = [ { "conversation_engine": "conversation_engine_1", "conversation_language": "language_1", "language": "language_1", "name": "name_1", "stt_engine": "stt_engine_1", "stt_language": "language_1", "tts_engine": "tts_engine_1", "tts_language": "language_1", "tts_voice": "Arnold Schwarzenegger", "wake_word_entity": "wakeword_entity_1", "wake_word_id": "wakeword_id_1", }, { "conversation_engine": "conversation_engine_2", "conversation_language": "language_2", "language": "language_2", "name": "name_2", "stt_engine": "stt_engine_2", "stt_language": "language_1", "tts_engine": "tts_engine_2", "tts_language": "language_2", "tts_voice": "The Voice", "wake_word_entity": "wakeword_entity_2", "wake_word_id": "wakeword_id_2", }, { "conversation_engine": "conversation_engine_3", "conversation_language": "language_3", "language": "language_3", "name": "name_3", "stt_engine": None, "stt_language": None, "tts_engine": None, "tts_language": None, "tts_voice": None, "wake_word_entity": "wakeword_entity_3", "wake_word_id": "wakeword_id_3", }, ] pipeline_data: PipelineData = hass.data[DOMAIN] store1 = pipeline_data.pipeline_store pipeline_ids = [ (await store1.async_create_item(pipeline)).id for pipeline in pipelines ] assert len(store1.data) == 4 # 3 manually created plus a default pipeline assert store1.async_get_preferred_item() == list(store1.data)[0] await store1.async_delete_item(pipeline_ids[1]) assert len(store1.data) == 3 store2 = PipelineStorageCollection( PipelineStore( hass, STORAGE_VERSION, STORAGE_KEY, minor_version=STORAGE_VERSION_MINOR ) ) await flush_store(store1.store) await store2.async_load() assert len(store2.data) == 3 assert store1.data is not store2.data assert store1.data == store2.data assert store1.async_get_preferred_item() == store2.async_get_preferred_item() @pytest.fixture(autouse=True) def mock_chat_session_id() -> Generator[Mock]: """Mock the conversation ID of chat sessions.""" with patch( "homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid" ) as mock_ulid_now: yield mock_ulid_now @pytest.fixture(autouse=True) def mock_tts_token() -> Generator[None]: """Mock the TTS token for URLs.""" with patch("secrets.token_urlsafe", return_value="mocked-token"): yield async def test_loading_pipelines_from_storage( hass: HomeAssistant, hass_storage: dict[str, Any] ) -> None: """Test loading stored pipelines on start.""" id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY" hass_storage[STORAGE_KEY] = { "version": STORAGE_VERSION, "minor_version": STORAGE_VERSION_MINOR, "key": "assist_pipeline.pipelines", "data": { "items": [ { "conversation_engine": conversation.HOME_ASSISTANT_AGENT, "conversation_language": "language_1", "id": id_1, "language": "language_1", "name": "name_1", "stt_engine": "stt_engine_1", "stt_language": "language_1", "tts_engine": "tts_engine_1", "tts_language": "language_1", "tts_voice": "Arnold Schwarzenegger", "wake_word_entity": "wakeword_entity_1", "wake_word_id": "wakeword_id_1", }, { "conversation_engine": "conversation_engine_2", "conversation_language": "language_2", "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX", "language": "language_2", "name": "name_2", "stt_engine": "stt_engine_2", "stt_language": "language_2", "tts_engine": "tts_engine_2", "tts_language": "language_2", "tts_voice": "The Voice", "wake_word_entity": "wakeword_entity_2", "wake_word_id": "wakeword_id_2", }, { "conversation_engine": "conversation_engine_3", "conversation_language": "language_3", "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J", "language": "language_3", "name": "name_3", "stt_engine": None, "stt_language": None, "tts_engine": None, "tts_language": None, "tts_voice": None, "wake_word_entity": "wakeword_entity_3", "wake_word_id": "wakeword_id_3", }, ], "preferred_item": id_1, }, } assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 3 assert store.async_get_preferred_item() == id_1 assert store.data[id_1].conversation_engine == conversation.HOME_ASSISTANT_AGENT async def test_migrate_pipeline_store( hass: HomeAssistant, hass_storage: dict[str, Any] ) -> None: """Test loading stored pipelines from an older version.""" hass_storage[STORAGE_KEY] = { "version": 1, "minor_version": 1, "key": "assist_pipeline.pipelines", "data": { "items": [ { "conversation_engine": "conversation_engine_1", "conversation_language": "language_1", "id": "01GX8ZWBAQYWNB1XV3EXEZ75DY", "language": "language_1", "name": "name_1", "stt_engine": "stt_engine_1", "stt_language": "language_1", "tts_engine": "tts_engine_1", "tts_language": "language_1", "tts_voice": "Arnold Schwarzenegger", }, { "conversation_engine": "conversation_engine_2", "conversation_language": "language_2", "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX", "language": "language_2", "name": "name_2", "stt_engine": "stt_engine_2", "stt_language": "language_2", "tts_engine": "tts_engine_2", "tts_language": "language_2", "tts_voice": "The Voice", }, { "conversation_engine": "conversation_engine_3", "conversation_language": "language_3", "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J", "language": "language_3", "name": "name_3", "stt_engine": None, "stt_language": None, "tts_engine": None, "tts_language": None, "tts_voice": None, }, ], "preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY", }, } assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 3 assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY" @pytest.mark.usefixtures("init_supporting_components") @pytest.mark.usefixtures("disable_tts_entity") async def test_create_default_pipeline(hass: HomeAssistant) -> None: """Test async_create_default_pipeline.""" assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 assert ( await async_create_default_pipeline( hass, stt_engine_id="bla", tts_engine_id="bla", pipeline_name="Bla pipeline", ) is None ) assert await async_create_default_pipeline( hass, stt_engine_id="test", tts_engine_id="test", pipeline_name="Test pipeline", ) == Pipeline( conversation_engine="conversation.home_assistant", conversation_language="en", id=ANY, language="en", name="Test pipeline", stt_engine="test", stt_language="en-US", tts_engine="test", tts_language="en-US", tts_voice="james_earl_jones", wake_word_entity=None, wake_word_id=None, ) async def test_get_pipeline(hass: HomeAssistant) -> None: """Test async_get_pipeline.""" assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 # Test we get the preferred pipeline if none is specified pipeline = async_get_pipeline(hass, None) assert pipeline.id == store.async_get_preferred_item() # Test getting a specific pipeline assert pipeline is async_get_pipeline(hass, pipeline.id) async def test_get_pipelines(hass: HomeAssistant) -> None: """Test async_get_pipelines.""" assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 pipelines = async_get_pipelines(hass) assert list(pipelines) == [ Pipeline( conversation_engine="conversation.home_assistant", conversation_language="en", id=ANY, language="en", name="Home Assistant", stt_engine=None, stt_language=None, tts_engine=None, tts_language=None, tts_voice=None, wake_word_entity=None, wake_word_id=None, ) ] @pytest.mark.parametrize( ("ha_language", "ha_country", "conv_language", "pipeline_language"), [ ("en", None, "en", "en"), ("de", "de", "de", "de"), ("de", "ch", "de-CH", "de"), ("en", "us", "en", "en"), ("en", "uk", "en", "en"), ("pt", "pt", "pt", "pt"), ("pt", "br", "pt-BR", "pt"), ], ) async def test_default_pipeline_no_stt_tts( hass: HomeAssistant, ha_language: str, ha_country: str | None, conv_language: str, pipeline_language: str, ) -> None: """Test async_get_pipeline.""" hass.config.country = ha_country hass.config.language = ha_language assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( conversation_engine="conversation.home_assistant", conversation_language=conv_language, id=pipeline.id, language=pipeline_language, name="Home Assistant", stt_engine=None, stt_language=None, tts_engine=None, tts_language=None, tts_voice=None, wake_word_entity=None, wake_word_id=None, ) @pytest.mark.parametrize( ( "ha_language", "ha_country", "conv_language", "pipeline_language", "stt_language", "tts_language", ), [ ("en", None, "en", "en", "en", "en"), ("de", "de", "de", "de", "de", "de"), ("de", "ch", "de-CH", "de", "de-CH", "de-CH"), ("en", "us", "en", "en", "en", "en"), ("en", "uk", "en", "en", "en", "en"), ("pt", "pt", "pt", "pt", "pt", "pt"), ("pt", "br", "pt-BR", "pt", "pt-br", "pt-br"), ], ) @pytest.mark.usefixtures("init_supporting_components") @pytest.mark.usefixtures("disable_tts_entity") async def test_default_pipeline( hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity, mock_tts_provider: MockTTSProvider, ha_language: str, ha_country: str | None, conv_language: str, pipeline_language: str, stt_language: str, tts_language: str, ) -> None: """Test async_get_pipeline.""" hass.config.country = ha_country hass.config.language = ha_language with ( patch.object(mock_stt_provider_entity, "_supported_languages", MANY_LANGUAGES), patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES), ): assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( conversation_engine="conversation.home_assistant", conversation_language=conv_language, id=pipeline.id, language=pipeline_language, name="Home Assistant", stt_engine="stt.mock_stt", stt_language=stt_language, tts_engine="test", tts_language=tts_language, tts_voice=None, wake_word_entity=None, wake_word_id=None, ) @pytest.mark.usefixtures("init_supporting_components") @pytest.mark.usefixtures("disable_tts_entity") async def test_default_pipeline_unsupported_stt_language( hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity ) -> None: """Test async_get_pipeline.""" with patch.object(mock_stt_provider_entity, "_supported_languages", ["smurfish"]): assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( conversation_engine="conversation.home_assistant", conversation_language="en", id=pipeline.id, language="en", name="Home Assistant", stt_engine=None, stt_language=None, tts_engine="test", tts_language="en-US", tts_voice="james_earl_jones", wake_word_entity=None, wake_word_id=None, ) @pytest.mark.usefixtures("init_supporting_components") @pytest.mark.usefixtures("disable_tts_entity") async def test_default_pipeline_unsupported_tts_language( hass: HomeAssistant, mock_tts_provider: MockTTSProvider ) -> None: """Test async_get_pipeline.""" with patch.object(mock_tts_provider, "_supported_languages", ["smurfish"]): assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( conversation_engine="conversation.home_assistant", conversation_language="en", id=pipeline.id, language="en", name="Home Assistant", stt_engine="stt.mock_stt", stt_language="en-US", tts_engine=None, tts_language=None, tts_voice=None, wake_word_entity=None, wake_word_id=None, ) async def test_update_pipeline( hass: HomeAssistant, hass_storage: dict[str, Any] ) -> None: """Test async_update_pipeline.""" assert await async_setup_component(hass, "assist_pipeline", {}) pipelines = async_get_pipelines(hass) pipelines = list(pipelines) assert pipelines == [ Pipeline( conversation_engine="conversation.home_assistant", conversation_language="en", id=ANY, language="en", name="Home Assistant", stt_engine=None, stt_language=None, tts_engine=None, tts_language=None, tts_voice=None, wake_word_entity=None, wake_word_id=None, ) ] pipeline = pipelines[0] await async_update_pipeline( hass, pipeline, conversation_engine="homeassistant_1", conversation_language="de", language="de", name="Home Assistant 1", stt_engine="stt.test_1", stt_language="de", tts_engine="test_1", tts_language="de", tts_voice="test_voice", wake_word_entity="wake_work.test_1", wake_word_id="wake_word_id_1", ) pipelines = async_get_pipelines(hass) pipelines = list(pipelines) pipeline = pipelines[0] assert pipelines == [ Pipeline( conversation_engine="homeassistant_1", conversation_language="de", id=pipeline.id, language="de", name="Home Assistant 1", stt_engine="stt.test_1", stt_language="de", tts_engine="test_1", tts_language="de", tts_voice="test_voice", wake_word_entity="wake_work.test_1", wake_word_id="wake_word_id_1", ) ] assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1 assert hass_storage[STORAGE_KEY]["data"]["items"][0] == { "conversation_engine": "homeassistant_1", "conversation_language": "de", "id": pipeline.id, "language": "de", "name": "Home Assistant 1", "stt_engine": "stt.test_1", "stt_language": "de", "tts_engine": "test_1", "tts_language": "de", "tts_voice": "test_voice", "wake_word_entity": "wake_work.test_1", "wake_word_id": "wake_word_id_1", "prefer_local_intents": False, } await async_update_pipeline( hass, pipeline, stt_engine="stt.test_2", stt_language="en", tts_engine="test_2", tts_language="en", ) pipelines = async_get_pipelines(hass) pipelines = list(pipelines) assert pipelines == [ Pipeline( conversation_engine="homeassistant_1", conversation_language="de", id=pipeline.id, language="de", name="Home Assistant 1", stt_engine="stt.test_2", stt_language="en", tts_engine="test_2", tts_language="en", tts_voice="test_voice", wake_word_entity="wake_work.test_1", wake_word_id="wake_word_id_1", ) ] assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1 assert hass_storage[STORAGE_KEY]["data"]["items"][0] == { "conversation_engine": "homeassistant_1", "conversation_language": "de", "id": pipeline.id, "language": "de", "name": "Home Assistant 1", "stt_engine": "stt.test_2", "stt_language": "en", "tts_engine": "test_2", "tts_language": "en", "tts_voice": "test_voice", "wake_word_entity": "wake_work.test_1", "wake_word_id": "wake_word_id_1", "prefer_local_intents": False, } def test_fallback_intent_filter() -> None: """Test that we filter the right things.""" assert ( _async_local_fallback_intent_filter( RecognizeResult( intent=Intent(intent.INTENT_GET_STATE), intent_data=IntentData([]), entities={}, entities_list=[], ) ) is True ) assert ( _async_local_fallback_intent_filter( RecognizeResult( intent=Intent(media_player.INTENT_MEDIA_SEARCH_AND_PLAY), intent_data=IntentData([]), entities={}, entities_list=[], ) ) is True ) assert ( _async_local_fallback_intent_filter( RecognizeResult( intent=Intent(intent.INTENT_NEVERMIND), intent_data=IntentData([]), entities={}, entities_list=[], ) ) is False ) assert ( _async_local_fallback_intent_filter( RecognizeResult( intent=Intent(intent.INTENT_TURN_ON), intent_data=IntentData([]), entities={}, entities_list=[], ) ) is False ) async def test_wake_word_detection_aborted( hass: HomeAssistant, mock_stt_provider: MockSTTProvider, mock_wake_word_provider_entity: MockWakeWordEntity, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, mock_chat_session: chat_session.ChatSession, snapshot: SnapshotAssertion, ) -> None: """Test wake word stream is first detected, then aborted.""" events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): yield make_10ms_chunk(b"silence!") yield make_10ms_chunk(b"wake word!") yield make_10ms_chunk(b"part1") yield make_10ms_chunk(b"part2") yield b"" pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( session=mock_chat_session, device_id=None, stt_metadata=stt.SpeechMetadata( language="", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.WAKE_WORD, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, tts_audio_output=None, wake_word_settings=assist_pipeline.WakeWordSettings( audio_seconds_to_buffer=1.5 ), audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ), ) await pipeline_input.validate() updates = pipeline.to_json() updates.pop("id") await pipeline_store.async_update_item( pipeline_id, updates, ) await pipeline_input.execute() assert process_events(events) == snapshot def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None: """Test that pipeline run equality uses unique id.""" def event_callback(event): pass pipeline = assist_pipeline.pipeline.async_get_pipeline(hass) run_1 = assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.STT, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=event_callback, ) run_2 = assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.STT, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=event_callback, ) assert run_1 == run_1 # noqa: PLR0124 assert run_1 != run_2 assert run_1 != 1234 async def test_tts_audio_output( hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_entity: MockTTSProvider, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, mock_chat_session: chat_session.ChatSession, snapshot: SnapshotAssertion, ) -> None: """Test using tts_audio_output with wav sets options correctly.""" client = await hass_client() assert await async_setup_component(hass, media_source.DOMAIN, {}) events: list[assist_pipeline.PipelineEvent] = [] pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", session=mock_chat_session, device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.TTS, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, tts_audio_output="wav", ), ) await pipeline_input.validate() # Verify TTS audio settings assert pipeline_input.run.tts_stream.options is not None assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" assert ( pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000 ) assert ( pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1 ) with patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio: await pipeline_input.execute() for event in events: if event.type == assist_pipeline.PipelineEventType.TTS_END: # We must fetch the media URL to trigger the TTS assert event.data await client.get(event.data["tts_output"]["url"]) # Ensure that no unsupported options were passed in assert mock_get_tts_audio.called options = mock_get_tts_audio.call_args_list[0].kwargs["options"] extra_options = set(options).difference(mock_tts_entity.supported_options) assert len(extra_options) == 0, extra_options async def test_tts_wav_preferred_format( hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_entity: MockTTSEntity, init_components, mock_chat_session: chat_session.ChatSession, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that preferred format options are given to the TTS system if supported.""" client = await hass_client() assert await async_setup_component(hass, media_source.DOMAIN, {}) events: list[assist_pipeline.PipelineEvent] = [] pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", session=mock_chat_session, device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.TTS, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, tts_audio_output="wav", ), ) await pipeline_input.validate() # Make the TTS provider support preferred format options supported_options = list(mock_tts_entity.supported_options or []) supported_options.extend( [ tts.ATTR_PREFERRED_FORMAT, tts.ATTR_PREFERRED_SAMPLE_RATE, tts.ATTR_PREFERRED_SAMPLE_CHANNELS, tts.ATTR_PREFERRED_SAMPLE_BYTES, ] ) with ( patch.object(mock_tts_entity, "_supported_options", supported_options), patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio, ): await pipeline_input.execute() for event in events: if event.type == assist_pipeline.PipelineEventType.TTS_END: # We must fetch the media URL to trigger the TTS assert event.data await client.get(event.data["tts_output"]["url"]) assert mock_get_tts_audio.called options = mock_get_tts_audio.call_args_list[0].kwargs["options"] # We should have received preferred format options in get_tts_audio assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000 assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1 assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2 async def test_tts_dict_preferred_format( hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_entity: MockTTSEntity, init_components, mock_chat_session: chat_session.ChatSession, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that preferred format options are given to the TTS system if supported.""" client = await hass_client() assert await async_setup_component(hass, media_source.DOMAIN, {}) events: list[assist_pipeline.PipelineEvent] = [] pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", session=mock_chat_session, device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.TTS, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, tts_audio_output={ tts.ATTR_PREFERRED_FORMAT: "flac", tts.ATTR_PREFERRED_SAMPLE_RATE: 48000, tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2, tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, }, ), ) await pipeline_input.validate() # Make the TTS provider support preferred format options supported_options = list(mock_tts_entity.supported_options or []) supported_options.extend( [ tts.ATTR_PREFERRED_FORMAT, tts.ATTR_PREFERRED_SAMPLE_RATE, tts.ATTR_PREFERRED_SAMPLE_CHANNELS, tts.ATTR_PREFERRED_SAMPLE_BYTES, ] ) with ( patch.object(mock_tts_entity, "_supported_options", supported_options), patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio, ): await pipeline_input.execute() for event in events: if event.type == assist_pipeline.PipelineEventType.TTS_END: # We must fetch the media URL to trigger the TTS assert event.data await client.get(event.data["tts_output"]["url"]) assert mock_get_tts_audio.called options = mock_get_tts_audio.call_args_list[0].kwargs["options"] # We should have received preferred format options in get_tts_audio assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac" assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000 assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2 assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2 async def test_sentence_trigger_overrides_conversation_agent( hass: HomeAssistant, init_components, mock_chat_session: chat_session.ChatSession, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that sentence triggers are checked before a non-default conversation agent.""" assert await async_setup_component( hass, "automation", { "automation": { "trigger": { "platform": "conversation", "command": [ "test trigger sentence", ], }, "action": { "set_conversation_response": "test trigger response", }, } }, ) events: list[assist_pipeline.PipelineEvent] = [] pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test trigger sentence", session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, intent_agent="test-agent", # not the default agent ), ) # Ensure prepare succeeds with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", return_value=conversation.AgentInfo( id="test-agent", name="Test Agent", supports_streaming=False, ), ): await pipeline_input.validate() with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse" ) as mock_async_converse: await pipeline_input.execute() # Sentence trigger should have been handled mock_async_converse.assert_not_called() # Verify sentence trigger response intent_end_event = next( ( e for e in events if e.type == assist_pipeline.PipelineEventType.INTENT_END ), None, ) assert (intent_end_event is not None) and intent_end_event.data assert intent_end_event.data["processed_locally"] is True assert ( intent_end_event.data["intent_output"]["response"]["speech"]["plain"][ "speech" ] == "test trigger response" ) async def test_prefer_local_intents( hass: HomeAssistant, init_components, mock_chat_session: chat_session.ChatSession, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that the default agent is checked first when local intents are preferred.""" events: list[assist_pipeline.PipelineEvent] = [] # Reuse custom sentences in test config class OrderBeerIntentHandler(intent.IntentHandler): intent_type = "OrderBeer" async def async_handle( self, intent_obj: intent.Intent ) -> intent.IntentResponse: response = intent_obj.create_response() response.async_set_speech("Order confirmed") return response handler = OrderBeerIntentHandler() intent.async_register(hass, handler) # Fake a test agent and prefer local intents pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) await assist_pipeline.pipeline.async_update_pipeline( hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True ) pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="I'd like to order a stout please", session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, ), ) # Ensure prepare succeeds with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", return_value=conversation.AgentInfo( id="test-agent", name="Test Agent", supports_streaming=False, ), ): await pipeline_input.validate() with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse" ) as mock_async_converse: await pipeline_input.execute() # Test agent should not have been called mock_async_converse.assert_not_called() # Verify local intent response intent_end_event = next( ( e for e in events if e.type == assist_pipeline.PipelineEventType.INTENT_END ), None, ) assert (intent_end_event is not None) and intent_end_event.data assert intent_end_event.data["processed_locally"] is True assert ( intent_end_event.data["intent_output"]["response"]["speech"]["plain"][ "speech" ] == "Order confirmed" ) async def test_intent_continue_conversation( hass: HomeAssistant, init_components, mock_chat_session: chat_session.ChatSession, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that a conversation agent flagging continue conversation gets response.""" events: list[assist_pipeline.PipelineEvent] = [] # Fake a test agent and prefer local intents pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) await assist_pipeline.pipeline.async_update_pipeline( hass, pipeline, conversation_engine="test-agent" ) pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="Set a timer", session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, ), ) # Ensure prepare succeeds with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", return_value=conversation.AgentInfo( id="test-agent", name="Test Agent", supports_streaming=False, ), ): await pipeline_input.validate() response = intent.IntentResponse("en") response.async_set_speech("For how long?") with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", return_value=conversation.ConversationResult( response=response, conversation_id=mock_chat_session.conversation_id, continue_conversation=True, ), ) as mock_async_converse: await pipeline_input.execute() mock_async_converse.assert_called() results = [ event.data for event in events if event.type in ( assist_pipeline.PipelineEventType.INTENT_START, assist_pipeline.PipelineEventType.INTENT_END, ) ] assert results[1]["intent_output"]["continue_conversation"] is True # Change conversation agent to default one and register sentence trigger that should not be called await assist_pipeline.pipeline.async_update_pipeline( hass, pipeline, conversation_engine=None ) pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) assert await async_setup_component( hass, "automation", { "automation": { "trigger": { "platform": "conversation", "command": ["Hello"], }, "action": { "set_conversation_response": "test trigger response", }, } }, ) # Because we did continue conversation, it should respond to the test agent again. events.clear() pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="Hello", session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, ), ) # Ensure prepare succeeds with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", return_value=conversation.AgentInfo( id="test-agent", name="Test Agent", supports_streaming=False, ), ) as mock_prepare: await pipeline_input.validate() # It requested test agent even if that was not default agent. assert mock_prepare.mock_calls[0][1][1] == "test-agent" response = intent.IntentResponse("en") response.async_set_speech("Timer set for 20 minutes") with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", return_value=conversation.ConversationResult( response=response, conversation_id=mock_chat_session.conversation_id, ), ) as mock_async_converse: await pipeline_input.execute() mock_async_converse.assert_called() # Snapshot will show it was still handled by the test agent and not default agent results = [ event.data for event in events if event.type in ( assist_pipeline.PipelineEventType.INTENT_START, assist_pipeline.PipelineEventType.INTENT_END, ) ] assert results[0]["engine"] == "test-agent" assert results[1]["intent_output"]["continue_conversation"] is False async def test_stt_language_used_instead_of_conversation_language( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, mock_chat_session: chat_session.ChatSession, snapshot: SnapshotAssertion, ) -> None: """Test that the STT language is used first when the conversation language is '*' (all languages).""" client = await hass_ws_client(hass) events: list[assist_pipeline.PipelineEvent] = [] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/create", "conversation_engine": conversation.HOME_ASSISTANT_AGENT, "conversation_language": MATCH_ALL, "language": "en", "name": "test_name", "stt_engine": "test", "stt_language": "en-US", "tts_engine": "test", "tts_language": "en-US", "tts_voice": "Arnold Schwarzenegger", "wake_word_entity": None, "wake_word_id": None, } ) msg = await client.receive_json() assert msg["success"] pipeline_id = msg["result"]["id"] pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, ), ) await pipeline_input.validate() with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", return_value=conversation.ConversationResult( intent.IntentResponse(pipeline.language) ), ) as mock_async_converse: await pipeline_input.execute() # Check intent start event assert process_events(events) == snapshot intent_start: assist_pipeline.PipelineEvent | None = None for event in events: if event.type == assist_pipeline.PipelineEventType.INTENT_START: intent_start = event break assert intent_start is not None # STT language (en-US) should be used instead of '*' assert intent_start.data.get("language") == pipeline.stt_language # Check input to async_converse mock_async_converse.assert_called_once() assert ( mock_async_converse.call_args_list[0].kwargs.get("language") == pipeline.stt_language ) async def test_tts_language_used_instead_of_conversation_language( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, mock_chat_session: chat_session.ChatSession, snapshot: SnapshotAssertion, ) -> None: """Test that the TTS language is used after STT when the conversation language is '*' (all languages).""" client = await hass_ws_client(hass) events: list[assist_pipeline.PipelineEvent] = [] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/create", "conversation_engine": conversation.HOME_ASSISTANT_AGENT, "conversation_language": MATCH_ALL, "language": "en", "name": "test_name", "stt_engine": None, "stt_language": None, "tts_engine": None, "tts_language": "en-us", "tts_voice": "Arnold Schwarzenegger", "wake_word_entity": None, "wake_word_id": None, } ) msg = await client.receive_json() assert msg["success"] pipeline_id = msg["result"]["id"] pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, ), ) await pipeline_input.validate() with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", return_value=conversation.ConversationResult( intent.IntentResponse(pipeline.language) ), ) as mock_async_converse: await pipeline_input.execute() # Check intent start event assert process_events(events) == snapshot intent_start: assist_pipeline.PipelineEvent | None = None for event in events: if event.type == assist_pipeline.PipelineEventType.INTENT_START: intent_start = event break assert intent_start is not None # STT language (en-US) should be used instead of '*' assert intent_start.data.get("language") == pipeline.tts_language # Check input to async_converse mock_async_converse.assert_called_once() assert ( mock_async_converse.call_args_list[0].kwargs.get("language") == pipeline.tts_language ) async def test_pipeline_language_used_instead_of_conversation_language( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, mock_chat_session: chat_session.ChatSession, snapshot: SnapshotAssertion, ) -> None: """Test that the pipeline language is used last when the conversation language is '*' (all languages).""" client = await hass_ws_client(hass) events: list[assist_pipeline.PipelineEvent] = [] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/create", "conversation_engine": conversation.HOME_ASSISTANT_AGENT, "conversation_language": MATCH_ALL, "language": "en", "name": "test_name", "stt_engine": None, "stt_language": None, "tts_engine": None, "tts_language": None, "tts_voice": None, "wake_word_entity": None, "wake_word_id": None, } ) msg = await client.receive_json() assert msg["success"] pipeline_id = msg["result"]["id"] pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.INTENT, event_callback=events.append, ), ) await pipeline_input.validate() with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", return_value=conversation.ConversationResult( intent.IntentResponse(pipeline.language) ), ) as mock_async_converse: await pipeline_input.execute() # Check intent start event assert process_events(events) == snapshot intent_start: assist_pipeline.PipelineEvent | None = None for event in events: if event.type == assist_pipeline.PipelineEventType.INTENT_START: intent_start = event break assert intent_start is not None # STT language (en-US) should be used instead of '*' assert intent_start.data.get("language") == pipeline.language # Check input to async_converse mock_async_converse.assert_called_once() assert ( mock_async_converse.call_args_list[0].kwargs.get("language") == pipeline.language ) @pytest.mark.parametrize( ("to_stream_deltas", "expected_chunks", "chunk_text"), [ # Size below STREAM_RESPONSE_CHUNKS ( ( [ "hello,", " ", "how", " ", "are", " ", "you", "?", ], ), # We always stream when possible, so 1 chunk via streaming method 1, "hello, how are you?", ), # Size above STREAM_RESPONSE_CHUNKS ( ( [ "hello, ", "how ", "are ", "you", "? ", "I'm ", "doing ", "well", ", ", "thank ", "you", ". ", "What ", "about ", "you", "?", "!", ], ), # We are streamed. First 15 chunks are grouped into 1 chunk # and the rest are streamed 3, "hello, how are you? I'm doing well, thank you. What about you?!", ), # Stream a bit, then a tool call, then stream some more ( ( [ "hello, ", "how ", "are ", "you", "? ", ], { "tool_calls": [ llm.ToolInput( tool_name="test_tool", tool_args={}, id="test_tool_id", ) ], }, [ "I'm ", "doing ", "well", ", ", "thank ", "you", ".", ], ), # 1 chunk before tool call, then 7 after 8, "hello, how are you? I'm doing well, thank you.", ), ], ) @freeze_time("2025-10-31 12:00:00") async def test_chat_log_tts_streaming( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, mock_chat_session: chat_session.ChatSession, snapshot: SnapshotAssertion, mock_tts_entity: MockTTSEntity, pipeline_data: assist_pipeline.pipeline.PipelineData, to_stream_deltas: tuple[dict | list[str]], expected_chunks: int, chunk_text: str, ) -> None: """Test that chat log events are streamed to the TTS entity.""" text_deltas = [ delta for deltas in to_stream_deltas if isinstance(deltas, list) for delta in deltas ] events: list[assist_pipeline.PipelineEvent] = [] pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) await assist_pipeline.pipeline.async_update_pipeline( hass, pipeline, conversation_engine="test-agent" ) pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="Set a timer", session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, ), ) received_tts = [] async def async_stream_tts_audio( request: tts.TTSAudioRequest, ) -> tts.TTSAudioResponse: """Mock stream TTS audio.""" async def gen_data(): async for msg in request.message_gen: received_tts.append(msg) yield msg.encode() return tts.TTSAudioResponse( extension="mp3", data_gen=gen_data(), ) async def async_get_tts_audio( message: str, language: str, options: dict[str, Any] | None = None, ) -> tts.TtsAudioType: """Mock get TTS audio.""" return ("mp3", b"".join([chunk.encode() for chunk in text_deltas])) mock_tts_entity.async_get_tts_audio = async_get_tts_audio mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio mock_tts_entity.async_supports_streaming_input = Mock(return_value=True) with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", return_value=conversation.AgentInfo( id="test-agent", name="Test Agent", supports_streaming=True, ), ): await pipeline_input.validate() async def mock_converse( hass: HomeAssistant, text: str, conversation_id: str | None, context: Context, language: str | None = None, agent_id: str | None = None, device_id: str | None = None, satellite_id: str | None = None, extra_system_prompt: str | None = None, ): """Mock converse.""" conversation_input = conversation.ConversationInput( text=text, context=context, conversation_id=conversation_id, device_id=device_id, satellite_id=satellite_id, language=language, agent_id=agent_id, extra_system_prompt=extra_system_prompt, ) async def stream_llm_response(): for deltas in to_stream_deltas: if isinstance(deltas, dict): yield deltas else: yield {"role": "assistant"} for chunk in deltas: yield {"content": chunk} with ( chat_session.async_get_chat_session(hass, conversation_id) as session, conversation.async_get_chat_log( hass, session, conversation_input, ) as chat_log, ): await chat_log.async_provide_llm_data( conversation_input.as_llm_context("test"), user_llm_hass_api="assist", user_llm_prompt=None, user_extra_system_prompt=conversation_input.extra_system_prompt, ) async for _content in chat_log.async_add_delta_content_stream( agent_id, stream_llm_response() ): pass intent_response = intent.IntentResponse(language) intent_response.async_set_speech("".join(to_stream_deltas[-1])) return conversation.ConversationResult( response=intent_response, conversation_id=chat_log.conversation_id, continue_conversation=chat_log.continue_conversation, ) mock_tool = AsyncMock() mock_tool.name = "test_tool" mock_tool.description = "Test function" mock_tool.parameters = vol.Schema({}) mock_tool.async_call.return_value = "Test response" with ( patch( "homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[mock_tool], ), patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", mock_converse, ), ): await pipeline_input.execute() stream = tts.async_get_stream(hass, events[0].data["tts_output"]["token"]) assert stream is not None tts_result = "".join( [chunk.decode() async for chunk in stream.async_stream_result()] ) streamed_text = "".join(text_deltas) assert tts_result == streamed_text assert len(received_tts) == expected_chunks assert "".join(received_tts) == chunk_text assert process_events(events) == snapshot @pytest.mark.parametrize(("use_satellite_entity"), [True, False]) async def test_acknowledge( hass: HomeAssistant, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, mock_chat_session: chat_session.ChatSession, entity_registry: er.EntityRegistry, area_registry: ar.AreaRegistry, device_registry: dr.DeviceRegistry, use_satellite_entity: bool, ) -> None: """Test that acknowledge sound is played when targets are in the same area.""" area_1 = area_registry.async_get_or_create("area_1") light_1 = entity_registry.async_get_or_create("light", "demo", "1234") hass.states.async_set(light_1.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 1"}) light_1 = entity_registry.async_update_entity(light_1.entity_id, area_id=area_1.id) light_2 = entity_registry.async_get_or_create("light", "demo", "5678") hass.states.async_set(light_2.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 2"}) light_2 = entity_registry.async_update_entity(light_2.entity_id, area_id=area_1.id) entry = MockConfigEntry() entry.add_to_hass(hass) satellite = entity_registry.async_get_or_create("assist_satellite", "test", "1234") entity_registry.async_update_entity(satellite.entity_id, area_id=area_1.id) satellite_device = device_registry.async_get_or_create( config_entry_id=entry.entry_id, connections=set(), identifiers={("demo", "id-1234")}, ) device_registry.async_update_device(satellite_device.id, area_id=area_1.id) events: list[assist_pipeline.PipelineEvent] = [] turn_on = async_mock_service(hass, "light", "turn_on") pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) async def _run(text: str) -> None: pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input=text, session=mock_chat_session, satellite_id=satellite.entity_id if use_satellite_entity else None, device_id=satellite_device.id if not use_satellite_entity else None, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, ), ) await pipeline_input.validate() await pipeline_input.execute() with patch( "homeassistant.components.assist_pipeline.PipelineRun.text_to_speech" ) as text_to_speech: def _reset() -> None: events.clear() text_to_speech.reset_mock() turn_on.clear() # 1. All targets in same area await _run("turn on the lights") # Acknowledgment sound should be played (same area) text_to_speech.assert_called_once() assert ( text_to_speech.call_args.kwargs["override_media_path"] == ACKNOWLEDGE_PATH ) assert len(turn_on) == 2 # 2. One light in a different area area_2 = area_registry.async_get_or_create("area_2") light_2 = entity_registry.async_update_entity( light_2.entity_id, area_id=area_2.id ) _reset() await _run("turn on light 2") # Acknowledgment sound should be not played (different area) text_to_speech.assert_called_once() assert text_to_speech.call_args.kwargs.get("override_media_path") is None assert len(turn_on) == 1 # Restore light_2 = entity_registry.async_update_entity( light_2.entity_id, area_id=area_1.id ) # 3. Remove satellite device area entity_registry.async_update_entity(satellite.entity_id, area_id=None) device_registry.async_update_device(satellite_device.id, area_id=None) _reset() await _run("turn on light 1") # Acknowledgment sound should be not played (no satellite area) text_to_speech.assert_called_once() assert text_to_speech.call_args.kwargs.get("override_media_path") is None assert len(turn_on) == 1 # Restore entity_registry.async_update_entity(satellite.entity_id, area_id=area_1.id) device_registry.async_update_device(satellite_device.id, area_id=area_1.id) # 4. Check device area instead of entity area light_device = device_registry.async_get_or_create( config_entry_id=entry.entry_id, connections=set(), identifiers={("demo", "id-5678")}, ) device_registry.async_update_device(light_device.id, area_id=area_1.id) light_2 = entity_registry.async_update_entity( light_2.entity_id, area_id=None, device_id=light_device.id ) _reset() await _run("turn on the lights") # Acknowledgment sound should be played (same area) text_to_speech.assert_called_once() assert ( text_to_speech.call_args.kwargs["override_media_path"] == ACKNOWLEDGE_PATH ) assert len(turn_on) == 2 # 5. Move device to different area device_registry.async_update_device(light_device.id, area_id=area_2.id) _reset() await _run("turn on light 2") # Acknowledgment sound should be not played (different device area) text_to_speech.assert_called_once() assert text_to_speech.call_args.kwargs.get("override_media_path") is None assert len(turn_on) == 1 # 6. No device or area light_2 = entity_registry.async_update_entity( light_2.entity_id, area_id=None, device_id=None ) _reset() await _run("turn on light 2") # Acknowledgment sound should be not played (no area) text_to_speech.assert_called_once() assert text_to_speech.call_args.kwargs.get("override_media_path") is None assert len(turn_on) == 1 # 7. Not in entity registry hass.states.async_set("light.light_3", "off", {ATTR_FRIENDLY_NAME: "light 3"}) _reset() await _run("turn on light 3") # Acknowledgment sound should be not played (not in entity registry) text_to_speech.assert_called_once() assert text_to_speech.call_args.kwargs.get("override_media_path") is None assert len(turn_on) == 1 # Check TTS event events.clear() await _run("turn on light 1") has_acknowledge_override: bool | None = None for event in events: if event.type == PipelineEventType.TTS_START: assert event.data has_acknowledge_override = event.data["acknowledge_override"] break assert has_acknowledge_override async def test_acknowledge_other_agents( hass: HomeAssistant, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, mock_chat_session: chat_session.ChatSession, entity_registry: er.EntityRegistry, area_registry: ar.AreaRegistry, device_registry: dr.DeviceRegistry, ) -> None: """Test that acknowledge sound is only played when intents are processed locally for other agents.""" area_1 = area_registry.async_get_or_create("area_1") light_1 = entity_registry.async_get_or_create("light", "demo", "1234") hass.states.async_set(light_1.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 1"}) light_1 = entity_registry.async_update_entity(light_1.entity_id, area_id=area_1.id) light_2 = entity_registry.async_get_or_create("light", "demo", "5678") hass.states.async_set(light_2.entity_id, "off", {ATTR_FRIENDLY_NAME: "light 2"}) light_2 = entity_registry.async_update_entity(light_2.entity_id, area_id=area_1.id) entry = MockConfigEntry() entry.add_to_hass(hass) satellite = device_registry.async_get_or_create( config_entry_id=entry.entry_id, connections=set(), identifiers={("demo", "id-1234")}, ) device_registry.async_update_device(satellite.id, area_id=area_1.id) events: list[assist_pipeline.PipelineEvent] = [] async_mock_service(hass, "light", "turn_on") pipeline_store = pipeline_data.pipeline_store pipeline = await pipeline_store.async_create_item( { "name": "Test 1", "language": "en-US", "conversation_engine": "test agent", "conversation_language": "en-US", "tts_engine": "test tts", "tts_language": "en-US", "tts_voice": "test voice", "stt_engine": "test stt", "stt_language": "en-US", "wake_word_entity": None, "wake_word_id": None, "prefer_local_intents": True, } ) with ( patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", return_value=conversation.AgentInfo( id="test-agent", name="Test Agent", supports_streaming=False, ), ), patch( "homeassistant.components.assist_pipeline.PipelineRun.prepare_text_to_speech" ), patch( "homeassistant.components.assist_pipeline.PipelineRun.text_to_speech" ) as text_to_speech, patch( "homeassistant.components.conversation.async_converse", return_value=None ) as async_converse, patch( "homeassistant.components.assist_pipeline.PipelineRun._get_all_targets_in_satellite_area" ) as get_all_targets_in_satellite_area, ): pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="turn on the lights", session=mock_chat_session, device_id=satellite.id, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, ), ) await pipeline_input.validate() await pipeline_input.execute() # Processed locally async_converse.assert_not_called() # Not processed locally text_to_speech.reset_mock() get_all_targets_in_satellite_area.reset_mock() pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="not processed locally", session=mock_chat_session, device_id=satellite.id, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.INTENT, end_stage=assist_pipeline.PipelineStage.TTS, event_callback=events.append, ), ) await pipeline_input.validate() await pipeline_input.execute() # The acknowledgment should not have even been checked for because the # default agent didn't handle the intent. text_to_speech.assert_not_called() async_converse.assert_called_once() get_all_targets_in_satellite_area.assert_not_called()