mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add extra prompt to assist pipeline and conversation (#124743)
* Add extra prompt to assist pipeline and conversation * extra_prompt -> extra_system_prompt * Fix rebase * Fix tests
This commit is contained in:
parent
e5c5d1bcfd
commit
7a484ee0ae
@ -108,6 +108,7 @@ async def async_pipeline_from_audio_stream(
|
|||||||
device_id: str | None = None,
|
device_id: str | None = None,
|
||||||
start_stage: PipelineStage = PipelineStage.STT,
|
start_stage: PipelineStage = PipelineStage.STT,
|
||||||
end_stage: PipelineStage = PipelineStage.TTS,
|
end_stage: PipelineStage = PipelineStage.TTS,
|
||||||
|
conversation_extra_system_prompt: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create an audio pipeline from an audio stream.
|
"""Create an audio pipeline from an audio stream.
|
||||||
|
|
||||||
@ -119,6 +120,7 @@ async def async_pipeline_from_audio_stream(
|
|||||||
stt_metadata=stt_metadata,
|
stt_metadata=stt_metadata,
|
||||||
stt_stream=stt_stream,
|
stt_stream=stt_stream,
|
||||||
wake_word_phrase=wake_word_phrase,
|
wake_word_phrase=wake_word_phrase,
|
||||||
|
conversation_extra_system_prompt=conversation_extra_system_prompt,
|
||||||
run=PipelineRun(
|
run=PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
context=context,
|
context=context,
|
||||||
|
@ -1010,7 +1010,11 @@ class PipelineRun:
|
|||||||
self.intent_agent = agent_info.id
|
self.intent_agent = agent_info.id
|
||||||
|
|
||||||
async def recognize_intent(
|
async def recognize_intent(
|
||||||
self, intent_input: str, conversation_id: str | None, device_id: str | None
|
self,
|
||||||
|
intent_input: str,
|
||||||
|
conversation_id: str | None,
|
||||||
|
device_id: str | None,
|
||||||
|
conversation_extra_system_prompt: str | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run intent recognition portion of pipeline. Returns text to speak."""
|
"""Run intent recognition portion of pipeline. Returns text to speak."""
|
||||||
if self.intent_agent is None:
|
if self.intent_agent is None:
|
||||||
@ -1045,6 +1049,7 @@ class PipelineRun:
|
|||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
language=input_language,
|
language=input_language,
|
||||||
agent_id=self.intent_agent,
|
agent_id=self.intent_agent,
|
||||||
|
extra_system_prompt=conversation_extra_system_prompt,
|
||||||
)
|
)
|
||||||
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
|
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
|
||||||
|
|
||||||
@ -1392,8 +1397,13 @@ class PipelineInput:
|
|||||||
"""Input for text-to-speech. Required when start_stage = tts."""
|
"""Input for text-to-speech. Required when start_stage = tts."""
|
||||||
|
|
||||||
conversation_id: str | None = None
|
conversation_id: str | None = None
|
||||||
|
"""Identifier for the conversation."""
|
||||||
|
|
||||||
|
conversation_extra_system_prompt: str | None = None
|
||||||
|
"""Extra prompt information for the conversation agent."""
|
||||||
|
|
||||||
device_id: str | None = None
|
device_id: str | None = None
|
||||||
|
"""Identifier of the device that is processing the input/output of the pipeline."""
|
||||||
|
|
||||||
async def execute(self) -> None:
|
async def execute(self) -> None:
|
||||||
"""Run pipeline."""
|
"""Run pipeline."""
|
||||||
@ -1483,6 +1493,7 @@ class PipelineInput:
|
|||||||
intent_input,
|
intent_input,
|
||||||
self.conversation_id,
|
self.conversation_id,
|
||||||
self.device_id,
|
self.device_id,
|
||||||
|
self.conversation_extra_system_prompt,
|
||||||
)
|
)
|
||||||
if tts_input.strip():
|
if tts_input.strip():
|
||||||
current_stage = PipelineStage.TTS
|
current_stage = PipelineStage.TTS
|
||||||
|
@ -75,6 +75,7 @@ async def async_converse(
|
|||||||
language: str | None = None,
|
language: str | None = None,
|
||||||
agent_id: str | None = None,
|
agent_id: str | None = None,
|
||||||
device_id: str | None = None,
|
device_id: str | None = None,
|
||||||
|
extra_system_prompt: str | None = None,
|
||||||
) -> ConversationResult:
|
) -> ConversationResult:
|
||||||
"""Process text and get intent."""
|
"""Process text and get intent."""
|
||||||
agent = async_get_agent(hass, agent_id)
|
agent = async_get_agent(hass, agent_id)
|
||||||
@ -99,6 +100,7 @@ async def async_converse(
|
|||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
language=language,
|
language=language,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
extra_system_prompt=extra_system_prompt,
|
||||||
)
|
)
|
||||||
with async_conversation_trace() as trace:
|
with async_conversation_trace() as trace:
|
||||||
trace.add_event(
|
trace.add_event(
|
||||||
|
@ -40,6 +40,9 @@ class ConversationInput:
|
|||||||
agent_id: str | None = None
|
agent_id: str | None = None
|
||||||
"""Agent to use for processing."""
|
"""Agent to use for processing."""
|
||||||
|
|
||||||
|
extra_system_prompt: str | None = None
|
||||||
|
"""Extra prompt to provide extra info to LLMs how to understand the command."""
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
def as_dict(self) -> dict[str, Any]:
|
||||||
"""Return input as a dict."""
|
"""Return input as a dict."""
|
||||||
return {
|
return {
|
||||||
@ -49,6 +52,7 @@ class ConversationInput:
|
|||||||
"device_id": self.device_id,
|
"device_id": self.device_id,
|
||||||
"language": self.language,
|
"language": self.language,
|
||||||
"agent_id": self.agent_id,
|
"agent_id": self.agent_id,
|
||||||
|
"extra_system_prompt": self.extra_system_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ async def test_async_converse(hass: HomeAssistant, init_components) -> None:
|
|||||||
language="test lang",
|
language="test lang",
|
||||||
agent_id="conversation.home_assistant",
|
agent_id="conversation.home_assistant",
|
||||||
device_id="test device id",
|
device_id="test device id",
|
||||||
|
extra_system_prompt="test extra prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert mock_process.called
|
assert mock_process.called
|
||||||
@ -32,3 +33,4 @@ async def test_async_converse(hass: HomeAssistant, init_components) -> None:
|
|||||||
assert conversation_input.language == "test lang"
|
assert conversation_input.language == "test lang"
|
||||||
assert conversation_input.agent_id == "conversation.home_assistant"
|
assert conversation_input.agent_id == "conversation.home_assistant"
|
||||||
assert conversation_input.device_id == "test device id"
|
assert conversation_input.device_id == "test device id"
|
||||||
|
assert conversation_input.extra_system_prompt == "test extra prompt"
|
||||||
|
@ -88,6 +88,7 @@ async def test_if_fires_on_event(
|
|||||||
"device_id": None,
|
"device_id": None,
|
||||||
"language": "en",
|
"language": "en",
|
||||||
"text": "Ha ha ha",
|
"text": "Ha ha ha",
|
||||||
|
"extra_system_prompt": None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,6 +236,7 @@ async def test_response_same_sentence(
|
|||||||
"device_id": None,
|
"device_id": None,
|
||||||
"language": "en",
|
"language": "en",
|
||||||
"text": "test sentence",
|
"text": "test sentence",
|
||||||
|
"extra_system_prompt": None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -412,6 +414,7 @@ async def test_same_trigger_multiple_sentences(
|
|||||||
"device_id": None,
|
"device_id": None,
|
||||||
"language": "en",
|
"language": "en",
|
||||||
"text": "hello",
|
"text": "hello",
|
||||||
|
"extra_system_prompt": None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -639,6 +642,7 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
|
|||||||
"device_id": None,
|
"device_id": None,
|
||||||
"language": "en",
|
"language": "en",
|
||||||
"text": "play the white album by the beatles",
|
"text": "play the white album by the beatles",
|
||||||
|
"extra_system_prompt": None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user