Add support for continue conversation in Assist Pipeline (#139480)

* Add support for continue conversation in Assist Pipeline

* Also forward to ESPHome

* Update snapshot

* And mobile app
This commit is contained in:
Paulus Schoutsen
2025-02-28 19:15:31 +00:00
committed by GitHub
parent 086c91485f
commit 90fc6ffdbf
19 changed files with 362 additions and 46 deletions

View File

@@ -27,7 +27,7 @@ from homeassistant.components.assist_pipeline.const import (
)
from homeassistant.const import MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent
from homeassistant.helpers import chat_session, intent
from homeassistant.setup import async_setup_component
from .conftest import (
@@ -675,6 +675,7 @@ async def test_wake_word_detection_aborted(
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream with wake word."""
@@ -693,7 +694,7 @@ async def test_wake_word_detection_aborted(
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
conversation_id="mock-conversation-id",
session=mock_chat_session,
device_id=None,
stt_metadata=stt.SpeechMetadata(
language="",
@@ -766,6 +767,7 @@ async def test_tts_audio_output(
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test using tts_audio_output with wav sets options correctly."""
@@ -780,7 +782,7 @@ async def test_tts_audio_output(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id="mock-conversation-id",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
@@ -823,6 +825,7 @@ async def test_tts_wav_preferred_format(
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
@@ -837,7 +840,7 @@ async def test_tts_wav_preferred_format(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id="mock-conversation-id",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
@@ -891,6 +894,7 @@ async def test_tts_dict_preferred_format(
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
@@ -905,7 +909,7 @@ async def test_tts_dict_preferred_format(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id="mock-conversation-id",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
@@ -962,6 +966,7 @@ async def test_tts_dict_preferred_format(
async def test_sentence_trigger_overrides_conversation_agent(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that sentence triggers are checked before a non-default conversation agent."""
@@ -991,7 +996,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test trigger sentence",
conversation_id="mock-conversation-id",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@@ -1039,6 +1044,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
async def test_prefer_local_intents(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that the default agent is checked first when local intents are preferred."""
@@ -1069,7 +1075,7 @@ async def test_prefer_local_intents(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="I'd like to order a stout please",
conversation_id="mock-conversation-id",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@@ -1113,10 +1119,150 @@ async def test_prefer_local_intents(
)
async def test_intent_continue_conversation(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that a conversation agent flagging continue conversation gets response."""
events: list[assist_pipeline.PipelineEvent] = []
# Fake a test agent and prefer local intents
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent"
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Set a timer",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
):
await pipeline_input.validate()
response = intent.IntentResponse("en")
response.async_set_speech("For how long?")
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
response=response,
conversation_id=mock_chat_session.conversation_id,
continue_conversation=True,
),
) as mock_async_converse:
await pipeline_input.execute()
mock_async_converse.assert_called()
results = [
event.data
for event in events
if event.type
in (
assist_pipeline.PipelineEventType.INTENT_START,
assist_pipeline.PipelineEventType.INTENT_END,
)
]
assert results[1]["intent_output"]["continue_conversation"] is True
# Change conversation agent to default one and register sentence trigger that should not be called
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine=None
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": ["Hello"],
},
"action": {
"set_conversation_response": "test trigger response",
},
}
},
)
# Because we did continue conversation, it should respond to the test agent again.
events.clear()
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Hello",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
) as mock_prepare:
await pipeline_input.validate()
# It requested test agent even if that was not default agent.
assert mock_prepare.mock_calls[0][1][1] == "test-agent"
response = intent.IntentResponse("en")
response.async_set_speech("Timer set for 20 minutes")
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
response=response,
conversation_id=mock_chat_session.conversation_id,
),
) as mock_async_converse:
await pipeline_input.execute()
mock_async_converse.assert_called()
# Snapshot will show it was still handled by the test agent and not default agent
results = [
event.data
for event in events
if event.type
in (
assist_pipeline.PipelineEventType.INTENT_START,
assist_pipeline.PipelineEventType.INTENT_END,
)
]
assert results[0]["engine"] == "test-agent"
assert results[1]["intent_output"]["continue_conversation"] is False
async def test_stt_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
@@ -1147,7 +1293,7 @@ async def test_stt_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
conversation_id="mock-conversation-id",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@@ -1192,6 +1338,7 @@ async def test_tts_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
@@ -1222,7 +1369,7 @@ async def test_tts_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
conversation_id="mock-conversation-id",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@@ -1267,6 +1414,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
@@ -1297,7 +1445,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
conversation_id="mock-conversation-id",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),