Introduce async_add_assistant_content to conversation chat log (#137273)

introduce async_add_assistant_content_without_tools to conversation chat log
This commit is contained in:
Paulus Schoutsen 2025-02-03 15:27:55 -05:00 committed by GitHub
parent 282560acf8
commit 649319f4ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 37 additions and 29 deletions

View File

@ -1106,13 +1106,12 @@ class PipelineRun:
speech: str = intent_response.speech.get("plain", {}).get(
"speech", ""
)
async for _ in chat_log.async_add_assistant_content(
chat_log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id=agent_id,
content=speech,
)
):
pass
)
conversation_result = conversation.ConversationResult(
response=intent_response,
conversation_id=session.conversation_id,

View File

@ -265,12 +265,11 @@ class AssistSatelliteEntity(entity.Entity):
self._conversation_id = session.conversation_id
if start_message:
async for _tool_response in chat_log.async_add_assistant_content(
chat_log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id=self.entity_id, content=start_message
)
):
pass # no tool responses.
)
try:
await self.async_start_conversation(announcement)

View File

@ -143,6 +143,15 @@ class ChatLog:
"""Add user content to the log."""
self.content.append(content)
@callback
def async_add_assistant_content_without_tools(
self, content: AssistantContent
) -> None:
"""Add assistant content to the log."""
if content.tool_calls is not None:
raise ValueError("Tool calls not allowed")
self.content.append(content)
async def async_add_assistant_content(
self, content: AssistantContent
) -> AsyncGenerator[ToolResultContent]:

View File

@ -379,13 +379,12 @@ class DefaultAgent(ConversationEntity):
)
speech: str = response.speech.get("plain", {}).get("speech", "")
async for _tool_result in chat_log.async_add_assistant_content(
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id=user_input.agent_id, # type: ignore[arg-type]
content=speech,
)
):
pass
)
return ConversationResult(
response=response, conversation_id=session.conversation_id

View File

@ -56,13 +56,12 @@ async def test_cleanup(
):
conversation_id = session.conversation_id
# Add message so it persists
async for _tool_result in chat_log.async_add_assistant_content(
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="mock-agent-id",
content="Hey!",
)
):
pytest.fail("should not reach here")
)
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
@ -210,13 +209,12 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None,
user_llm_prompt=None,
)
async for _tool_result in chat_log.async_add_assistant_content(
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="mock-agent-id",
content="Hey!",
)
):
pytest.fail("should not reach here")
)
assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.content[0].content.endswith(extra_system_prompt)
@ -252,13 +250,12 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None,
user_llm_prompt=None,
)
async for _tool_result in chat_log.async_add_assistant_content(
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="mock-agent-id",
content="Hey!",
)
):
pytest.fail("should not reach here")
)
assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.content[0].content.endswith(extra_system_prompt2)
@ -311,19 +308,24 @@ async def test_tool_call(
user_llm_hass_api="assist",
user_llm_prompt=None,
)
content = AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
],
)
with pytest.raises(ValueError):
chat_log.async_add_assistant_content_without_tools(content)
result = None
async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
],
)
content
):
assert result is None
result = tool_result_content