Introduce async_add_assistant_content to conversation chat log (#137273)

introduce async_add_assistant_content_without_tools to conversation chat log
This commit is contained in:
Paulus Schoutsen 2025-02-03 15:27:55 -05:00 committed by GitHub
parent 282560acf8
commit 649319f4ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 37 additions and 29 deletions

View File

@ -1106,13 +1106,12 @@ class PipelineRun:
speech: str = intent_response.speech.get("plain", {}).get( speech: str = intent_response.speech.get("plain", {}).get(
"speech", "" "speech", ""
) )
async for _ in chat_log.async_add_assistant_content( chat_log.async_add_assistant_content_without_tools(
conversation.AssistantContent( conversation.AssistantContent(
agent_id=agent_id, agent_id=agent_id,
content=speech, content=speech,
) )
): )
pass
conversation_result = conversation.ConversationResult( conversation_result = conversation.ConversationResult(
response=intent_response, response=intent_response,
conversation_id=session.conversation_id, conversation_id=session.conversation_id,

View File

@ -265,12 +265,11 @@ class AssistSatelliteEntity(entity.Entity):
self._conversation_id = session.conversation_id self._conversation_id = session.conversation_id
if start_message: if start_message:
async for _tool_response in chat_log.async_add_assistant_content( chat_log.async_add_assistant_content_without_tools(
conversation.AssistantContent( conversation.AssistantContent(
agent_id=self.entity_id, content=start_message agent_id=self.entity_id, content=start_message
) )
): )
pass # no tool responses.
try: try:
await self.async_start_conversation(announcement) await self.async_start_conversation(announcement)

View File

@ -143,6 +143,15 @@ class ChatLog:
"""Add user content to the log.""" """Add user content to the log."""
self.content.append(content) self.content.append(content)
@callback
def async_add_assistant_content_without_tools(
self, content: AssistantContent
) -> None:
"""Add assistant content to the log."""
if content.tool_calls is not None:
raise ValueError("Tool calls not allowed")
self.content.append(content)
async def async_add_assistant_content( async def async_add_assistant_content(
self, content: AssistantContent self, content: AssistantContent
) -> AsyncGenerator[ToolResultContent]: ) -> AsyncGenerator[ToolResultContent]:

View File

@ -379,13 +379,12 @@ class DefaultAgent(ConversationEntity):
) )
speech: str = response.speech.get("plain", {}).get("speech", "") speech: str = response.speech.get("plain", {}).get("speech", "")
async for _tool_result in chat_log.async_add_assistant_content( chat_log.async_add_assistant_content_without_tools(
AssistantContent( AssistantContent(
agent_id=user_input.agent_id, # type: ignore[arg-type] agent_id=user_input.agent_id, # type: ignore[arg-type]
content=speech, content=speech,
) )
): )
pass
return ConversationResult( return ConversationResult(
response=response, conversation_id=session.conversation_id response=response, conversation_id=session.conversation_id

View File

@ -56,13 +56,12 @@ async def test_cleanup(
): ):
conversation_id = session.conversation_id conversation_id = session.conversation_id
# Add message so it persists # Add message so it persists
async for _tool_result in chat_log.async_add_assistant_content( chat_log.async_add_assistant_content_without_tools(
AssistantContent( AssistantContent(
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
) )
): )
pytest.fail("should not reach here")
assert conversation_id in hass.data[DATA_CHAT_HISTORY] assert conversation_id in hass.data[DATA_CHAT_HISTORY]
@ -210,13 +209,12 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt=None, user_llm_prompt=None,
) )
async for _tool_result in chat_log.async_add_assistant_content( chat_log.async_add_assistant_content_without_tools(
AssistantContent( AssistantContent(
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
) )
): )
pytest.fail("should not reach here")
assert chat_log.extra_system_prompt == extra_system_prompt assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.content[0].content.endswith(extra_system_prompt) assert chat_log.content[0].content.endswith(extra_system_prompt)
@ -252,13 +250,12 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt=None, user_llm_prompt=None,
) )
async for _tool_result in chat_log.async_add_assistant_content( chat_log.async_add_assistant_content_without_tools(
AssistantContent( AssistantContent(
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
) )
): )
pytest.fail("should not reach here")
assert chat_log.extra_system_prompt == extra_system_prompt2 assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.content[0].content.endswith(extra_system_prompt2) assert chat_log.content[0].content.endswith(extra_system_prompt2)
@ -311,19 +308,24 @@ async def test_tool_call(
user_llm_hass_api="assist", user_llm_hass_api="assist",
user_llm_prompt=None, user_llm_prompt=None,
) )
content = AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
],
)
with pytest.raises(ValueError):
chat_log.async_add_assistant_content_without_tools(content)
result = None result = None
async for tool_result_content in chat_log.async_add_assistant_content( async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent( content
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
],
)
): ):
assert result is None assert result is None
result = tool_result_content result = tool_result_content