Assist Pipeline to use ChatSession for conversation ID (#137143)

* Assist Pipeline to use ChatSession for conversation ID

* Adjust to latest changes
This commit is contained in:
Paulus Schoutsen 2025-02-03 10:18:15 -05:00 committed by GitHub
parent 8acab6c646
commit 05ca80f4ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 143 additions and 76 deletions

View File

@ -9,6 +9,7 @@ import voluptuous as vol
from homeassistant.components import stt
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import chat_session
from homeassistant.helpers.typing import ConfigType
from .const import (
@ -114,24 +115,25 @@ async def async_pipeline_from_audio_stream(
Raises PipelineNotFound if no pipeline is found.
"""
pipeline_input = PipelineInput(
conversation_id=conversation_id,
device_id=device_id,
stt_metadata=stt_metadata,
stt_stream=stt_stream,
wake_word_phrase=wake_word_phrase,
conversation_extra_system_prompt=conversation_extra_system_prompt,
run=PipelineRun(
hass,
context=context,
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
start_stage=start_stage,
end_stage=end_stage,
event_callback=event_callback,
tts_audio_output=tts_audio_output,
wake_word_settings=wake_word_settings,
audio_settings=audio_settings or AudioSettings(),
),
)
await pipeline_input.validate()
await pipeline_input.execute()
with chat_session.async_get_chat_session(hass, conversation_id) as session:
pipeline_input = PipelineInput(
conversation_id=session.conversation_id,
device_id=device_id,
stt_metadata=stt_metadata,
stt_stream=stt_stream,
wake_word_phrase=wake_word_phrase,
conversation_extra_system_prompt=conversation_extra_system_prompt,
run=PipelineRun(
hass,
context=context,
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
start_stage=start_stage,
end_stage=end_stage,
event_callback=event_callback,
tts_audio_output=tts_audio_output,
wake_word_settings=wake_word_settings,
audio_settings=audio_settings or AudioSettings(),
),
)
await pipeline_input.validate()
await pipeline_input.execute()

View File

@ -624,7 +624,7 @@ class PipelineRun:
return
pipeline_data.pipeline_debug[self.pipeline.id][self.id].events.append(event)
def start(self, device_id: str | None) -> None:
def start(self, conversation_id: str, device_id: str | None) -> None:
"""Emit run start event."""
self._device_id = device_id
self._start_debug_recording_thread()
@ -632,6 +632,7 @@ class PipelineRun:
data = {
"pipeline": self.pipeline.id,
"language": self.language,
"conversation_id": conversation_id,
}
if self.runner_data is not None:
data["runner_data"] = self.runner_data
@ -1015,7 +1016,7 @@ class PipelineRun:
async def recognize_intent(
self,
intent_input: str,
conversation_id: str | None,
conversation_id: str,
device_id: str | None,
conversation_extra_system_prompt: str | None,
) -> str:
@ -1409,12 +1410,15 @@ def _pipeline_debug_recording_thread_proc(
wav_writer.close()
@dataclass
@dataclass(kw_only=True)
class PipelineInput:
"""Input to a pipeline run."""
run: PipelineRun
conversation_id: str
"""Identifier for the conversation."""
stt_metadata: stt.SpeechMetadata | None = None
"""Metadata of stt input audio. Required when start_stage = stt."""
@ -1430,9 +1434,6 @@ class PipelineInput:
tts_input: str | None = None
"""Input for text-to-speech. Required when start_stage = tts."""
conversation_id: str | None = None
"""Identifier for the conversation."""
conversation_extra_system_prompt: str | None = None
"""Extra prompt information for the conversation agent."""
@ -1441,7 +1442,7 @@ class PipelineInput:
async def execute(self) -> None:
"""Run pipeline."""
self.run.start(device_id=self.device_id)
self.run.start(conversation_id=self.conversation_id, device_id=self.device_id)
current_stage: PipelineStage | None = self.run.start_stage
stt_audio_buffer: list[EnhancedAudioChunk] = []
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None

View File

@ -14,7 +14,11 @@ import voluptuous as vol
from homeassistant.components import conversation, stt, tts, websocket_api
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers import (
chat_session,
config_validation as cv,
entity_registry as er,
)
from homeassistant.util import language as language_util
from .const import (
@ -145,7 +149,6 @@ async def websocket_run(
# Arguments to PipelineInput
input_args: dict[str, Any] = {
"conversation_id": msg.get("conversation_id"),
"device_id": msg.get("device_id"),
}
@ -233,38 +236,42 @@ async def websocket_run(
audio_settings=audio_settings or AudioSettings(),
)
pipeline_input = PipelineInput(**input_args)
with chat_session.async_get_chat_session(
hass, msg.get("conversation_id")
) as session:
input_args["conversation_id"] = session.conversation_id
pipeline_input = PipelineInput(**input_args)
try:
await pipeline_input.validate()
except PipelineError as error:
# Report more specific error when possible
connection.send_error(msg["id"], error.code, error.message)
return
try:
await pipeline_input.validate()
except PipelineError as error:
# Report more specific error when possible
connection.send_error(msg["id"], error.code, error.message)
return
# Confirm subscription
connection.send_result(msg["id"])
# Confirm subscription
connection.send_result(msg["id"])
run_task = hass.async_create_task(pipeline_input.execute())
run_task = hass.async_create_task(pipeline_input.execute())
# Cancel pipeline if user unsubscribes
connection.subscriptions[msg["id"]] = run_task.cancel
# Cancel pipeline if user unsubscribes
connection.subscriptions[msg["id"]] = run_task.cancel
try:
# Task contains a timeout
async with asyncio.timeout(timeout):
await run_task
except TimeoutError:
pipeline_input.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
{"code": "timeout", "message": "Timeout running pipeline"},
try:
# Task contains a timeout
async with asyncio.timeout(timeout):
await run_task
except TimeoutError:
pipeline_input.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
{"code": "timeout", "message": "Timeout running pipeline"},
)
)
)
finally:
if unregister_handler is not None:
# Unregister binary handler
unregister_handler()
finally:
if unregister_handler is not None:
# Unregister binary handler
unregister_handler()
@callback

View File

@ -3,6 +3,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
@ -32,7 +33,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -94,6 +95,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
@ -123,7 +125,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -185,6 +187,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
@ -214,7 +217,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -276,6 +279,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
@ -329,7 +333,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -391,6 +395,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
@ -427,6 +432,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
}),
@ -434,7 +440,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-conversation-id',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
@ -478,6 +484,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
}),
@ -485,7 +492,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-conversation-id',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
@ -529,6 +536,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
}),
@ -536,7 +544,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-conversation-id',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
@ -580,6 +588,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
}),

View File

@ -1,6 +1,7 @@
# serializer version: 1
# name: test_audio_pipeline
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -31,7 +32,7 @@
# ---
# name: test_audio_pipeline.3
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -84,6 +85,7 @@
# ---
# name: test_audio_pipeline_debug
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -114,7 +116,7 @@
# ---
# name: test_audio_pipeline_debug.3
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -179,6 +181,7 @@
# ---
# name: test_audio_pipeline_with_enhancements
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -209,7 +212,7 @@
# ---
# name: test_audio_pipeline_with_enhancements.3
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -262,6 +265,7 @@
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -314,7 +318,7 @@
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.5
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -367,6 +371,7 @@
# ---
# name: test_audio_pipeline_with_wake_word_timeout
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -399,6 +404,7 @@
# ---
# name: test_device_capture
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -425,6 +431,7 @@
# ---
# name: test_device_capture_override
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -473,6 +480,7 @@
# ---
# name: test_device_capture_queue_full
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -512,6 +520,7 @@
# ---
# name: test_intent_failed
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -522,7 +531,7 @@
# ---
# name: test_intent_failed.1
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?',
@ -535,6 +544,7 @@
# ---
# name: test_intent_timeout
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -545,7 +555,7 @@
# ---
# name: test_intent_timeout.1
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?',
@ -564,6 +574,7 @@
# ---
# name: test_pipeline_empty_tts_output
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -574,7 +585,7 @@
# ---
# name: test_pipeline_empty_tts_output.1
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'never mind',
@ -611,6 +622,7 @@
# ---
# name: test_stt_cooldown_different_ids
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -621,6 +633,7 @@
# ---
# name: test_stt_cooldown_different_ids.1
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -631,6 +644,7 @@
# ---
# name: test_stt_cooldown_same_id
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -641,6 +655,7 @@
# ---
# name: test_stt_cooldown_same_id.1
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -651,6 +666,7 @@
# ---
# name: test_stt_stream_failed
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -677,6 +693,7 @@
# ---
# name: test_text_only_pipeline[extra_msg0]
dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -723,6 +740,7 @@
# ---
# name: test_text_only_pipeline[extra_msg1]
dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -775,6 +793,7 @@
# ---
# name: test_tts_failed
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -796,6 +815,7 @@
# ---
# name: test_wake_word_cooldown_different_entities
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -806,6 +826,7 @@
# ---
# name: test_wake_word_cooldown_different_entities.1
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -857,6 +878,7 @@
# ---
# name: test_wake_word_cooldown_different_ids
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -867,6 +889,7 @@
# ---
# name: test_wake_word_cooldown_different_ids.1
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -921,6 +944,7 @@
# ---
# name: test_wake_word_cooldown_same_id
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -931,6 +955,7 @@
# ---
# name: test_wake_word_cooldown_same_id.1
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({

View File

@ -1,11 +1,12 @@
"""Test Voice Assistant init."""
import asyncio
from collections.abc import Generator
from dataclasses import asdict
import itertools as it
from pathlib import Path
import tempfile
from unittest.mock import ANY, patch
from unittest.mock import ANY, Mock, patch
import wave
import hass_nabucasa
@ -41,6 +42,14 @@ from .conftest import (
from tests.typing import ClientSessionGenerator, WebSocketGenerator
@pytest.fixture(autouse=True)
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid of chat sessions."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
"""Process events to remove dynamic values."""
processed = []
@ -684,7 +693,7 @@ async def test_wake_word_detection_aborted(
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
conversation_id=None,
conversation_id="mock-conversation-id",
device_id=None,
stt_metadata=stt.SpeechMetadata(
language="",
@ -771,7 +780,7 @@ async def test_tts_audio_output(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
conversation_id="mock-conversation-id",
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
@ -828,7 +837,7 @@ async def test_tts_wav_preferred_format(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
conversation_id="mock-conversation-id",
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
@ -896,7 +905,7 @@ async def test_tts_dict_preferred_format(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
conversation_id="mock-conversation-id",
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
@ -982,6 +991,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test trigger sentence",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@ -1059,6 +1069,7 @@ async def test_prefer_local_intents(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="I'd like to order a stout please",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@ -1136,6 +1147,7 @@ async def test_stt_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@ -1210,6 +1222,7 @@ async def test_tts_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@ -1284,6 +1297,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),

View File

@ -2,8 +2,9 @@
import asyncio
import base64
from collections.abc import Generator
from typing import Any
from unittest.mock import ANY, patch
from unittest.mock import ANY, Mock, patch
import pytest
from syrupy.assertion import SnapshotAssertion
@ -35,6 +36,14 @@ from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator
@pytest.fixture(autouse=True)
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid of chat sessions."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
@pytest.mark.parametrize(
"extra_msg",
[