From 649319f4eed68e7172c2c867d415ac307f3c6dee Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 3 Feb 2025 15:27:55 -0500 Subject: [PATCH] Introduce async_add_assistant_content to conversation chat log (#137273) introduce async_add_assistant_content_without_tools to conversation chat log --- .../components/assist_pipeline/pipeline.py | 5 +-- .../components/assist_satellite/entity.py | 5 +-- .../components/conversation/chat_log.py | 9 ++++ .../components/conversation/default_agent.py | 5 +-- .../components/conversation/test_chat_log.py | 42 ++++++++++--------- 5 files changed, 37 insertions(+), 29 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 262f4c59687..94e2b04d7ae 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1106,13 +1106,12 @@ class PipelineRun: speech: str = intent_response.speech.get("plain", {}).get( "speech", "" ) - async for _ in chat_log.async_add_assistant_content( + chat_log.async_add_assistant_content_without_tools( conversation.AssistantContent( agent_id=agent_id, content=speech, ) - ): - pass + ) conversation_result = conversation.ConversationResult( response=intent_response, conversation_id=session.conversation_id, diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py index 902cf731a5d..e43abb4539c 100644 --- a/homeassistant/components/assist_satellite/entity.py +++ b/homeassistant/components/assist_satellite/entity.py @@ -265,12 +265,11 @@ class AssistSatelliteEntity(entity.Entity): self._conversation_id = session.conversation_id if start_message: - async for _tool_response in chat_log.async_add_assistant_content( + chat_log.async_add_assistant_content_without_tools( conversation.AssistantContent( agent_id=self.entity_id, content=start_message ) - ): - pass # no tool responses. + ) try: await self.async_start_conversation(announcement) diff --git a/homeassistant/components/conversation/chat_log.py b/homeassistant/components/conversation/chat_log.py index d053d114a11..53e248d0a98 100644 --- a/homeassistant/components/conversation/chat_log.py +++ b/homeassistant/components/conversation/chat_log.py @@ -143,6 +143,15 @@ class ChatLog: """Add user content to the log.""" self.content.append(content) + @callback + def async_add_assistant_content_without_tools( + self, content: AssistantContent + ) -> None: + """Add assistant content to the log.""" + if content.tool_calls is not None: + raise ValueError("Tool calls not allowed") + self.content.append(content) + async def async_add_assistant_content( self, content: AssistantContent ) -> AsyncGenerator[ToolResultContent]: diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 5e1709c0404..bd7450e5a0f 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -379,13 +379,12 @@ class DefaultAgent(ConversationEntity): ) speech: str = response.speech.get("plain", {}).get("speech", "") - async for _tool_result in chat_log.async_add_assistant_content( + chat_log.async_add_assistant_content_without_tools( AssistantContent( agent_id=user_input.agent_id, # type: ignore[arg-type] content=speech, ) - ): - pass + ) return ConversationResult( response=response, conversation_id=session.conversation_id diff --git a/tests/components/conversation/test_chat_log.py b/tests/components/conversation/test_chat_log.py index a37d4408756..c22a90e6928 100644 --- a/tests/components/conversation/test_chat_log.py +++ b/tests/components/conversation/test_chat_log.py @@ -56,13 +56,12 @@ async def test_cleanup( ): conversation_id = session.conversation_id # Add message so it persists - async for _tool_result in chat_log.async_add_assistant_content( + chat_log.async_add_assistant_content_without_tools( AssistantContent( agent_id="mock-agent-id", content="Hey!", ) - ): - pytest.fail("should not reach here") + ) assert conversation_id in hass.data[DATA_CHAT_HISTORY] @@ -210,13 +209,12 @@ async def test_extra_systen_prompt( user_llm_hass_api=None, user_llm_prompt=None, ) - async for _tool_result in chat_log.async_add_assistant_content( + chat_log.async_add_assistant_content_without_tools( AssistantContent( agent_id="mock-agent-id", content="Hey!", ) - ): - pytest.fail("should not reach here") + ) assert chat_log.extra_system_prompt == extra_system_prompt assert chat_log.content[0].content.endswith(extra_system_prompt) @@ -252,13 +250,12 @@ async def test_extra_systen_prompt( user_llm_hass_api=None, user_llm_prompt=None, ) - async for _tool_result in chat_log.async_add_assistant_content( + chat_log.async_add_assistant_content_without_tools( AssistantContent( agent_id="mock-agent-id", content="Hey!", ) - ): - pytest.fail("should not reach here") + ) assert chat_log.extra_system_prompt == extra_system_prompt2 assert chat_log.content[0].content.endswith(extra_system_prompt2) @@ -311,19 +308,24 @@ async def test_tool_call( user_llm_hass_api="assist", user_llm_prompt=None, ) + content = AssistantContent( + agent_id=mock_conversation_input.agent_id, + content="", + tool_calls=[ + llm.ToolInput( + id="mock-tool-call-id", + tool_name="test_tool", + tool_args={"param1": "Test Param"}, + ) + ], + ) + + with pytest.raises(ValueError): + chat_log.async_add_assistant_content_without_tools(content) + result = None async for tool_result_content in chat_log.async_add_assistant_content( - AssistantContent( - agent_id=mock_conversation_input.agent_id, - content="", - tool_calls=[ - llm.ToolInput( - id="mock-tool-call-id", - tool_name="test_tool", - tool_args={"param1": "Test Param"}, - ) - ], - ) + content ): assert result is None result = tool_result_content