mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 02:49:40 +00:00
Add support for continue conversation in Assist Pipeline (#139480)
* Add support for continue conversation in Assist Pipeline * Also forward to ESPHome * Update snapshot * And mobile app
This commit is contained in:
@@ -96,6 +96,9 @@ ENGINE_LANGUAGE_PAIRS = (
|
||||
)
|
||||
|
||||
KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN)
|
||||
KEY_PIPELINE_CONVERSATION_DATA: HassKey[dict[str, PipelineConversationData]] = HassKey(
|
||||
"pipeline_conversation_data"
|
||||
)
|
||||
|
||||
|
||||
def validate_language(data: dict[str, Any]) -> Any:
|
||||
@@ -590,6 +593,12 @@ class PipelineRun:
|
||||
_device_id: str | None = None
|
||||
"""Optional device id set during run start."""
|
||||
|
||||
_conversation_data: PipelineConversationData | None = None
|
||||
"""Data tied to the conversation ID."""
|
||||
|
||||
_intent_agent_only = False
|
||||
"""If request should only be handled by agent, ignoring sentence triggers and local processing."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set language for pipeline."""
|
||||
self.language = self.pipeline.language or self.hass.config.language
|
||||
@@ -1007,19 +1016,36 @@ class PipelineRun:
|
||||
|
||||
yield chunk.audio
|
||||
|
||||
async def prepare_recognize_intent(self) -> None:
|
||||
async def prepare_recognize_intent(self, session: chat_session.ChatSession) -> None:
|
||||
"""Prepare recognizing an intent."""
|
||||
agent_info = conversation.async_get_agent_info(
|
||||
self.hass,
|
||||
self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT,
|
||||
self._conversation_data = async_get_pipeline_conversation_data(
|
||||
self.hass, session
|
||||
)
|
||||
|
||||
if agent_info is None:
|
||||
engine = self.pipeline.conversation_engine or "default"
|
||||
raise IntentRecognitionError(
|
||||
code="intent-not-supported",
|
||||
message=f"Intent recognition engine {engine} is not found",
|
||||
if self._conversation_data.continue_conversation_agent is not None:
|
||||
agent_info = conversation.async_get_agent_info(
|
||||
self.hass, self._conversation_data.continue_conversation_agent
|
||||
)
|
||||
self._conversation_data.continue_conversation_agent = None
|
||||
if agent_info is None:
|
||||
raise IntentRecognitionError(
|
||||
code="intent-agent-not-found",
|
||||
message=f"Intent recognition engine {self._conversation_data.continue_conversation_agent} asked for follow-up but is no longer found",
|
||||
)
|
||||
self._intent_agent_only = True
|
||||
|
||||
else:
|
||||
agent_info = conversation.async_get_agent_info(
|
||||
self.hass,
|
||||
self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT,
|
||||
)
|
||||
|
||||
if agent_info is None:
|
||||
engine = self.pipeline.conversation_engine or "default"
|
||||
raise IntentRecognitionError(
|
||||
code="intent-not-supported",
|
||||
message=f"Intent recognition engine {engine} is not found",
|
||||
)
|
||||
|
||||
self.intent_agent = agent_info.id
|
||||
|
||||
@@ -1031,7 +1057,7 @@ class PipelineRun:
|
||||
conversation_extra_system_prompt: str | None,
|
||||
) -> str:
|
||||
"""Run intent recognition portion of pipeline. Returns text to speak."""
|
||||
if self.intent_agent is None:
|
||||
if self.intent_agent is None or self._conversation_data is None:
|
||||
raise RuntimeError("Recognize intent was not prepared")
|
||||
|
||||
if self.pipeline.conversation_language == MATCH_ALL:
|
||||
@@ -1078,7 +1104,7 @@ class PipelineRun:
|
||||
agent_id = self.intent_agent
|
||||
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
|
||||
intent_response: intent.IntentResponse | None = None
|
||||
if not processed_locally:
|
||||
if not processed_locally and not self._intent_agent_only:
|
||||
# Sentence triggers override conversation agent
|
||||
if (
|
||||
trigger_response_text
|
||||
@@ -1195,6 +1221,9 @@ class PipelineRun:
|
||||
)
|
||||
)
|
||||
|
||||
if conversation_result.continue_conversation:
|
||||
self._conversation_data.continue_conversation_agent = agent_id
|
||||
|
||||
return speech
|
||||
|
||||
async def prepare_text_to_speech(self) -> None:
|
||||
@@ -1458,8 +1487,8 @@ class PipelineInput:
|
||||
|
||||
run: PipelineRun
|
||||
|
||||
conversation_id: str
|
||||
"""Identifier for the conversation."""
|
||||
session: chat_session.ChatSession
|
||||
"""Session for the conversation."""
|
||||
|
||||
stt_metadata: stt.SpeechMetadata | None = None
|
||||
"""Metadata of stt input audio. Required when start_stage = stt."""
|
||||
@@ -1484,7 +1513,9 @@ class PipelineInput:
|
||||
|
||||
async def execute(self) -> None:
|
||||
"""Run pipeline."""
|
||||
self.run.start(conversation_id=self.conversation_id, device_id=self.device_id)
|
||||
self.run.start(
|
||||
conversation_id=self.session.conversation_id, device_id=self.device_id
|
||||
)
|
||||
current_stage: PipelineStage | None = self.run.start_stage
|
||||
stt_audio_buffer: list[EnhancedAudioChunk] = []
|
||||
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
|
||||
@@ -1568,7 +1599,7 @@ class PipelineInput:
|
||||
assert intent_input is not None
|
||||
tts_input = await self.run.recognize_intent(
|
||||
intent_input,
|
||||
self.conversation_id,
|
||||
self.session.conversation_id,
|
||||
self.device_id,
|
||||
self.conversation_extra_system_prompt,
|
||||
)
|
||||
@@ -1652,7 +1683,7 @@ class PipelineInput:
|
||||
<= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT)
|
||||
<= end_stage_index
|
||||
):
|
||||
prepare_tasks.append(self.run.prepare_recognize_intent())
|
||||
prepare_tasks.append(self.run.prepare_recognize_intent(self.session))
|
||||
|
||||
if (
|
||||
start_stage_index
|
||||
@@ -1931,7 +1962,7 @@ class PipelineRunDebug:
|
||||
|
||||
|
||||
class PipelineStore(Store[SerializedPipelineStorageCollection]):
|
||||
"""Store entity registry data."""
|
||||
"""Store pipeline data."""
|
||||
|
||||
async def _async_migrate_func(
|
||||
self,
|
||||
@@ -2013,3 +2044,37 @@ async def async_run_migrations(hass: HomeAssistant) -> None:
|
||||
|
||||
for pipeline, attr_updates in updates:
|
||||
await async_update_pipeline(hass, pipeline, **attr_updates)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineConversationData:
|
||||
"""Hold data for the duration of a conversation."""
|
||||
|
||||
continue_conversation_agent: str | None = None
|
||||
"""The agent that requested the conversation to be continued."""
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_pipeline_conversation_data(
|
||||
hass: HomeAssistant, session: chat_session.ChatSession
|
||||
) -> PipelineConversationData:
|
||||
"""Get the pipeline data for a specific conversation."""
|
||||
all_conversation_data = hass.data.get(KEY_PIPELINE_CONVERSATION_DATA)
|
||||
if all_conversation_data is None:
|
||||
all_conversation_data = {}
|
||||
hass.data[KEY_PIPELINE_CONVERSATION_DATA] = all_conversation_data
|
||||
|
||||
data = all_conversation_data.get(session.conversation_id)
|
||||
|
||||
if data is not None:
|
||||
return data
|
||||
|
||||
@callback
|
||||
def do_cleanup() -> None:
|
||||
"""Handle cleanup."""
|
||||
all_conversation_data.pop(session.conversation_id)
|
||||
|
||||
session.async_on_cleanup(do_cleanup)
|
||||
|
||||
data = all_conversation_data[session.conversation_id] = PipelineConversationData()
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user