mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Introduce async_add_assistant_content to conversation chat log (#137273)
introduce async_add_assistant_content_without_tools to conversation chat log
This commit is contained in:
parent
282560acf8
commit
649319f4ee
@ -1106,13 +1106,12 @@ class PipelineRun:
|
||||
speech: str = intent_response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
async for _ in chat_log.async_add_assistant_content(
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id=agent_id,
|
||||
content=speech,
|
||||
)
|
||||
):
|
||||
pass
|
||||
)
|
||||
conversation_result = conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=session.conversation_id,
|
||||
|
@ -265,12 +265,11 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
self._conversation_id = session.conversation_id
|
||||
|
||||
if start_message:
|
||||
async for _tool_response in chat_log.async_add_assistant_content(
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id=self.entity_id, content=start_message
|
||||
)
|
||||
):
|
||||
pass # no tool responses.
|
||||
)
|
||||
|
||||
try:
|
||||
await self.async_start_conversation(announcement)
|
||||
|
@ -143,6 +143,15 @@ class ChatLog:
|
||||
"""Add user content to the log."""
|
||||
self.content.append(content)
|
||||
|
||||
@callback
|
||||
def async_add_assistant_content_without_tools(
|
||||
self, content: AssistantContent
|
||||
) -> None:
|
||||
"""Add assistant content to the log."""
|
||||
if content.tool_calls is not None:
|
||||
raise ValueError("Tool calls not allowed")
|
||||
self.content.append(content)
|
||||
|
||||
async def async_add_assistant_content(
|
||||
self, content: AssistantContent
|
||||
) -> AsyncGenerator[ToolResultContent]:
|
||||
|
@ -379,13 +379,12 @@ class DefaultAgent(ConversationEntity):
|
||||
)
|
||||
|
||||
speech: str = response.speech.get("plain", {}).get("speech", "")
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
agent_id=user_input.agent_id, # type: ignore[arg-type]
|
||||
content=speech,
|
||||
)
|
||||
):
|
||||
pass
|
||||
)
|
||||
|
||||
return ConversationResult(
|
||||
response=response, conversation_id=session.conversation_id
|
||||
|
@ -56,13 +56,12 @@ async def test_cleanup(
|
||||
):
|
||||
conversation_id = session.conversation_id
|
||||
# Add message so it persists
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
):
|
||||
pytest.fail("should not reach here")
|
||||
)
|
||||
|
||||
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
|
||||
|
||||
@ -210,13 +209,12 @@ async def test_extra_systen_prompt(
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
):
|
||||
pytest.fail("should not reach here")
|
||||
)
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||
assert chat_log.content[0].content.endswith(extra_system_prompt)
|
||||
@ -252,13 +250,12 @@ async def test_extra_systen_prompt(
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
):
|
||||
pytest.fail("should not reach here")
|
||||
)
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
||||
@ -311,19 +308,24 @@ async def test_tool_call(
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
content = AssistantContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content="",
|
||||
tool_calls=[
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_log.async_add_assistant_content_without_tools(content)
|
||||
|
||||
result = None
|
||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content="",
|
||||
tool_calls=[
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
)
|
||||
],
|
||||
)
|
||||
content
|
||||
):
|
||||
assert result is None
|
||||
result = tool_result_content
|
||||
|
Loading…
x
Reference in New Issue
Block a user