mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Stream OpenAI messages into the chat log (#137400)
This commit is contained in:
parent
a526baa831
commit
df307aeb6d
@ -32,6 +32,7 @@ from .agent_manager import (
|
||||
)
|
||||
from .chat_log import (
|
||||
AssistantContent,
|
||||
AssistantContentDeltaDict,
|
||||
ChatLog,
|
||||
Content,
|
||||
ConverseError,
|
||||
@ -65,6 +66,7 @@ __all__ = [
|
||||
"HOME_ASSISTANT_AGENT",
|
||||
"OLD_HOME_ASSISTANT_AGENT",
|
||||
"AssistantContent",
|
||||
"AssistantContentDeltaDict",
|
||||
"ChatLog",
|
||||
"Content",
|
||||
"ConversationEntity",
|
||||
|
@ -3,11 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field, replace
|
||||
import logging
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -145,6 +146,14 @@ class ToolResultContent:
|
||||
type Content = SystemContent | UserContent | AssistantContent | ToolResultContent
|
||||
|
||||
|
||||
class AssistantContentDeltaDict(TypedDict, total=False):
|
||||
"""Partial content to define an AssistantContent."""
|
||||
|
||||
role: Literal["assistant"]
|
||||
content: str | None
|
||||
tool_calls: list[llm.ToolInput] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatLog:
|
||||
"""Class holding the chat history of a specific conversation."""
|
||||
@ -155,6 +164,11 @@ class ChatLog:
|
||||
extra_system_prompt: str | None = None
|
||||
llm_api: llm.APIInstance | None = None
|
||||
|
||||
@property
|
||||
def unresponded_tool_results(self) -> bool:
|
||||
"""Return if there are unresponded tool results."""
|
||||
return self.content[-1].role == "tool_result"
|
||||
|
||||
@callback
|
||||
def async_add_user_content(self, content: UserContent) -> None:
|
||||
"""Add user content to the log."""
|
||||
@ -223,6 +237,77 @@ class ChatLog:
|
||||
self.content.append(response_content)
|
||||
yield response_content
|
||||
|
||||
async def async_add_delta_content_stream(
|
||||
self, agent_id: str, stream: AsyncIterable[AssistantContentDeltaDict]
|
||||
) -> AsyncGenerator[AssistantContent | ToolResultContent]:
|
||||
"""Stream content into the chat log.
|
||||
|
||||
Returns a generator with all content that was added to the chat log.
|
||||
|
||||
stream iterates over dictionaries with optional keys role, content and tool_calls.
|
||||
|
||||
When a delta contains a role key, the current message is considered complete and
|
||||
a new message is started.
|
||||
|
||||
The keys content and tool_calls will be concatenated if they appear multiple times.
|
||||
"""
|
||||
current_content = ""
|
||||
current_tool_calls: list[llm.ToolInput] = []
|
||||
tool_call_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
async for delta in stream:
|
||||
LOGGER.debug("Received delta: %s", delta)
|
||||
|
||||
# Indicates update to current message
|
||||
if "role" not in delta:
|
||||
if delta_content := delta.get("content"):
|
||||
current_content += delta_content
|
||||
if delta_tool_calls := delta.get("tool_calls"):
|
||||
if self.llm_api is None:
|
||||
raise ValueError("No LLM API configured")
|
||||
current_tool_calls += delta_tool_calls
|
||||
|
||||
# Start processing the tool calls as soon as we know about them
|
||||
for tool_call in delta_tool_calls:
|
||||
tool_call_tasks[tool_call.id] = self.hass.async_create_task(
|
||||
self.llm_api.async_call_tool(tool_call),
|
||||
name=f"llm_tool_{tool_call.id}",
|
||||
)
|
||||
continue
|
||||
|
||||
# Starting a new message
|
||||
|
||||
if delta["role"] != "assistant":
|
||||
raise ValueError(f"Only assistant role expected. Got {delta['role']}")
|
||||
|
||||
# Yield the previous message if it has content
|
||||
if current_content or current_tool_calls:
|
||||
content = AssistantContent(
|
||||
agent_id=agent_id,
|
||||
content=current_content or None,
|
||||
tool_calls=current_tool_calls or None,
|
||||
)
|
||||
yield content
|
||||
async for tool_result in self.async_add_assistant_content(
|
||||
content, tool_call_tasks=tool_call_tasks
|
||||
):
|
||||
yield tool_result
|
||||
|
||||
current_content = delta.get("content") or ""
|
||||
current_tool_calls = delta.get("tool_calls") or []
|
||||
|
||||
if current_content or current_tool_calls:
|
||||
content = AssistantContent(
|
||||
agent_id=agent_id,
|
||||
content=current_content or None,
|
||||
tool_calls=current_tool_calls or None,
|
||||
)
|
||||
yield content
|
||||
async for tool_result in self.async_add_assistant_content(
|
||||
content, tool_call_tasks=tool_call_tasks
|
||||
):
|
||||
yield tool_result
|
||||
|
||||
async def async_update_llm_data(
|
||||
self,
|
||||
conversing_domain: str,
|
||||
|
@ -1,14 +1,15 @@
|
||||
"""Conversation support for OpenAI."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import json
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import openai
|
||||
from openai._streaming import AsyncStream
|
||||
from openai._types import NOT_GIVEN
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
@ -70,32 +71,6 @@ def _format_tool(
|
||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
||||
|
||||
def _convert_message_to_param(
|
||||
message: ChatCompletionMessage,
|
||||
) -> ChatCompletionMessageParam:
|
||||
"""Convert from class to TypedDict."""
|
||||
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
||||
if message.tool_calls:
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call.id,
|
||||
function=Function(
|
||||
arguments=tool_call.function.arguments,
|
||||
name=tool_call.function.name,
|
||||
),
|
||||
type=tool_call.type,
|
||||
)
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
param = ChatCompletionAssistantMessageParam(
|
||||
role=message.role,
|
||||
content=message.content,
|
||||
)
|
||||
if tool_calls:
|
||||
param["tool_calls"] = tool_calls
|
||||
return param
|
||||
|
||||
|
||||
def _convert_content_to_param(
|
||||
content: conversation.Content,
|
||||
) -> ChatCompletionMessageParam:
|
||||
@ -135,6 +110,74 @@ def _convert_content_to_param(
|
||||
)
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
result: AsyncStream[ChatCompletionChunk],
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform an OpenAI delta stream into HA format."""
|
||||
current_tool_call: dict | None = None
|
||||
|
||||
async for chunk in result:
|
||||
LOGGER.debug("Received chunk: %s", chunk)
|
||||
choice = chunk.choices[0]
|
||||
|
||||
if choice.finish_reason:
|
||||
if current_tool_call:
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=current_tool_call["id"],
|
||||
tool_name=current_tool_call["tool_name"],
|
||||
tool_args=json.loads(current_tool_call["tool_args"]),
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
break
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# We can yield delta messages not continuing or starting tool calls
|
||||
if current_tool_call is None and not delta.tool_calls:
|
||||
yield { # type: ignore[misc]
|
||||
key: value
|
||||
for key in ("role", "content")
|
||||
if (value := getattr(delta, key)) is not None
|
||||
}
|
||||
continue
|
||||
|
||||
# When doing tool calls, we should always have a tool call
|
||||
# object or we have gotten stopped above with a finish_reason set.
|
||||
if (
|
||||
not delta.tool_calls
|
||||
or not (delta_tool_call := delta.tool_calls[0])
|
||||
or not delta_tool_call.function
|
||||
):
|
||||
raise ValueError("Expected delta with tool call")
|
||||
|
||||
if current_tool_call and delta_tool_call.index == current_tool_call["index"]:
|
||||
current_tool_call["tool_args"] += delta_tool_call.function.arguments or ""
|
||||
continue
|
||||
|
||||
# We got tool call with new index, so we need to yield the previous
|
||||
if current_tool_call:
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=current_tool_call["id"],
|
||||
tool_name=current_tool_call["tool_name"],
|
||||
tool_args=json.loads(current_tool_call["tool_args"]),
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
current_tool_call = {
|
||||
"index": delta_tool_call.index,
|
||||
"id": delta_tool_call.id,
|
||||
"tool_name": delta_tool_call.function.name,
|
||||
"tool_args": delta_tool_call.function.arguments or "",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
@ -234,6 +277,7 @@ class OpenAIConversationEntity(
|
||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
"user": chat_log.conversation_id,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if model.startswith("o"):
|
||||
@ -247,39 +291,21 @@ class OpenAIConversationEntity(
|
||||
LOGGER.error("Error talking to OpenAI: %s", err)
|
||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||
|
||||
LOGGER.debug("Response %s", result)
|
||||
response = result.choices[0].message
|
||||
messages.append(_convert_message_to_param(response))
|
||||
|
||||
tool_calls: list[llm.ToolInput] | None = None
|
||||
if response.tool_calls:
|
||||
tool_calls = [
|
||||
llm.ToolInput(
|
||||
id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
for tool_call in response.tool_calls
|
||||
]
|
||||
|
||||
messages.extend(
|
||||
[
|
||||
_convert_content_to_param(tool_response)
|
||||
async for tool_response in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=response.content or "",
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
_convert_content_to_param(content)
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
user_input.agent_id, _transform_stream(result)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response.content or "")
|
||||
assert type(chat_log.content[-1]) is conversation.AssistantContent
|
||||
intent_response.async_set_speech(chat_log.content[-1].content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=chat_log.conversation_id
|
||||
)
|
||||
|
@ -7,6 +7,7 @@ from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import (
|
||||
@ -27,6 +28,7 @@ DATA_CHAT_SESSION: HassKey[dict[str, ChatSession]] = HassKey("chat_session")
|
||||
DATA_CHAT_SESSION_CLEANUP: HassKey[SessionCleanup] = HassKey("chat_session_cleanup")
|
||||
|
||||
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
@ -100,6 +102,7 @@ class SessionCleanup:
|
||||
# yielding session based on it.
|
||||
for conversation_id, session in list(all_sessions.items()):
|
||||
if session.last_updated + CONVERSATION_TIMEOUT < now:
|
||||
LOGGER.debug("Cleaning up session %s", conversation_id)
|
||||
del all_sessions[conversation_id]
|
||||
session.async_cleanup()
|
||||
|
||||
@ -150,6 +153,7 @@ def async_get_chat_session(
|
||||
pass
|
||||
|
||||
if session is None:
|
||||
LOGGER.debug("Creating new session %s", conversation_id)
|
||||
session = ChatSession(conversation_id)
|
||||
|
||||
current_session.set(session)
|
||||
|
@ -1,4 +1,154 @@
|
||||
# serializer version: 1
|
||||
# name: test_add_delta_content_stream[deltas0]
|
||||
list([
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas1]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test',
|
||||
'role': 'assistant',
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas2]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test',
|
||||
'role': 'assistant',
|
||||
'tool_calls': None,
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test 2',
|
||||
'role': 'assistant',
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas3]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': None,
|
||||
'role': 'assistant',
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'mock-tool-call-id',
|
||||
'tool_args': dict({
|
||||
'param1': 'Test Param 1',
|
||||
}),
|
||||
'tool_name': 'test_tool',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'role': 'tool_result',
|
||||
'tool_call_id': 'mock-tool-call-id',
|
||||
'tool_name': 'test_tool',
|
||||
'tool_result': 'Test Param 1',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas4]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test',
|
||||
'role': 'assistant',
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'mock-tool-call-id',
|
||||
'tool_args': dict({
|
||||
'param1': 'Test Param 1',
|
||||
}),
|
||||
'tool_name': 'test_tool',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'role': 'tool_result',
|
||||
'tool_call_id': 'mock-tool-call-id',
|
||||
'tool_name': 'test_tool',
|
||||
'tool_result': 'Test Param 1',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas5]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test',
|
||||
'role': 'assistant',
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'mock-tool-call-id',
|
||||
'tool_args': dict({
|
||||
'param1': 'Test Param 1',
|
||||
}),
|
||||
'tool_name': 'test_tool',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'role': 'tool_result',
|
||||
'tool_call_id': 'mock-tool-call-id',
|
||||
'tool_name': 'test_tool',
|
||||
'tool_result': 'Test Param 1',
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test 2',
|
||||
'role': 'assistant',
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas6]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': None,
|
||||
'role': 'assistant',
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'mock-tool-call-id',
|
||||
'tool_args': dict({
|
||||
'param1': 'Test Param 1',
|
||||
}),
|
||||
'tool_name': 'test_tool',
|
||||
}),
|
||||
dict({
|
||||
'id': 'mock-tool-call-id-2',
|
||||
'tool_args': dict({
|
||||
'param1': 'Test Param 2',
|
||||
}),
|
||||
'tool_name': 'test_tool',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'role': 'tool_result',
|
||||
'tool_call_id': 'mock-tool-call-id',
|
||||
'tool_name': 'test_tool',
|
||||
'tool_result': 'Test Param 1',
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'role': 'tool_result',
|
||||
'tool_call_id': 'mock-tool-call-id-2',
|
||||
'tool_name': 'test_tool',
|
||||
'tool_result': 'Test Param 2',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_template_error
|
||||
dict({
|
||||
'conversation_id': <ANY>,
|
||||
|
@ -282,7 +282,7 @@ async def test_extra_systen_prompt(
|
||||
@pytest.mark.parametrize(
|
||||
"prerun_tool_tasks",
|
||||
[
|
||||
None,
|
||||
(),
|
||||
("mock-tool-call-id",),
|
||||
("mock-tool-call-id", "mock-tool-call-id-2"),
|
||||
],
|
||||
@ -290,7 +290,7 @@ async def test_extra_systen_prompt(
|
||||
async def test_tool_call(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
prerun_tool_tasks: tuple[str] | None,
|
||||
prerun_tool_tasks: tuple[str],
|
||||
) -> None:
|
||||
"""Test using the session tool calling API."""
|
||||
|
||||
@ -334,15 +334,13 @@ async def test_tool_call(
|
||||
],
|
||||
)
|
||||
|
||||
tool_call_tasks = None
|
||||
if prerun_tool_tasks:
|
||||
tool_call_tasks = {
|
||||
tool_call_id: hass.async_create_task(
|
||||
chat_log.llm_api.async_call_tool(content.tool_calls[0]),
|
||||
tool_call_id,
|
||||
)
|
||||
for tool_call_id in prerun_tool_tasks
|
||||
}
|
||||
tool_call_tasks = {
|
||||
tool_call_id: hass.async_create_task(
|
||||
chat_log.llm_api.async_call_tool(content.tool_calls[0]),
|
||||
tool_call_id,
|
||||
)
|
||||
for tool_call_id in prerun_tool_tasks
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_log.async_add_assistant_content_without_tools(content)
|
||||
@ -350,7 +348,7 @@ async def test_tool_call(
|
||||
results = [
|
||||
tool_result_content
|
||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||
content, tool_call_tasks=tool_call_tasks
|
||||
content, tool_call_tasks=tool_call_tasks or None
|
||||
)
|
||||
]
|
||||
|
||||
@ -382,37 +380,36 @@ async def test_tool_call_exception(
|
||||
)
|
||||
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||
) as mock_get_tools:
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||
) as mock_get_tools,
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
result = None
|
||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content="",
|
||||
tool_calls=[
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
)
|
||||
],
|
||||
)
|
||||
result = None
|
||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content="",
|
||||
tool_calls=[
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
)
|
||||
],
|
||||
)
|
||||
):
|
||||
assert result is None
|
||||
result = tool_result_content
|
||||
):
|
||||
assert result is None
|
||||
result = tool_result_content
|
||||
|
||||
assert result == ToolResultContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
@ -420,3 +417,188 @@ async def test_tool_call_exception(
|
||||
tool_result={"error": "HomeAssistantError", "error_text": "Test error"},
|
||||
tool_name="test_tool",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"deltas",
|
||||
[
|
||||
[],
|
||||
# With content
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{"content": "Test"},
|
||||
],
|
||||
# With 2 content
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{"content": "Test"},
|
||||
{"role": "assistant"},
|
||||
{"content": "Test 2"},
|
||||
],
|
||||
# With 1 tool call
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param 1"},
|
||||
)
|
||||
]
|
||||
},
|
||||
],
|
||||
# With content and 1 tool call
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{"content": "Test"},
|
||||
{
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param 1"},
|
||||
)
|
||||
]
|
||||
},
|
||||
],
|
||||
# With 2 contents and 1 tool call
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{"content": "Test"},
|
||||
{
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param 1"},
|
||||
)
|
||||
]
|
||||
},
|
||||
{"role": "assistant"},
|
||||
{"content": "Test 2"},
|
||||
],
|
||||
# With 2 tool calls
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param 1"},
|
||||
)
|
||||
]
|
||||
},
|
||||
{
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id-2",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param 2"},
|
||||
)
|
||||
]
|
||||
},
|
||||
],
|
||||
],
|
||||
)
|
||||
async def test_add_delta_content_stream(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
snapshot: SnapshotAssertion,
|
||||
deltas: list[dict],
|
||||
) -> None:
|
||||
"""Test streaming deltas."""
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
)
|
||||
|
||||
async def tool_call(
|
||||
hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext
|
||||
) -> str:
|
||||
"""Call the tool."""
|
||||
return tool_input.tool_args["param1"]
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
|
||||
async def stream():
|
||||
"""Yield deltas."""
|
||||
for d in deltas:
|
||||
yield d
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||
) as mock_get_tools,
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
|
||||
results = [
|
||||
tool_result_content
|
||||
async for tool_result_content in chat_log.async_add_delta_content_stream(
|
||||
"mock-agent-id", stream()
|
||||
)
|
||||
]
|
||||
|
||||
assert results == snapshot
|
||||
assert chat_log.content[2:] == results
|
||||
|
||||
|
||||
async def test_add_delta_content_stream_errors(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
) -> None:
|
||||
"""Test streaming deltas error handling."""
|
||||
|
||||
async def stream(deltas):
|
||||
"""Yield deltas."""
|
||||
for d in deltas:
|
||||
yield d
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
# Stream content without LLM API set
|
||||
with pytest.raises(ValueError): # noqa: PT012
|
||||
async for _tool_result_content in chat_log.async_add_delta_content_stream(
|
||||
"mock-agent-id",
|
||||
stream(
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={},
|
||||
)
|
||||
]
|
||||
},
|
||||
]
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
# Non assistant role
|
||||
for role in "system", "user":
|
||||
with pytest.raises(ValueError): # noqa: PT012
|
||||
async for (
|
||||
_tool_result_content
|
||||
) in chat_log.async_add_delta_content_stream(
|
||||
"mock-agent-id",
|
||||
stream([{"role": role}]),
|
||||
):
|
||||
pass
|
||||
|
@ -1,34 +1,64 @@
|
||||
# serializer version: 1
|
||||
# name: test_unknown_hass_api
|
||||
dict({
|
||||
'conversation_id': 'my-conversation-id',
|
||||
'response': IntentResponse(
|
||||
card=dict({
|
||||
}),
|
||||
error_code=<IntentResponseErrorCode.UNKNOWN: 'unknown'>,
|
||||
failed_results=list([
|
||||
]),
|
||||
intent=None,
|
||||
intent_targets=list([
|
||||
]),
|
||||
language='en',
|
||||
matched_states=list([
|
||||
]),
|
||||
reprompt=dict({
|
||||
}),
|
||||
response_type=<IntentResponseType.ERROR: 'error'>,
|
||||
speech=dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': 'Error preparing LLM API',
|
||||
# name: test_function_call
|
||||
list([
|
||||
dict({
|
||||
'content': '''
|
||||
Current time is 16:00:00. Today's date is 2024-06-03.
|
||||
You are a voice assistant for Home Assistant.
|
||||
Answer questions about the world truthfully.
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
|
||||
''',
|
||||
'role': 'system',
|
||||
}),
|
||||
dict({
|
||||
'content': 'hello',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'content': 'Please call the test function',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.openai',
|
||||
'content': None,
|
||||
'role': 'assistant',
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'call_call_1',
|
||||
'tool_args': dict({
|
||||
'param1': 'call1',
|
||||
}),
|
||||
'tool_name': 'test_tool',
|
||||
}),
|
||||
dict({
|
||||
'id': 'call_call_2',
|
||||
'tool_args': dict({
|
||||
'param1': 'call2',
|
||||
}),
|
||||
'tool_name': 'test_tool',
|
||||
}),
|
||||
}),
|
||||
speech_slots=dict({
|
||||
}),
|
||||
success_results=list([
|
||||
]),
|
||||
unmatched_states=list([
|
||||
]),
|
||||
),
|
||||
})
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.openai',
|
||||
'role': 'tool_result',
|
||||
'tool_call_id': 'call_call_1',
|
||||
'tool_name': 'test_tool',
|
||||
'tool_result': 'value1',
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.openai',
|
||||
'role': 'tool_result',
|
||||
'tool_call_id': 'call_call_2',
|
||||
'tool_name': 'test_tool',
|
||||
'tool_result': 'value2',
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.openai',
|
||||
'content': 'Cool',
|
||||
'role': 'assistant',
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
|
@ -1,29 +1,130 @@
|
||||
"""Tests for the OpenAI integration."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from httpx import Response
|
||||
from openai import RateLimitError
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
Choice,
|
||||
ChoiceDelta,
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
import voluptuous as vol
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.components.conversation import chat_log
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import intent, llm
|
||||
from homeassistant.helpers import chat_session, intent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
ASSIST_RESPONSE_FINISH = (
|
||||
# Assistant message
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-B",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[Choice(index=0, delta=ChoiceDelta(content="Cool"))],
|
||||
),
|
||||
# Finish stream
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-B",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[Choice(index=0, finish_reason="stop", delta=ChoiceDelta())],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_create_stream() -> Generator[AsyncMock]:
|
||||
"""Mock stream response."""
|
||||
|
||||
async def mock_generator(stream):
|
||||
for value in stream:
|
||||
yield value
|
||||
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
AsyncMock(),
|
||||
) as mock_create:
|
||||
mock_create.side_effect = lambda **kwargs: mock_generator(
|
||||
mock_create.return_value.pop(0)
|
||||
)
|
||||
|
||||
yield mock_create
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockChatLog(chat_log.ChatLog):
|
||||
"""Mock chat log."""
|
||||
|
||||
_mock_tool_results: dict = field(default_factory=dict)
|
||||
|
||||
def mock_tool_results(self, results: dict) -> None:
|
||||
"""Set tool results."""
|
||||
self._mock_tool_results = results
|
||||
|
||||
@property
|
||||
def llm_api(self):
|
||||
"""Return LLM API."""
|
||||
return self._llm_api
|
||||
|
||||
@llm_api.setter
|
||||
def llm_api(self, value):
|
||||
"""Set LLM API."""
|
||||
self._llm_api = value
|
||||
|
||||
if not value:
|
||||
return
|
||||
|
||||
async def async_call_tool(tool_input):
|
||||
"""Call tool."""
|
||||
if tool_input.id not in self._mock_tool_results:
|
||||
raise ValueError(f"Tool {tool_input.id} not found")
|
||||
return self._mock_tool_results[tool_input.id]
|
||||
|
||||
self._llm_api.async_call_tool = async_call_tool
|
||||
|
||||
def latest_content(self) -> list[conversation.Content]:
|
||||
"""Return content from latest version chat log.
|
||||
|
||||
The chat log makes copies until it's committed. Helper to get latest content.
|
||||
"""
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, self.conversation_id
|
||||
) as session,
|
||||
conversation.async_get_chat_log(self.hass, session) as chat_log,
|
||||
):
|
||||
return chat_log.content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog:
|
||||
"""Return mock chat logs."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.conversation.chat_log.ChatLog",
|
||||
MockChatLog,
|
||||
),
|
||||
chat_session.async_get_chat_session(hass, "mock-conversation-id") as session,
|
||||
conversation.async_get_chat_log(hass, session) as chat_log,
|
||||
):
|
||||
chat_log.async_add_user_content(conversation.UserContent("hello"))
|
||||
return chat_log
|
||||
|
||||
|
||||
async def test_entity(
|
||||
hass: HomeAssistant,
|
||||
@ -83,348 +184,299 @@ async def test_conversation_agent(
|
||||
assert agent.supported_languages == "*"
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||
)
|
||||
async def test_function_call(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
mock_chat_log: MockChatLog,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test function call from the assistant."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
mock_create_stream.return_value = [
|
||||
# Initial conversation
|
||||
(
|
||||
# First tool call
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
id="call_call_1",
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="test_tool",
|
||||
arguments=None,
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=None,
|
||||
arguments='{"para',
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=None,
|
||||
arguments='m1":"call1"}',
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
# Second tool call
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
id="call_call_2",
|
||||
index=1,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="test_tool",
|
||||
arguments='{"param1":"call2"}',
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
# Finish stream
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(index=0, finish_reason="tool_calls", delta=ChoiceDelta())
|
||||
],
|
||||
),
|
||||
),
|
||||
# Response after tool responses
|
||||
ASSIST_RESPONSE_FINISH,
|
||||
]
|
||||
mock_chat_log.mock_tool_results(
|
||||
{
|
||||
"call_call_1": "value1",
|
||||
"call_call_2": "value2",
|
||||
}
|
||||
)
|
||||
mock_tool.async_call.return_value = "Test response"
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
with freeze_time("2024-06-03 23:00:00"):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
"mock-conversation-id",
|
||||
Context(),
|
||||
agent_id="conversation.openai",
|
||||
)
|
||||
|
||||
def completion_result(*args, messages, **kwargs):
|
||||
for message in messages:
|
||||
role = message["role"] if isinstance(message, dict) else message.role
|
||||
if role == "tool":
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert mock_chat_log.latest_content() == snapshot
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("description", "messages"),
|
||||
[
|
||||
(
|
||||
"Test function call started with missing arguments",
|
||||
(
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="I have successfully called the function",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=None,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
id="call_call_1",
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="test_tool",
|
||||
arguments=None,
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-B",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||
),
|
||||
)
|
||||
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||
function=Function(
|
||||
arguments='{"param1":"test_value"}',
|
||||
name="test_tool",
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||
object="chat.completion.chunk",
|
||||
choices=[Choice(index=0, delta=ChoiceDelta(content="Cool"))],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=completion_result,
|
||||
) as mock_create,
|
||||
freeze_time("2024-06-03 23:00:00"),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert (
|
||||
"Today's date is 2024-06-03."
|
||||
in mock_create.mock_calls[1][2]["messages"][0]["content"]
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert mock_create.mock_calls[1][2]["messages"][3] == {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||
"content": '"Test response"',
|
||||
}
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
llm.LLMContext(
|
||||
platform="openai_conversation",
|
||||
context=context,
|
||||
user_prompt="Please call the test function",
|
||||
language="en",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
(
|
||||
"Test invalid JSON",
|
||||
(
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
id="call_call_1",
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="test_tool",
|
||||
arguments=None,
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-A",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=None,
|
||||
arguments='{"para',
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatCompletionChunk(
|
||||
id="chatcmpl-B",
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(content="Cool"),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Test Conversation tracing
|
||||
traces = trace.async_get_traces()
|
||||
assert traces
|
||||
last_trace = traces[-1].as_dict()
|
||||
trace_events = last_trace.get("events", [])
|
||||
assert [event["event_type"] for event in trace_events] == [
|
||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
trace.ConversationTraceEventType.TOOL_CALL,
|
||||
]
|
||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||
detail_event = trace_events[1]
|
||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
||||
assert (
|
||||
"Today's date is 2024-06-03."
|
||||
in trace_events[1]["data"]["messages"][0]["content"]
|
||||
)
|
||||
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
|
||||
|
||||
# Call it again, make sure we have updated prompt
|
||||
with (
|
||||
patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=completion_result,
|
||||
) as mock_create,
|
||||
freeze_time("2024-06-04 23:00:00"),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert (
|
||||
"Today's date is 2024-06-04."
|
||||
in mock_create.mock_calls[1][2]["messages"][0]["content"]
|
||||
)
|
||||
# Test old assert message not updated
|
||||
assert (
|
||||
"Today's date is 2024-06-03."
|
||||
in trace_events[1]["data"]["messages"][0]["content"]
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||
],
|
||||
)
|
||||
async def test_function_exception(
|
||||
mock_get_tools,
|
||||
async def test_function_call_invalid(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
mock_chat_log: MockChatLog,
|
||||
description: str,
|
||||
messages: tuple[ChatCompletionChunk],
|
||||
) -> None:
|
||||
"""Test function call with exception."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
context = Context()
|
||||
"""Test function call containing invalid data."""
|
||||
mock_create_stream.return_value = [messages]
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
)
|
||||
mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception")
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
def completion_result(*args, messages, **kwargs):
|
||||
for message in messages:
|
||||
role = message["role"] if isinstance(message, dict) else message.role
|
||||
if role == "tool":
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="There was an error calling the function",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||
),
|
||||
)
|
||||
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||
function=Function(
|
||||
arguments='{"param1":"test_value"}',
|
||||
name="test_tool",
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=completion_result,
|
||||
) as mock_create:
|
||||
result = await conversation.async_converse(
|
||||
with pytest.raises(ValueError):
|
||||
await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
"mock-conversation-id",
|
||||
Context(),
|
||||
agent_id="conversation.openai",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert mock_create.mock_calls[1][2]["messages"][3] == {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||
"content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}',
|
||||
}
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
llm.LLMContext(
|
||||
platform="openai_conversation",
|
||||
context=context,
|
||||
user_prompt="Please call the test function",
|
||||
language="en",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def test_assist_api_tools_conversion(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream,
|
||||
) -> None:
|
||||
"""Test that we are able to convert actual tools from Assist API."""
|
||||
for component in (
|
||||
"intent",
|
||||
"todo",
|
||||
"light",
|
||||
"shopping_list",
|
||||
"humidifier",
|
||||
"calendar",
|
||||
"climate",
|
||||
"media_player",
|
||||
"vacuum",
|
||||
"cover",
|
||||
"humidifier",
|
||||
"intent",
|
||||
"light",
|
||||
"media_player",
|
||||
"script",
|
||||
"shopping_list",
|
||||
"todo",
|
||||
"vacuum",
|
||||
"weather",
|
||||
):
|
||||
assert await async_setup_component(hass, component, {})
|
||||
hass.states.async_set(f"{component}.test", "on")
|
||||
async_expose_entity(hass, "conversation", f"{component}.test", True)
|
||||
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=ChatCompletion(
|
||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="Hello, how can I help you?",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="gpt-3.5-turbo-0613",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||
),
|
||||
),
|
||||
) as mock_create:
|
||||
await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=agent_id
|
||||
)
|
||||
mock_create_stream.return_value = [ASSIST_RESPONSE_FINISH]
|
||||
|
||||
tools = mock_create.mock_calls[0][2]["tools"]
|
||||
await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id="conversation.openai"
|
||||
)
|
||||
|
||||
tools = mock_create_stream.mock_calls[0][2]["tools"]
|
||||
assert tools
|
||||
|
Loading…
x
Reference in New Issue
Block a user