mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Introduce async_add_assistant_content to conversation chat log (#137273)
introduce async_add_assistant_content_without_tools to conversation chat log
This commit is contained in:
parent
282560acf8
commit
649319f4ee
@ -1106,13 +1106,12 @@ class PipelineRun:
|
|||||||
speech: str = intent_response.speech.get("plain", {}).get(
|
speech: str = intent_response.speech.get("plain", {}).get(
|
||||||
"speech", ""
|
"speech", ""
|
||||||
)
|
)
|
||||||
async for _ in chat_log.async_add_assistant_content(
|
chat_log.async_add_assistant_content_without_tools(
|
||||||
conversation.AssistantContent(
|
conversation.AssistantContent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
content=speech,
|
content=speech,
|
||||||
)
|
)
|
||||||
):
|
)
|
||||||
pass
|
|
||||||
conversation_result = conversation.ConversationResult(
|
conversation_result = conversation.ConversationResult(
|
||||||
response=intent_response,
|
response=intent_response,
|
||||||
conversation_id=session.conversation_id,
|
conversation_id=session.conversation_id,
|
||||||
|
@ -265,12 +265,11 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
self._conversation_id = session.conversation_id
|
self._conversation_id = session.conversation_id
|
||||||
|
|
||||||
if start_message:
|
if start_message:
|
||||||
async for _tool_response in chat_log.async_add_assistant_content(
|
chat_log.async_add_assistant_content_without_tools(
|
||||||
conversation.AssistantContent(
|
conversation.AssistantContent(
|
||||||
agent_id=self.entity_id, content=start_message
|
agent_id=self.entity_id, content=start_message
|
||||||
)
|
)
|
||||||
):
|
)
|
||||||
pass # no tool responses.
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.async_start_conversation(announcement)
|
await self.async_start_conversation(announcement)
|
||||||
|
@ -143,6 +143,15 @@ class ChatLog:
|
|||||||
"""Add user content to the log."""
|
"""Add user content to the log."""
|
||||||
self.content.append(content)
|
self.content.append(content)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_add_assistant_content_without_tools(
|
||||||
|
self, content: AssistantContent
|
||||||
|
) -> None:
|
||||||
|
"""Add assistant content to the log."""
|
||||||
|
if content.tool_calls is not None:
|
||||||
|
raise ValueError("Tool calls not allowed")
|
||||||
|
self.content.append(content)
|
||||||
|
|
||||||
async def async_add_assistant_content(
|
async def async_add_assistant_content(
|
||||||
self, content: AssistantContent
|
self, content: AssistantContent
|
||||||
) -> AsyncGenerator[ToolResultContent]:
|
) -> AsyncGenerator[ToolResultContent]:
|
||||||
|
@ -379,13 +379,12 @@ class DefaultAgent(ConversationEntity):
|
|||||||
)
|
)
|
||||||
|
|
||||||
speech: str = response.speech.get("plain", {}).get("speech", "")
|
speech: str = response.speech.get("plain", {}).get("speech", "")
|
||||||
async for _tool_result in chat_log.async_add_assistant_content(
|
chat_log.async_add_assistant_content_without_tools(
|
||||||
AssistantContent(
|
AssistantContent(
|
||||||
agent_id=user_input.agent_id, # type: ignore[arg-type]
|
agent_id=user_input.agent_id, # type: ignore[arg-type]
|
||||||
content=speech,
|
content=speech,
|
||||||
)
|
)
|
||||||
):
|
)
|
||||||
pass
|
|
||||||
|
|
||||||
return ConversationResult(
|
return ConversationResult(
|
||||||
response=response, conversation_id=session.conversation_id
|
response=response, conversation_id=session.conversation_id
|
||||||
|
@ -56,13 +56,12 @@ async def test_cleanup(
|
|||||||
):
|
):
|
||||||
conversation_id = session.conversation_id
|
conversation_id = session.conversation_id
|
||||||
# Add message so it persists
|
# Add message so it persists
|
||||||
async for _tool_result in chat_log.async_add_assistant_content(
|
chat_log.async_add_assistant_content_without_tools(
|
||||||
AssistantContent(
|
AssistantContent(
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
)
|
)
|
||||||
):
|
)
|
||||||
pytest.fail("should not reach here")
|
|
||||||
|
|
||||||
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
|
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
|
||||||
|
|
||||||
@ -210,13 +209,12 @@ async def test_extra_systen_prompt(
|
|||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
async for _tool_result in chat_log.async_add_assistant_content(
|
chat_log.async_add_assistant_content_without_tools(
|
||||||
AssistantContent(
|
AssistantContent(
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
)
|
)
|
||||||
):
|
)
|
||||||
pytest.fail("should not reach here")
|
|
||||||
|
|
||||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||||
assert chat_log.content[0].content.endswith(extra_system_prompt)
|
assert chat_log.content[0].content.endswith(extra_system_prompt)
|
||||||
@ -252,13 +250,12 @@ async def test_extra_systen_prompt(
|
|||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
async for _tool_result in chat_log.async_add_assistant_content(
|
chat_log.async_add_assistant_content_without_tools(
|
||||||
AssistantContent(
|
AssistantContent(
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
)
|
)
|
||||||
):
|
)
|
||||||
pytest.fail("should not reach here")
|
|
||||||
|
|
||||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||||
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
||||||
@ -311,19 +308,24 @@ async def test_tool_call(
|
|||||||
user_llm_hass_api="assist",
|
user_llm_hass_api="assist",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
|
content = AssistantContent(
|
||||||
|
agent_id=mock_conversation_input.agent_id,
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id",
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={"param1": "Test Param"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chat_log.async_add_assistant_content_without_tools(content)
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||||
AssistantContent(
|
content
|
||||||
agent_id=mock_conversation_input.agent_id,
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
llm.ToolInput(
|
|
||||||
id="mock-tool-call-id",
|
|
||||||
tool_name="test_tool",
|
|
||||||
tool_args={"param1": "Test Param"},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
assert result is None
|
assert result is None
|
||||||
result = tool_result_content
|
result = tool_result_content
|
||||||
|
Loading…
x
Reference in New Issue
Block a user