From 18767154df5acc9d86fb0f6c078fffe92d1dcf07 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 21 Jun 2024 06:24:53 -0500 Subject: [PATCH] Generate and keep conversation id for Wyoming satellite (#118835) --- homeassistant/components/wyoming/satellite.py | 20 ++++ tests/components/wyoming/test_satellite.py | 101 ++++++++++++++++++ 2 files changed, 121 insertions(+) diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite.py index 41ca2887d88..5af0c54abad 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/satellite.py @@ -3,7 +3,9 @@ import asyncio import io import logging +import time from typing import Final +from uuid import uuid4 import wave from typing_extensions import AsyncGenerator @@ -38,6 +40,7 @@ _RESTART_SECONDS: Final = 3 _PING_TIMEOUT: Final = 5 _PING_SEND_DELAY: Final = 2 _PIPELINE_FINISH_TIMEOUT: Final = 1 +_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes # Wyoming stage -> Assist stage _STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = { @@ -73,6 +76,9 @@ class WyomingSatellite: self._pipeline_id: str | None = None self._muted_changed_event = asyncio.Event() + self._conversation_id: str | None = None + self._conversation_id_time: float | None = None + self.device.set_is_muted_listener(self._muted_changed) self.device.set_pipeline_listener(self._pipeline_changed) self.device.set_audio_settings_listener(self._audio_settings_changed) @@ -365,6 +371,19 @@ class WyomingSatellite: start_stage, end_stage, ) + + # Reset conversation id, if necessary + if (self._conversation_id_time is None) or ( + (time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC + ): + self._conversation_id = None + + if self._conversation_id is None: + self._conversation_id = str(uuid4()) + + # Update timeout + self._conversation_id_time = time.monotonic() + self._is_pipeline_running = True self._pipeline_ended_event.clear() self.config_entry.async_create_background_task( @@ -393,6 +412,7 @@ class WyomingSatellite: ), device_id=self.device.device_id, wake_word_phrase=wake_word_phrase, + conversation_id=self._conversation_id, ), name="wyoming satellite pipeline", ) diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index 4d39607158e..1a291153ad0 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -1285,3 +1285,104 @@ async def test_timers(hass: HomeAssistant) -> None: timer_finished = mock_client.timer_finished assert timer_finished is not None assert timer_finished.id == timer_started.id + + +async def test_satellite_conversation_id(hass: HomeAssistant) -> None: + """Test that the same conversation id is used until timeout.""" + assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) + + events = [ + RunPipeline( + start_stage=PipelineStage.WAKE, + end_stage=PipelineStage.TTS, + restart_on_end=True, + ).event(), + ] + + pipeline_kwargs: dict[str, Any] = {} + pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = ( + None + ) + run_pipeline_called = asyncio.Event() + + async def async_pipeline_from_audio_stream( + hass: HomeAssistant, + context, + event_callback, + stt_metadata, + stt_stream, + **kwargs, + ) -> None: + nonlocal pipeline_kwargs, pipeline_event_callback + pipeline_kwargs = kwargs + pipeline_event_callback = event_callback + + run_pipeline_called.set() + + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, + patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + async_pipeline_from_audio_stream, + ), + patch( + "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", + return_value=("wav", get_test_wav()), + ), + patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0), + ): + entry = await setup_config_entry(hass) + satellite: wyoming.WyomingSatellite = hass.data[wyoming.DOMAIN][ + entry.entry_id + ].satellite + + async with asyncio.timeout(1): + await mock_client.connect_event.wait() + await mock_client.run_satellite_event.wait() + + async with asyncio.timeout(1): + await run_pipeline_called.wait() + + assert pipeline_event_callback is not None + + # A conversation id should have been generated + conversation_id = pipeline_kwargs.get("conversation_id") + assert conversation_id + + # Reset and run again + run_pipeline_called.clear() + pipeline_kwargs.clear() + + pipeline_event_callback( + assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END) + ) + + async with asyncio.timeout(1): + await run_pipeline_called.wait() + + # Should be the same conversation id + assert pipeline_kwargs.get("conversation_id") == conversation_id + + # Reset and run again, but this time "time out" + satellite._conversation_id_time = None + run_pipeline_called.clear() + pipeline_kwargs.clear() + + pipeline_event_callback( + assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END) + ) + + async with asyncio.timeout(1): + await run_pipeline_called.wait() + + # Should be a different conversation id + new_conversation_id = pipeline_kwargs.get("conversation_id") + assert new_conversation_id + assert new_conversation_id != conversation_id