Add support for continue conversation in Assist Pipeline (#139480)

* Add support for continue conversation in Assist Pipeline

* Also forward to ESPHome

* Update snapshot

* And mobile app
This commit is contained in:
Paulus Schoutsen 2025-02-28 19:15:31 +00:00 committed by GitHub
parent 086c91485f
commit 90fc6ffdbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 362 additions and 46 deletions

View File

@ -117,7 +117,7 @@ async def async_pipeline_from_audio_stream(
""" """
with chat_session.async_get_chat_session(hass, conversation_id) as session: with chat_session.async_get_chat_session(hass, conversation_id) as session:
pipeline_input = PipelineInput( pipeline_input = PipelineInput(
conversation_id=session.conversation_id, session=session,
device_id=device_id, device_id=device_id,
stt_metadata=stt_metadata, stt_metadata=stt_metadata,
stt_stream=stt_stream, stt_stream=stt_stream,

View File

@ -96,6 +96,9 @@ ENGINE_LANGUAGE_PAIRS = (
) )
KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN) KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN)
KEY_PIPELINE_CONVERSATION_DATA: HassKey[dict[str, PipelineConversationData]] = HassKey(
"pipeline_conversation_data"
)
def validate_language(data: dict[str, Any]) -> Any: def validate_language(data: dict[str, Any]) -> Any:
@ -590,6 +593,12 @@ class PipelineRun:
_device_id: str | None = None _device_id: str | None = None
"""Optional device id set during run start.""" """Optional device id set during run start."""
_conversation_data: PipelineConversationData | None = None
"""Data tied to the conversation ID."""
_intent_agent_only = False
"""If request should only be handled by agent, ignoring sentence triggers and local processing."""
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Set language for pipeline.""" """Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language self.language = self.pipeline.language or self.hass.config.language
@ -1007,19 +1016,36 @@ class PipelineRun:
yield chunk.audio yield chunk.audio
async def prepare_recognize_intent(self) -> None: async def prepare_recognize_intent(self, session: chat_session.ChatSession) -> None:
"""Prepare recognizing an intent.""" """Prepare recognizing an intent."""
agent_info = conversation.async_get_agent_info( self._conversation_data = async_get_pipeline_conversation_data(
self.hass, self.hass, session
self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT,
) )
if agent_info is None: if self._conversation_data.continue_conversation_agent is not None:
engine = self.pipeline.conversation_engine or "default" agent_info = conversation.async_get_agent_info(
raise IntentRecognitionError( self.hass, self._conversation_data.continue_conversation_agent
code="intent-not-supported",
message=f"Intent recognition engine {engine} is not found",
) )
self._conversation_data.continue_conversation_agent = None
if agent_info is None:
raise IntentRecognitionError(
code="intent-agent-not-found",
message=f"Intent recognition engine {self._conversation_data.continue_conversation_agent} asked for follow-up but is no longer found",
)
self._intent_agent_only = True
else:
agent_info = conversation.async_get_agent_info(
self.hass,
self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT,
)
if agent_info is None:
engine = self.pipeline.conversation_engine or "default"
raise IntentRecognitionError(
code="intent-not-supported",
message=f"Intent recognition engine {engine} is not found",
)
self.intent_agent = agent_info.id self.intent_agent = agent_info.id
@ -1031,7 +1057,7 @@ class PipelineRun:
conversation_extra_system_prompt: str | None, conversation_extra_system_prompt: str | None,
) -> str: ) -> str:
"""Run intent recognition portion of pipeline. Returns text to speak.""" """Run intent recognition portion of pipeline. Returns text to speak."""
if self.intent_agent is None: if self.intent_agent is None or self._conversation_data is None:
raise RuntimeError("Recognize intent was not prepared") raise RuntimeError("Recognize intent was not prepared")
if self.pipeline.conversation_language == MATCH_ALL: if self.pipeline.conversation_language == MATCH_ALL:
@ -1078,7 +1104,7 @@ class PipelineRun:
agent_id = self.intent_agent agent_id = self.intent_agent
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
intent_response: intent.IntentResponse | None = None intent_response: intent.IntentResponse | None = None
if not processed_locally: if not processed_locally and not self._intent_agent_only:
# Sentence triggers override conversation agent # Sentence triggers override conversation agent
if ( if (
trigger_response_text trigger_response_text
@ -1195,6 +1221,9 @@ class PipelineRun:
) )
) )
if conversation_result.continue_conversation:
self._conversation_data.continue_conversation_agent = agent_id
return speech return speech
async def prepare_text_to_speech(self) -> None: async def prepare_text_to_speech(self) -> None:
@ -1458,8 +1487,8 @@ class PipelineInput:
run: PipelineRun run: PipelineRun
conversation_id: str session: chat_session.ChatSession
"""Identifier for the conversation.""" """Session for the conversation."""
stt_metadata: stt.SpeechMetadata | None = None stt_metadata: stt.SpeechMetadata | None = None
"""Metadata of stt input audio. Required when start_stage = stt.""" """Metadata of stt input audio. Required when start_stage = stt."""
@ -1484,7 +1513,9 @@ class PipelineInput:
async def execute(self) -> None: async def execute(self) -> None:
"""Run pipeline.""" """Run pipeline."""
self.run.start(conversation_id=self.conversation_id, device_id=self.device_id) self.run.start(
conversation_id=self.session.conversation_id, device_id=self.device_id
)
current_stage: PipelineStage | None = self.run.start_stage current_stage: PipelineStage | None = self.run.start_stage
stt_audio_buffer: list[EnhancedAudioChunk] = [] stt_audio_buffer: list[EnhancedAudioChunk] = []
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
@ -1568,7 +1599,7 @@ class PipelineInput:
assert intent_input is not None assert intent_input is not None
tts_input = await self.run.recognize_intent( tts_input = await self.run.recognize_intent(
intent_input, intent_input,
self.conversation_id, self.session.conversation_id,
self.device_id, self.device_id,
self.conversation_extra_system_prompt, self.conversation_extra_system_prompt,
) )
@ -1652,7 +1683,7 @@ class PipelineInput:
<= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT) <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT)
<= end_stage_index <= end_stage_index
): ):
prepare_tasks.append(self.run.prepare_recognize_intent()) prepare_tasks.append(self.run.prepare_recognize_intent(self.session))
if ( if (
start_stage_index start_stage_index
@ -1931,7 +1962,7 @@ class PipelineRunDebug:
class PipelineStore(Store[SerializedPipelineStorageCollection]): class PipelineStore(Store[SerializedPipelineStorageCollection]):
"""Store entity registry data.""" """Store pipeline data."""
async def _async_migrate_func( async def _async_migrate_func(
self, self,
@ -2013,3 +2044,37 @@ async def async_run_migrations(hass: HomeAssistant) -> None:
for pipeline, attr_updates in updates: for pipeline, attr_updates in updates:
await async_update_pipeline(hass, pipeline, **attr_updates) await async_update_pipeline(hass, pipeline, **attr_updates)
@dataclass
class PipelineConversationData:
"""Hold data for the duration of a conversation."""
continue_conversation_agent: str | None = None
"""The agent that requested the conversation to be continued."""
@callback
def async_get_pipeline_conversation_data(
hass: HomeAssistant, session: chat_session.ChatSession
) -> PipelineConversationData:
"""Get the pipeline data for a specific conversation."""
all_conversation_data = hass.data.get(KEY_PIPELINE_CONVERSATION_DATA)
if all_conversation_data is None:
all_conversation_data = {}
hass.data[KEY_PIPELINE_CONVERSATION_DATA] = all_conversation_data
data = all_conversation_data.get(session.conversation_id)
if data is not None:
return data
@callback
def do_cleanup() -> None:
"""Handle cleanup."""
all_conversation_data.pop(session.conversation_id)
session.async_on_cleanup(do_cleanup)
data = all_conversation_data[session.conversation_id] = PipelineConversationData()
return data

View File

@ -239,7 +239,7 @@ async def websocket_run(
with chat_session.async_get_chat_session( with chat_session.async_get_chat_session(
hass, msg.get("conversation_id") hass, msg.get("conversation_id")
) as session: ) as session:
input_args["conversation_id"] = session.conversation_id input_args["session"] = session
pipeline_input = PipelineInput(**input_args) pipeline_input = PipelineInput(**input_args)
try: try:

View File

@ -62,12 +62,14 @@ class ConversationResult:
response: intent.IntentResponse response: intent.IntentResponse
conversation_id: str | None = None conversation_id: str | None = None
continue_conversation: bool = False
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
"""Return result as a dict.""" """Return result as a dict."""
return { return {
"response": self.response.as_dict(), "response": self.response.as_dict(),
"conversation_id": self.conversation_id, "conversation_id": self.conversation_id,
"continue_conversation": self.continue_conversation,
} }

View File

@ -284,7 +284,10 @@ class EsphomeAssistSatellite(
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
assert event.data is not None assert event.data is not None
data_to_send = { data_to_send = {
"conversation_id": event.data["intent_output"]["conversation_id"] or "", "conversation_id": event.data["intent_output"]["conversation_id"],
"continue_conversation": event.data["intent_output"][
"continue_conversation"
],
} }
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
assert event.data is not None assert event.data is not None

View File

@ -1,6 +1,7 @@
# serializer version: 1 # serializer version: 1
# name: test_unknown_hass_api # name: test_unknown_hass_api
dict({ dict({
'continue_conversation': False,
'conversation_id': '1234', 'conversation_id': '1234',
'response': IntentResponse( 'response': IntentResponse(
card=dict({ card=dict({

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import AsyncIterable, Generator from collections.abc import AsyncIterable, Generator
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from unittest.mock import AsyncMock from unittest.mock import AsyncMock, patch
import pytest import pytest
@ -24,7 +24,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import chat_session, device_registry as dr
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -379,3 +379,14 @@ def pipeline_storage(pipeline_data) -> PipelineStorageCollection:
def make_10ms_chunk(header: bytes) -> bytes: def make_10ms_chunk(header: bytes) -> bytes:
"""Return 10ms of zeros with the given header.""" """Return 10ms of zeros with the given header."""
return header + bytes(BYTES_PER_CHUNK - len(header)) return header + bytes(BYTES_PER_CHUNK - len(header))
@pytest.fixture
def mock_chat_session(hass: HomeAssistant) -> Generator[chat_session.ChatSession]:
"""Mock the ulid of chat sessions."""
# pylint: disable-next=contextmanager-generator-missing-cleanup
with (
patch("homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid"),
chat_session.async_get_chat_session(hass) as session,
):
yield session

View File

@ -45,6 +45,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -137,6 +138,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -229,6 +231,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -345,6 +348,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -432,7 +436,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-conversation-id', 'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),
@ -440,7 +444,7 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-conversation-id', 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test input', 'intent_input': 'test input',
@ -452,6 +456,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -484,7 +489,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-conversation-id', 'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),
@ -492,7 +497,7 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-conversation-id', 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test input', 'intent_input': 'test input',
@ -504,6 +509,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -536,7 +542,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-conversation-id', 'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),
@ -544,7 +550,7 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-conversation-id', 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test input', 'intent_input': 'test input',
@ -556,6 +562,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -588,7 +595,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-conversation-id', 'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),

View File

@ -43,6 +43,7 @@
# name: test_audio_pipeline.4 # name: test_audio_pipeline.4
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -127,6 +128,7 @@
# name: test_audio_pipeline_debug.4 # name: test_audio_pipeline_debug.4
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -223,6 +225,7 @@
# name: test_audio_pipeline_with_enhancements.4 # name: test_audio_pipeline_with_enhancements.4
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -329,6 +332,7 @@
# name: test_audio_pipeline_with_wake_word_no_timeout.6 # name: test_audio_pipeline_with_wake_word_no_timeout.6
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -596,6 +600,7 @@
# name: test_pipeline_empty_tts_output.2 # name: test_pipeline_empty_tts_output.2
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -715,6 +720,7 @@
# name: test_text_only_pipeline[extra_msg0].2 # name: test_text_only_pipeline[extra_msg0].2
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -762,6 +768,7 @@
# name: test_text_only_pipeline[extra_msg1].2 # name: test_text_only_pipeline[extra_msg1].2
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({

View File

@ -27,7 +27,7 @@ from homeassistant.components.assist_pipeline.const import (
) )
from homeassistant.const import MATCH_ALL from homeassistant.const import MATCH_ALL
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent from homeassistant.helpers import chat_session, intent
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .conftest import ( from .conftest import (
@ -675,6 +675,7 @@ async def test_wake_word_detection_aborted(
mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity: MockWakeWordEntity,
init_components, init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData, pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test creating a pipeline from an audio stream with wake word.""" """Test creating a pipeline from an audio stream with wake word."""
@ -693,7 +694,7 @@ async def test_wake_word_detection_aborted(
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
conversation_id="mock-conversation-id", session=mock_chat_session,
device_id=None, device_id=None,
stt_metadata=stt.SpeechMetadata( stt_metadata=stt.SpeechMetadata(
language="", language="",
@ -766,6 +767,7 @@ async def test_tts_audio_output(
mock_tts_provider: MockTTSProvider, mock_tts_provider: MockTTSProvider,
init_components, init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData, pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test using tts_audio_output with wav sets options correctly.""" """Test using tts_audio_output with wav sets options correctly."""
@ -780,7 +782,7 @@ async def test_tts_audio_output(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.", tts_input="This is a test.",
conversation_id="mock-conversation-id", session=mock_chat_session,
device_id=None, device_id=None,
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
@ -823,6 +825,7 @@ async def test_tts_wav_preferred_format(
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider, mock_tts_provider: MockTTSProvider,
init_components, init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData, pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None: ) -> None:
"""Test that preferred format options are given to the TTS system if supported.""" """Test that preferred format options are given to the TTS system if supported."""
@ -837,7 +840,7 @@ async def test_tts_wav_preferred_format(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.", tts_input="This is a test.",
conversation_id="mock-conversation-id", session=mock_chat_session,
device_id=None, device_id=None,
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
@ -891,6 +894,7 @@ async def test_tts_dict_preferred_format(
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider, mock_tts_provider: MockTTSProvider,
init_components, init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData, pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None: ) -> None:
"""Test that preferred format options are given to the TTS system if supported.""" """Test that preferred format options are given to the TTS system if supported."""
@ -905,7 +909,7 @@ async def test_tts_dict_preferred_format(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.", tts_input="This is a test.",
conversation_id="mock-conversation-id", session=mock_chat_session,
device_id=None, device_id=None,
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
@ -962,6 +966,7 @@ async def test_tts_dict_preferred_format(
async def test_sentence_trigger_overrides_conversation_agent( async def test_sentence_trigger_overrides_conversation_agent(
hass: HomeAssistant, hass: HomeAssistant,
init_components, init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData, pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None: ) -> None:
"""Test that sentence triggers are checked before a non-default conversation agent.""" """Test that sentence triggers are checked before a non-default conversation agent."""
@ -991,7 +996,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test trigger sentence", intent_input="test trigger sentence",
conversation_id="mock-conversation-id", session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
context=Context(), context=Context(),
@ -1039,6 +1044,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
async def test_prefer_local_intents( async def test_prefer_local_intents(
hass: HomeAssistant, hass: HomeAssistant,
init_components, init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData, pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None: ) -> None:
"""Test that the default agent is checked first when local intents are preferred.""" """Test that the default agent is checked first when local intents are preferred."""
@ -1069,7 +1075,7 @@ async def test_prefer_local_intents(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="I'd like to order a stout please", intent_input="I'd like to order a stout please",
conversation_id="mock-conversation-id", session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
context=Context(), context=Context(),
@ -1113,10 +1119,150 @@ async def test_prefer_local_intents(
) )
async def test_intent_continue_conversation(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that a conversation agent flagging continue conversation gets response."""
events: list[assist_pipeline.PipelineEvent] = []
# Fake a test agent and prefer local intents
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent"
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Set a timer",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
):
await pipeline_input.validate()
response = intent.IntentResponse("en")
response.async_set_speech("For how long?")
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
response=response,
conversation_id=mock_chat_session.conversation_id,
continue_conversation=True,
),
) as mock_async_converse:
await pipeline_input.execute()
mock_async_converse.assert_called()
results = [
event.data
for event in events
if event.type
in (
assist_pipeline.PipelineEventType.INTENT_START,
assist_pipeline.PipelineEventType.INTENT_END,
)
]
assert results[1]["intent_output"]["continue_conversation"] is True
# Change conversation agent to default one and register sentence trigger that should not be called
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine=None
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": ["Hello"],
},
"action": {
"set_conversation_response": "test trigger response",
},
}
},
)
# Because we did continue conversation, it should respond to the test agent again.
events.clear()
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Hello",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
) as mock_prepare:
await pipeline_input.validate()
# It requested test agent even if that was not default agent.
assert mock_prepare.mock_calls[0][1][1] == "test-agent"
response = intent.IntentResponse("en")
response.async_set_speech("Timer set for 20 minutes")
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
response=response,
conversation_id=mock_chat_session.conversation_id,
),
) as mock_async_converse:
await pipeline_input.execute()
mock_async_converse.assert_called()
# Snapshot will show it was still handled by the test agent and not default agent
results = [
event.data
for event in events
if event.type
in (
assist_pipeline.PipelineEventType.INTENT_START,
assist_pipeline.PipelineEventType.INTENT_END,
)
]
assert results[0]["engine"] == "test-agent"
assert results[1]["intent_output"]["continue_conversation"] is False
async def test_stt_language_used_instead_of_conversation_language( async def test_stt_language_used_instead_of_conversation_language(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
init_components, init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test that the STT language is used first when the conversation language is '*' (all languages).""" """Test that the STT language is used first when the conversation language is '*' (all languages)."""
@ -1147,7 +1293,7 @@ async def test_stt_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input", intent_input="test input",
conversation_id="mock-conversation-id", session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
context=Context(), context=Context(),
@ -1192,6 +1338,7 @@ async def test_tts_language_used_instead_of_conversation_language(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
init_components, init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages).""" """Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
@ -1222,7 +1369,7 @@ async def test_tts_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input", intent_input="test input",
conversation_id="mock-conversation-id", session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
context=Context(), context=Context(),
@ -1267,6 +1414,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
init_components, init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test that the pipeline language is used last when the conversation language is '*' (all languages).""" """Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
@ -1297,7 +1445,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input", intent_input="test input",
conversation_id="mock-conversation-id", session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
context=Context(), context=Context(),

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal from typing import Literal
from unittest.mock import patch from unittest.mock import patch
@ -49,7 +50,7 @@ class MockAgent(conversation.AbstractConversationAgent):
@pytest.fixture @pytest.fixture
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog: async def mock_chat_log(hass: HomeAssistant) -> AsyncGenerator[MockChatLog]:
"""Return mock chat logs.""" """Return mock chat logs."""
# pylint: disable-next=contextmanager-generator-missing-cleanup # pylint: disable-next=contextmanager-generator-missing-cleanup
with ( with (

View File

@ -151,6 +151,7 @@
# --- # ---
# name: test_template_error # name: test_template_error
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -171,6 +172,7 @@
# --- # ---
# name: test_unknown_llm_api # name: test_unknown_llm_api
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({

View File

@ -1,6 +1,7 @@
# serializer version: 1 # serializer version: 1
# name: test_custom_sentences # name: test_custom_sentences
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -26,6 +27,7 @@
# --- # ---
# name: test_custom_sentences.1 # name: test_custom_sentences.1
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -51,6 +53,7 @@
# --- # ---
# name: test_custom_sentences_config # name: test_custom_sentences_config
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -76,6 +79,7 @@
# --- # ---
# name: test_intent_alias_added_removed # name: test_intent_alias_added_removed
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -106,6 +110,7 @@
# --- # ---
# name: test_intent_alias_added_removed.1 # name: test_intent_alias_added_removed.1
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -136,6 +141,7 @@
# --- # ---
# name: test_intent_alias_added_removed.2 # name: test_intent_alias_added_removed.2
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -156,6 +162,7 @@
# --- # ---
# name: test_intent_conversion_not_expose_new # name: test_intent_conversion_not_expose_new
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -176,6 +183,7 @@
# --- # ---
# name: test_intent_conversion_not_expose_new.1 # name: test_intent_conversion_not_expose_new.1
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -206,6 +214,7 @@
# --- # ---
# name: test_intent_entity_added_removed # name: test_intent_entity_added_removed
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -236,6 +245,7 @@
# --- # ---
# name: test_intent_entity_added_removed.1 # name: test_intent_entity_added_removed.1
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -266,6 +276,7 @@
# --- # ---
# name: test_intent_entity_added_removed.2 # name: test_intent_entity_added_removed.2
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -296,6 +307,7 @@
# --- # ---
# name: test_intent_entity_added_removed.3 # name: test_intent_entity_added_removed.3
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -316,6 +328,7 @@
# --- # ---
# name: test_intent_entity_exposed # name: test_intent_entity_exposed
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -346,6 +359,7 @@
# --- # ---
# name: test_intent_entity_fail_if_unexposed # name: test_intent_entity_fail_if_unexposed
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -366,6 +380,7 @@
# --- # ---
# name: test_intent_entity_remove_custom_name # name: test_intent_entity_remove_custom_name
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -386,6 +401,7 @@
# --- # ---
# name: test_intent_entity_remove_custom_name.1 # name: test_intent_entity_remove_custom_name.1
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -416,6 +432,7 @@
# --- # ---
# name: test_intent_entity_remove_custom_name.2 # name: test_intent_entity_remove_custom_name.2
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -436,6 +453,7 @@
# --- # ---
# name: test_intent_entity_renamed # name: test_intent_entity_renamed
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -466,6 +484,7 @@
# --- # ---
# name: test_intent_entity_renamed.1 # name: test_intent_entity_renamed.1
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({

View File

@ -202,6 +202,7 @@
# --- # ---
# name: test_http_api_handle_failure # name: test_http_api_handle_failure
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -222,6 +223,7 @@
# --- # ---
# name: test_http_api_no_match # name: test_http_api_no_match
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -242,6 +244,7 @@
# --- # ---
# name: test_http_api_unexpected_failure # name: test_http_api_unexpected_failure
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -262,6 +265,7 @@
# --- # ---
# name: test_http_processing_intent[None] # name: test_http_processing_intent[None]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -292,6 +296,7 @@
# --- # ---
# name: test_http_processing_intent[conversation.home_assistant] # name: test_http_processing_intent[conversation.home_assistant]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -322,6 +327,7 @@
# --- # ---
# name: test_http_processing_intent[homeassistant] # name: test_http_processing_intent[homeassistant]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -352,6 +358,7 @@
# --- # ---
# name: test_ws_api[payload0] # name: test_ws_api[payload0]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -372,6 +379,7 @@
# --- # ---
# name: test_ws_api[payload1] # name: test_ws_api[payload1]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -392,6 +400,7 @@
# --- # ---
# name: test_ws_api[payload2] # name: test_ws_api[payload2]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -412,6 +421,7 @@
# --- # ---
# name: test_ws_api[payload3] # name: test_ws_api[payload3]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -432,6 +442,7 @@
# --- # ---
# name: test_ws_api[payload4] # name: test_ws_api[payload4]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -452,6 +463,7 @@
# --- # ---
# name: test_ws_api[payload5] # name: test_ws_api[payload5]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({

View File

@ -1,6 +1,7 @@
# serializer version: 1 # serializer version: 1
# name: test_custom_agent # name: test_custom_agent
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -44,6 +45,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn kitchen on-None] # name: test_turn_on_intent[None-turn kitchen on-None]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -74,6 +76,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant] # name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -104,6 +107,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn kitchen on-homeassistant] # name: test_turn_on_intent[None-turn kitchen on-homeassistant]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -134,6 +138,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn on kitchen-None] # name: test_turn_on_intent[None-turn on kitchen-None]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -164,6 +169,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant] # name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -194,6 +200,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn on kitchen-homeassistant] # name: test_turn_on_intent[None-turn on kitchen-homeassistant]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -224,6 +231,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-None] # name: test_turn_on_intent[my_new_conversation-turn kitchen on-None]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -254,6 +262,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant] # name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -284,6 +293,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant] # name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -314,6 +324,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-None] # name: test_turn_on_intent[my_new_conversation-turn on kitchen-None]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -344,6 +355,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant] # name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -374,6 +386,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant] # name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant]
dict({ dict({
'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({

View File

@ -25,7 +25,7 @@ from aioesphomeapi import (
) )
import pytest import pytest
from homeassistant.components import assist_satellite, tts from homeassistant.components import assist_satellite, conversation, tts
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
from homeassistant.components.assist_satellite import ( from homeassistant.components.assist_satellite import (
AssistSatelliteConfiguration, AssistSatelliteConfiguration,
@ -285,12 +285,21 @@ async def test_pipeline_api_audio(
event_callback( event_callback(
PipelineEvent( PipelineEvent(
type=PipelineEventType.INTENT_END, type=PipelineEventType.INTENT_END,
data={"intent_output": {"conversation_id": conversation_id}}, data={
"intent_output": conversation.ConversationResult(
response=intent_helper.IntentResponse("en"),
conversation_id=conversation_id,
continue_conversation=True,
).as_dict()
},
) )
) )
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END, VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END,
{"conversation_id": conversation_id}, {
"conversation_id": conversation_id,
"continue_conversation": True,
},
) )
# TTS # TTS
@ -484,7 +493,12 @@ async def test_pipeline_udp_audio(
event_callback( event_callback(
PipelineEvent( PipelineEvent(
type=PipelineEventType.INTENT_END, type=PipelineEventType.INTENT_END,
data={"intent_output": {"conversation_id": conversation_id}}, data={
"intent_output": conversation.ConversationResult(
response=intent_helper.IntentResponse("en"),
conversation_id=conversation_id,
).as_dict()
},
) )
) )
@ -690,7 +704,12 @@ async def test_pipeline_media_player(
event_callback( event_callback(
PipelineEvent( PipelineEvent(
type=PipelineEventType.INTENT_END, type=PipelineEventType.INTENT_END,
data={"intent_output": {"conversation_id": conversation_id}}, data={
"intent_output": conversation.ConversationResult(
response=intent_helper.IntentResponse("en"),
conversation_id=conversation_id,
).as_dict()
},
) )
) )

View File

@ -1081,6 +1081,7 @@ async def test_webhook_handle_conversation_process(
}, },
}, },
"conversation_id": None, "conversation_id": None,
"continue_conversation": False,
} }

View File

@ -1,6 +1,7 @@
# serializer version: 1 # serializer version: 1
# name: test_unknown_hass_api # name: test_unknown_hass_api
dict({ dict({
'continue_conversation': False,
'conversation_id': '1234', 'conversation_id': '1234',
'response': IntentResponse( 'response': IntentResponse(
card=dict({ card=dict({

View File

@ -109,7 +109,11 @@ class HomeAssistantSnapshotSerializer(AmberDataSerializer):
serializable_data = cls._serializable_issue_registry_entry(data) serializable_data = cls._serializable_issue_registry_entry(data)
elif isinstance(data, dict) and "flow_id" in data and "handler" in data: elif isinstance(data, dict) and "flow_id" in data and "handler" in data:
serializable_data = cls._serializable_flow_result(data) serializable_data = cls._serializable_flow_result(data)
elif isinstance(data, dict) and set(data) == {"conversation_id", "response"}: elif isinstance(data, dict) and set(data) == {
"conversation_id",
"response",
"continue_conversation",
}:
serializable_data = cls._serializable_conversation_result(data) serializable_data = cls._serializable_conversation_result(data)
elif isinstance(data, vol.Schema): elif isinstance(data, vol.Schema):
serializable_data = voluptuous_serialize.convert(data) serializable_data = voluptuous_serialize.convert(data)