mirror of
https://github.com/home-assistant/core.git
synced 2025-11-12 20:40:18 +00:00
Add support for continue conversation in Assist Pipeline (#139480)
* Add support for continue conversation in Assist Pipeline * Also forward to ESPHome * Update snapshot * And mobile app
This commit is contained in:
@@ -27,7 +27,7 @@ from homeassistant.components.assist_pipeline.const import (
|
||||
)
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers import chat_session, intent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .conftest import (
|
||||
@@ -675,6 +675,7 @@ async def test_wake_word_detection_aborted(
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test creating a pipeline from an audio stream with wake word."""
|
||||
@@ -693,7 +694,7 @@ async def test_wake_word_detection_aborted(
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
conversation_id="mock-conversation-id",
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
@@ -766,6 +767,7 @@ async def test_tts_audio_output(
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test using tts_audio_output with wav sets options correctly."""
|
||||
@@ -780,7 +782,7 @@ async def test_tts_audio_output(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id="mock-conversation-id",
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
@@ -823,6 +825,7 @@ async def test_tts_wav_preferred_format(
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that preferred format options are given to the TTS system if supported."""
|
||||
@@ -837,7 +840,7 @@ async def test_tts_wav_preferred_format(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id="mock-conversation-id",
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
@@ -891,6 +894,7 @@ async def test_tts_dict_preferred_format(
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that preferred format options are given to the TTS system if supported."""
|
||||
@@ -905,7 +909,7 @@ async def test_tts_dict_preferred_format(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id="mock-conversation-id",
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
@@ -962,6 +966,7 @@ async def test_tts_dict_preferred_format(
|
||||
async def test_sentence_trigger_overrides_conversation_agent(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that sentence triggers are checked before a non-default conversation agent."""
|
||||
@@ -991,7 +996,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test trigger sentence",
|
||||
conversation_id="mock-conversation-id",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@@ -1039,6 +1044,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
|
||||
async def test_prefer_local_intents(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that the default agent is checked first when local intents are preferred."""
|
||||
@@ -1069,7 +1075,7 @@ async def test_prefer_local_intents(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="I'd like to order a stout please",
|
||||
conversation_id="mock-conversation-id",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@@ -1113,10 +1119,150 @@ async def test_prefer_local_intents(
|
||||
)
|
||||
|
||||
|
||||
async def test_intent_continue_conversation(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that a conversation agent flagging continue conversation gets response."""
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
# Fake a test agent and prefer local intents
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
await assist_pipeline.pipeline.async_update_pipeline(
|
||||
hass, pipeline, conversation_engine="test-agent"
|
||||
)
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="Set a timer",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure prepare succeeds
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||
):
|
||||
await pipeline_input.validate()
|
||||
|
||||
response = intent.IntentResponse("en")
|
||||
response.async_set_speech("For how long?")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
response=response,
|
||||
conversation_id=mock_chat_session.conversation_id,
|
||||
continue_conversation=True,
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
mock_async_converse.assert_called()
|
||||
|
||||
results = [
|
||||
event.data
|
||||
for event in events
|
||||
if event.type
|
||||
in (
|
||||
assist_pipeline.PipelineEventType.INTENT_START,
|
||||
assist_pipeline.PipelineEventType.INTENT_END,
|
||||
)
|
||||
]
|
||||
assert results[1]["intent_output"]["continue_conversation"] is True
|
||||
|
||||
# Change conversation agent to default one and register sentence trigger that should not be called
|
||||
await assist_pipeline.pipeline.async_update_pipeline(
|
||||
hass, pipeline, conversation_engine=None
|
||||
)
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"automation",
|
||||
{
|
||||
"automation": {
|
||||
"trigger": {
|
||||
"platform": "conversation",
|
||||
"command": ["Hello"],
|
||||
},
|
||||
"action": {
|
||||
"set_conversation_response": "test trigger response",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Because we did continue conversation, it should respond to the test agent again.
|
||||
events.clear()
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="Hello",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure prepare succeeds
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||
) as mock_prepare:
|
||||
await pipeline_input.validate()
|
||||
|
||||
# It requested test agent even if that was not default agent.
|
||||
assert mock_prepare.mock_calls[0][1][1] == "test-agent"
|
||||
|
||||
response = intent.IntentResponse("en")
|
||||
response.async_set_speech("Timer set for 20 minutes")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
response=response,
|
||||
conversation_id=mock_chat_session.conversation_id,
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
mock_async_converse.assert_called()
|
||||
|
||||
# Snapshot will show it was still handled by the test agent and not default agent
|
||||
results = [
|
||||
event.data
|
||||
for event in events
|
||||
if event.type
|
||||
in (
|
||||
assist_pipeline.PipelineEventType.INTENT_START,
|
||||
assist_pipeline.PipelineEventType.INTENT_END,
|
||||
)
|
||||
]
|
||||
assert results[0]["engine"] == "test-agent"
|
||||
assert results[1]["intent_output"]["continue_conversation"] is False
|
||||
|
||||
|
||||
async def test_stt_language_used_instead_of_conversation_language(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
|
||||
@@ -1147,7 +1293,7 @@ async def test_stt_language_used_instead_of_conversation_language(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
conversation_id="mock-conversation-id",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@@ -1192,6 +1338,7 @@ async def test_tts_language_used_instead_of_conversation_language(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
|
||||
@@ -1222,7 +1369,7 @@ async def test_tts_language_used_instead_of_conversation_language(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
conversation_id="mock-conversation-id",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@@ -1267,6 +1414,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
|
||||
@@ -1297,7 +1445,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
conversation_id="mock-conversation-id",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
|
||||
Reference in New Issue
Block a user