mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 19:27:45 +00:00
Generate and keep conversation id for Wyoming satellite (#118835)
This commit is contained in:
parent
955685e116
commit
18767154df
@ -3,7 +3,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Final
|
from typing import Final
|
||||||
|
from uuid import uuid4
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
from typing_extensions import AsyncGenerator
|
from typing_extensions import AsyncGenerator
|
||||||
@ -38,6 +40,7 @@ _RESTART_SECONDS: Final = 3
|
|||||||
_PING_TIMEOUT: Final = 5
|
_PING_TIMEOUT: Final = 5
|
||||||
_PING_SEND_DELAY: Final = 2
|
_PING_SEND_DELAY: Final = 2
|
||||||
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
||||||
|
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||||
|
|
||||||
# Wyoming stage -> Assist stage
|
# Wyoming stage -> Assist stage
|
||||||
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
||||||
@ -73,6 +76,9 @@ class WyomingSatellite:
|
|||||||
self._pipeline_id: str | None = None
|
self._pipeline_id: str | None = None
|
||||||
self._muted_changed_event = asyncio.Event()
|
self._muted_changed_event = asyncio.Event()
|
||||||
|
|
||||||
|
self._conversation_id: str | None = None
|
||||||
|
self._conversation_id_time: float | None = None
|
||||||
|
|
||||||
self.device.set_is_muted_listener(self._muted_changed)
|
self.device.set_is_muted_listener(self._muted_changed)
|
||||||
self.device.set_pipeline_listener(self._pipeline_changed)
|
self.device.set_pipeline_listener(self._pipeline_changed)
|
||||||
self.device.set_audio_settings_listener(self._audio_settings_changed)
|
self.device.set_audio_settings_listener(self._audio_settings_changed)
|
||||||
@ -365,6 +371,19 @@ class WyomingSatellite:
|
|||||||
start_stage,
|
start_stage,
|
||||||
end_stage,
|
end_stage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reset conversation id, if necessary
|
||||||
|
if (self._conversation_id_time is None) or (
|
||||||
|
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||||
|
):
|
||||||
|
self._conversation_id = None
|
||||||
|
|
||||||
|
if self._conversation_id is None:
|
||||||
|
self._conversation_id = str(uuid4())
|
||||||
|
|
||||||
|
# Update timeout
|
||||||
|
self._conversation_id_time = time.monotonic()
|
||||||
|
|
||||||
self._is_pipeline_running = True
|
self._is_pipeline_running = True
|
||||||
self._pipeline_ended_event.clear()
|
self._pipeline_ended_event.clear()
|
||||||
self.config_entry.async_create_background_task(
|
self.config_entry.async_create_background_task(
|
||||||
@ -393,6 +412,7 @@ class WyomingSatellite:
|
|||||||
),
|
),
|
||||||
device_id=self.device.device_id,
|
device_id=self.device.device_id,
|
||||||
wake_word_phrase=wake_word_phrase,
|
wake_word_phrase=wake_word_phrase,
|
||||||
|
conversation_id=self._conversation_id,
|
||||||
),
|
),
|
||||||
name="wyoming satellite pipeline",
|
name="wyoming satellite pipeline",
|
||||||
)
|
)
|
||||||
|
@ -1285,3 +1285,104 @@ async def test_timers(hass: HomeAssistant) -> None:
|
|||||||
timer_finished = mock_client.timer_finished
|
timer_finished = mock_client.timer_finished
|
||||||
assert timer_finished is not None
|
assert timer_finished is not None
|
||||||
assert timer_finished.id == timer_started.id
|
assert timer_finished.id == timer_started.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_satellite_conversation_id(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that the same conversation id is used until timeout."""
|
||||||
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||||
|
|
||||||
|
events = [
|
||||||
|
RunPipeline(
|
||||||
|
start_stage=PipelineStage.WAKE,
|
||||||
|
end_stage=PipelineStage.TTS,
|
||||||
|
restart_on_end=True,
|
||||||
|
).event(),
|
||||||
|
]
|
||||||
|
|
||||||
|
pipeline_kwargs: dict[str, Any] = {}
|
||||||
|
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
|
||||||
|
None
|
||||||
|
)
|
||||||
|
run_pipeline_called = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
context,
|
||||||
|
event_callback,
|
||||||
|
stt_metadata,
|
||||||
|
stt_stream,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
nonlocal pipeline_kwargs, pipeline_event_callback
|
||||||
|
pipeline_kwargs = kwargs
|
||||||
|
pipeline_event_callback = event_callback
|
||||||
|
|
||||||
|
run_pipeline_called.set()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
|
SatelliteAsyncTcpClient(events),
|
||||||
|
) as mock_client,
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
|
async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||||
|
return_value=("wav", get_test_wav()),
|
||||||
|
),
|
||||||
|
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
|
||||||
|
):
|
||||||
|
entry = await setup_config_entry(hass)
|
||||||
|
satellite: wyoming.WyomingSatellite = hass.data[wyoming.DOMAIN][
|
||||||
|
entry.entry_id
|
||||||
|
].satellite
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.connect_event.wait()
|
||||||
|
await mock_client.run_satellite_event.wait()
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await run_pipeline_called.wait()
|
||||||
|
|
||||||
|
assert pipeline_event_callback is not None
|
||||||
|
|
||||||
|
# A conversation id should have been generated
|
||||||
|
conversation_id = pipeline_kwargs.get("conversation_id")
|
||||||
|
assert conversation_id
|
||||||
|
|
||||||
|
# Reset and run again
|
||||||
|
run_pipeline_called.clear()
|
||||||
|
pipeline_kwargs.clear()
|
||||||
|
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await run_pipeline_called.wait()
|
||||||
|
|
||||||
|
# Should be the same conversation id
|
||||||
|
assert pipeline_kwargs.get("conversation_id") == conversation_id
|
||||||
|
|
||||||
|
# Reset and run again, but this time "time out"
|
||||||
|
satellite._conversation_id_time = None
|
||||||
|
run_pipeline_called.clear()
|
||||||
|
pipeline_kwargs.clear()
|
||||||
|
|
||||||
|
pipeline_event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await run_pipeline_called.wait()
|
||||||
|
|
||||||
|
# Should be a different conversation id
|
||||||
|
new_conversation_id = pipeline_kwargs.get("conversation_id")
|
||||||
|
assert new_conversation_id
|
||||||
|
assert new_conversation_id != conversation_id
|
||||||
|
Loading…
x
Reference in New Issue
Block a user