Generate and keep conversation id for Wyoming satellite (#118835)

This commit is contained in:
Michael Hansen 2024-06-21 06:24:53 -05:00 committed by GitHub
parent 955685e116
commit 18767154df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 0 deletions

View File

@ -3,7 +3,9 @@
import asyncio
import io
import logging
import time
from typing import Final
from uuid import uuid4
import wave
from typing_extensions import AsyncGenerator
@ -38,6 +40,7 @@ _RESTART_SECONDS: Final = 3
_PING_TIMEOUT: Final = 5
_PING_SEND_DELAY: Final = 2
_PIPELINE_FINISH_TIMEOUT: Final = 1
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
# Wyoming stage -> Assist stage
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
@ -73,6 +76,9 @@ class WyomingSatellite:
self._pipeline_id: str | None = None
self._muted_changed_event = asyncio.Event()
self._conversation_id: str | None = None
self._conversation_id_time: float | None = None
self.device.set_is_muted_listener(self._muted_changed)
self.device.set_pipeline_listener(self._pipeline_changed)
self.device.set_audio_settings_listener(self._audio_settings_changed)
@ -365,6 +371,19 @@ class WyomingSatellite:
start_stage,
end_stage,
)
# Reset conversation id, if necessary
if (self._conversation_id_time is None) or (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
if self._conversation_id is None:
self._conversation_id = str(uuid4())
# Update timeout
self._conversation_id_time = time.monotonic()
self._is_pipeline_running = True
self._pipeline_ended_event.clear()
self.config_entry.async_create_background_task(
@ -393,6 +412,7 @@ class WyomingSatellite:
),
device_id=self.device.device_id,
wake_word_phrase=wake_word_phrase,
conversation_id=self._conversation_id,
),
name="wyoming satellite pipeline",
)

View File

@ -1285,3 +1285,104 @@ async def test_timers(hass: HomeAssistant) -> None:
timer_finished = mock_client.timer_finished
assert timer_finished is not None
assert timer_finished.id == timer_started.id
async def test_satellite_conversation_id(hass: HomeAssistant) -> None:
"""Test that the same conversation id is used until timeout."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE,
end_stage=PipelineStage.TTS,
restart_on_end=True,
).event(),
]
pipeline_kwargs: dict[str, Any] = {}
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
None
)
run_pipeline_called = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_kwargs, pipeline_event_callback
pipeline_kwargs = kwargs
pipeline_event_callback = event_callback
run_pipeline_called.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()),
),
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
):
entry = await setup_config_entry(hass)
satellite: wyoming.WyomingSatellite = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
async with asyncio.timeout(1):
await run_pipeline_called.wait()
assert pipeline_event_callback is not None
# A conversation id should have been generated
conversation_id = pipeline_kwargs.get("conversation_id")
assert conversation_id
# Reset and run again
run_pipeline_called.clear()
pipeline_kwargs.clear()
pipeline_event_callback(
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
)
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Should be the same conversation id
assert pipeline_kwargs.get("conversation_id") == conversation_id
# Reset and run again, but this time "time out"
satellite._conversation_id_time = None
run_pipeline_called.clear()
pipeline_kwargs.clear()
pipeline_event_callback(
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
)
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Should be a different conversation id
new_conversation_id = pipeline_kwargs.get("conversation_id")
assert new_conversation_id
assert new_conversation_id != conversation_id