Generate and keep conversation id for Wyoming satellite (#118835)

This commit is contained in:
Michael Hansen 2024-06-21 06:24:53 -05:00 committed by GitHub
parent 955685e116
commit 18767154df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 0 deletions

View File

@ -3,7 +3,9 @@
import asyncio import asyncio
import io import io
import logging import logging
import time
from typing import Final from typing import Final
from uuid import uuid4
import wave import wave
from typing_extensions import AsyncGenerator from typing_extensions import AsyncGenerator
@ -38,6 +40,7 @@ _RESTART_SECONDS: Final = 3
_PING_TIMEOUT: Final = 5 _PING_TIMEOUT: Final = 5
_PING_SEND_DELAY: Final = 2 _PING_SEND_DELAY: Final = 2
_PIPELINE_FINISH_TIMEOUT: Final = 1 _PIPELINE_FINISH_TIMEOUT: Final = 1
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
# Wyoming stage -> Assist stage # Wyoming stage -> Assist stage
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = { _STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
@ -73,6 +76,9 @@ class WyomingSatellite:
self._pipeline_id: str | None = None self._pipeline_id: str | None = None
self._muted_changed_event = asyncio.Event() self._muted_changed_event = asyncio.Event()
self._conversation_id: str | None = None
self._conversation_id_time: float | None = None
self.device.set_is_muted_listener(self._muted_changed) self.device.set_is_muted_listener(self._muted_changed)
self.device.set_pipeline_listener(self._pipeline_changed) self.device.set_pipeline_listener(self._pipeline_changed)
self.device.set_audio_settings_listener(self._audio_settings_changed) self.device.set_audio_settings_listener(self._audio_settings_changed)
@ -365,6 +371,19 @@ class WyomingSatellite:
start_stage, start_stage,
end_stage, end_stage,
) )
# Reset conversation id, if necessary
if (self._conversation_id_time is None) or (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
if self._conversation_id is None:
self._conversation_id = str(uuid4())
# Update timeout
self._conversation_id_time = time.monotonic()
self._is_pipeline_running = True self._is_pipeline_running = True
self._pipeline_ended_event.clear() self._pipeline_ended_event.clear()
self.config_entry.async_create_background_task( self.config_entry.async_create_background_task(
@ -393,6 +412,7 @@ class WyomingSatellite:
), ),
device_id=self.device.device_id, device_id=self.device.device_id,
wake_word_phrase=wake_word_phrase, wake_word_phrase=wake_word_phrase,
conversation_id=self._conversation_id,
), ),
name="wyoming satellite pipeline", name="wyoming satellite pipeline",
) )

View File

@ -1285,3 +1285,104 @@ async def test_timers(hass: HomeAssistant) -> None:
timer_finished = mock_client.timer_finished timer_finished = mock_client.timer_finished
assert timer_finished is not None assert timer_finished is not None
assert timer_finished.id == timer_started.id assert timer_finished.id == timer_started.id
async def test_satellite_conversation_id(hass: HomeAssistant) -> None:
"""Test that the same conversation id is used until timeout."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE,
end_stage=PipelineStage.TTS,
restart_on_end=True,
).event(),
]
pipeline_kwargs: dict[str, Any] = {}
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
None
)
run_pipeline_called = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_kwargs, pipeline_event_callback
pipeline_kwargs = kwargs
pipeline_event_callback = event_callback
run_pipeline_called.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()),
),
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
):
entry = await setup_config_entry(hass)
satellite: wyoming.WyomingSatellite = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
async with asyncio.timeout(1):
await run_pipeline_called.wait()
assert pipeline_event_callback is not None
# A conversation id should have been generated
conversation_id = pipeline_kwargs.get("conversation_id")
assert conversation_id
# Reset and run again
run_pipeline_called.clear()
pipeline_kwargs.clear()
pipeline_event_callback(
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
)
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Should be the same conversation id
assert pipeline_kwargs.get("conversation_id") == conversation_id
# Reset and run again, but this time "time out"
satellite._conversation_id_time = None
run_pipeline_called.clear()
pipeline_kwargs.clear()
pipeline_event_callback(
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
)
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Should be a different conversation id
new_conversation_id = pipeline_kwargs.get("conversation_id")
assert new_conversation_id
assert new_conversation_id != conversation_id