Track when an LLM expects to continue a conversation (#139810)

* Track when an LLM expects to continue a conversation

* Strip content

* Address comments
This commit is contained in:
Paulus Schoutsen 2025-03-06 22:52:29 -05:00 committed by GitHub
parent 3dd1fadc7d
commit d47481a30e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 59 additions and 4 deletions

View File

@ -305,7 +305,9 @@ class AnthropicConversationEntity(
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response_content.content or "")
return conversation.ConversationResult(
response=intent_response, conversation_id=chat_log.conversation_id
response=intent_response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)
async def _async_entry_update_listener(

View File

@ -183,6 +183,25 @@ class ChatLog:
llm_api: llm.APIInstance | None = None
delta_listener: Callable[[ChatLog, dict], None] | None = None
@property
def continue_conversation(self) -> bool:
"""Return whether the conversation should continue."""
if not self.content:
return False
last_msg = self.content[-1]
return (
last_msg.role == "assistant"
and last_msg.content is not None # type: ignore[union-attr]
and last_msg.content.strip().endswith( # type: ignore[union-attr]
(
"?",
";", # Greek question mark
)
)
)
@property
def unresponded_tool_results(self) -> bool:
"""Return if there are unresponded tool results."""

View File

@ -459,7 +459,9 @@ class GoogleGenerativeAIConversationEntity(
" ".join([part.text.strip() for part in response_parts if part.text])
)
return conversation.ConversationResult(
response=response, conversation_id=chat_log.conversation_id
response=response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)
async def _async_entry_update_listener(

View File

@ -292,7 +292,9 @@ class OllamaConversationEntity(
)
intent_response.async_set_speech(chat_log.content[-1].content or "")
return conversation.ConversationResult(
response=intent_response, conversation_id=chat_log.conversation_id
response=intent_response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:

View File

@ -310,7 +310,9 @@ class OpenAIConversationEntity(
assert type(chat_log.content[-1]) is conversation.AssistantContent
intent_response.async_set_speech(chat_log.content[-1].content or "")
return conversation.ConversationResult(
response=intent_response, conversation_id=chat_log.conversation_id
response=intent_response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)
async def _async_entry_update_listener(

View File

@ -14,6 +14,7 @@ from homeassistant.components.conversation import (
ConversationInput,
ConverseError,
ToolResultContent,
UserContent,
async_get_chat_log,
)
from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS
@ -643,3 +644,30 @@ async def test_chat_log_reuse(
assert len(chat_log.content) == 2
assert chat_log.content[1].role == "user"
assert chat_log.content[1].content == mock_conversation_input.text
async def test_chat_log_continue_conversation(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Test continue conversation."""
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session) as chat_log,
):
assert chat_log.continue_conversation is False
chat_log.async_add_user_content(UserContent(mock_conversation_input.text))
assert chat_log.continue_conversation is False
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="mock-agent-id",
content="Hey? ",
)
)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="mock-agent-id",
content="Ποιο είναι το αγαπημένο σου χρώμα στα ελληνικά;",
)
)
assert chat_log.continue_conversation is True