mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Stream OpenAI messages into the chat log (#137400)
This commit is contained in:
parent
a526baa831
commit
df307aeb6d
@ -32,6 +32,7 @@ from .agent_manager import (
|
|||||||
)
|
)
|
||||||
from .chat_log import (
|
from .chat_log import (
|
||||||
AssistantContent,
|
AssistantContent,
|
||||||
|
AssistantContentDeltaDict,
|
||||||
ChatLog,
|
ChatLog,
|
||||||
Content,
|
Content,
|
||||||
ConverseError,
|
ConverseError,
|
||||||
@ -65,6 +66,7 @@ __all__ = [
|
|||||||
"HOME_ASSISTANT_AGENT",
|
"HOME_ASSISTANT_AGENT",
|
||||||
"OLD_HOME_ASSISTANT_AGENT",
|
"OLD_HOME_ASSISTANT_AGENT",
|
||||||
"AssistantContent",
|
"AssistantContent",
|
||||||
|
"AssistantContentDeltaDict",
|
||||||
"ChatLog",
|
"ChatLog",
|
||||||
"Content",
|
"Content",
|
||||||
"ConversationEntity",
|
"ConversationEntity",
|
||||||
|
@ -3,11 +3,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, AsyncIterable, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -145,6 +146,14 @@ class ToolResultContent:
|
|||||||
type Content = SystemContent | UserContent | AssistantContent | ToolResultContent
|
type Content = SystemContent | UserContent | AssistantContent | ToolResultContent
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantContentDeltaDict(TypedDict, total=False):
|
||||||
|
"""Partial content to define an AssistantContent."""
|
||||||
|
|
||||||
|
role: Literal["assistant"]
|
||||||
|
content: str | None
|
||||||
|
tool_calls: list[llm.ToolInput] | None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatLog:
|
class ChatLog:
|
||||||
"""Class holding the chat history of a specific conversation."""
|
"""Class holding the chat history of a specific conversation."""
|
||||||
@ -155,6 +164,11 @@ class ChatLog:
|
|||||||
extra_system_prompt: str | None = None
|
extra_system_prompt: str | None = None
|
||||||
llm_api: llm.APIInstance | None = None
|
llm_api: llm.APIInstance | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unresponded_tool_results(self) -> bool:
|
||||||
|
"""Return if there are unresponded tool results."""
|
||||||
|
return self.content[-1].role == "tool_result"
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_add_user_content(self, content: UserContent) -> None:
|
def async_add_user_content(self, content: UserContent) -> None:
|
||||||
"""Add user content to the log."""
|
"""Add user content to the log."""
|
||||||
@ -223,6 +237,77 @@ class ChatLog:
|
|||||||
self.content.append(response_content)
|
self.content.append(response_content)
|
||||||
yield response_content
|
yield response_content
|
||||||
|
|
||||||
|
async def async_add_delta_content_stream(
|
||||||
|
self, agent_id: str, stream: AsyncIterable[AssistantContentDeltaDict]
|
||||||
|
) -> AsyncGenerator[AssistantContent | ToolResultContent]:
|
||||||
|
"""Stream content into the chat log.
|
||||||
|
|
||||||
|
Returns a generator with all content that was added to the chat log.
|
||||||
|
|
||||||
|
stream iterates over dictionaries with optional keys role, content and tool_calls.
|
||||||
|
|
||||||
|
When a delta contains a role key, the current message is considered complete and
|
||||||
|
a new message is started.
|
||||||
|
|
||||||
|
The keys content and tool_calls will be concatenated if they appear multiple times.
|
||||||
|
"""
|
||||||
|
current_content = ""
|
||||||
|
current_tool_calls: list[llm.ToolInput] = []
|
||||||
|
tool_call_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
async for delta in stream:
|
||||||
|
LOGGER.debug("Received delta: %s", delta)
|
||||||
|
|
||||||
|
# Indicates update to current message
|
||||||
|
if "role" not in delta:
|
||||||
|
if delta_content := delta.get("content"):
|
||||||
|
current_content += delta_content
|
||||||
|
if delta_tool_calls := delta.get("tool_calls"):
|
||||||
|
if self.llm_api is None:
|
||||||
|
raise ValueError("No LLM API configured")
|
||||||
|
current_tool_calls += delta_tool_calls
|
||||||
|
|
||||||
|
# Start processing the tool calls as soon as we know about them
|
||||||
|
for tool_call in delta_tool_calls:
|
||||||
|
tool_call_tasks[tool_call.id] = self.hass.async_create_task(
|
||||||
|
self.llm_api.async_call_tool(tool_call),
|
||||||
|
name=f"llm_tool_{tool_call.id}",
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Starting a new message
|
||||||
|
|
||||||
|
if delta["role"] != "assistant":
|
||||||
|
raise ValueError(f"Only assistant role expected. Got {delta['role']}")
|
||||||
|
|
||||||
|
# Yield the previous message if it has content
|
||||||
|
if current_content or current_tool_calls:
|
||||||
|
content = AssistantContent(
|
||||||
|
agent_id=agent_id,
|
||||||
|
content=current_content or None,
|
||||||
|
tool_calls=current_tool_calls or None,
|
||||||
|
)
|
||||||
|
yield content
|
||||||
|
async for tool_result in self.async_add_assistant_content(
|
||||||
|
content, tool_call_tasks=tool_call_tasks
|
||||||
|
):
|
||||||
|
yield tool_result
|
||||||
|
|
||||||
|
current_content = delta.get("content") or ""
|
||||||
|
current_tool_calls = delta.get("tool_calls") or []
|
||||||
|
|
||||||
|
if current_content or current_tool_calls:
|
||||||
|
content = AssistantContent(
|
||||||
|
agent_id=agent_id,
|
||||||
|
content=current_content or None,
|
||||||
|
tool_calls=current_tool_calls or None,
|
||||||
|
)
|
||||||
|
yield content
|
||||||
|
async for tool_result in self.async_add_assistant_content(
|
||||||
|
content, tool_call_tasks=tool_call_tasks
|
||||||
|
):
|
||||||
|
yield tool_result
|
||||||
|
|
||||||
async def async_update_llm_data(
|
async def async_update_llm_data(
|
||||||
self,
|
self,
|
||||||
conversing_domain: str,
|
conversing_domain: str,
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
"""Conversation support for OpenAI."""
|
"""Conversation support for OpenAI."""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
import json
|
import json
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
from openai._streaming import AsyncStream
|
||||||
from openai._types import NOT_GIVEN
|
from openai._types import NOT_GIVEN
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionAssistantMessageParam,
|
ChatCompletionAssistantMessageParam,
|
||||||
ChatCompletionMessage,
|
ChatCompletionChunk,
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ChatCompletionMessageToolCallParam,
|
ChatCompletionMessageToolCallParam,
|
||||||
ChatCompletionToolMessageParam,
|
ChatCompletionToolMessageParam,
|
||||||
@ -70,32 +71,6 @@ def _format_tool(
|
|||||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||||
|
|
||||||
|
|
||||||
def _convert_message_to_param(
|
|
||||||
message: ChatCompletionMessage,
|
|
||||||
) -> ChatCompletionMessageParam:
|
|
||||||
"""Convert from class to TypedDict."""
|
|
||||||
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
|
||||||
if message.tool_calls:
|
|
||||||
tool_calls = [
|
|
||||||
ChatCompletionMessageToolCallParam(
|
|
||||||
id=tool_call.id,
|
|
||||||
function=Function(
|
|
||||||
arguments=tool_call.function.arguments,
|
|
||||||
name=tool_call.function.name,
|
|
||||||
),
|
|
||||||
type=tool_call.type,
|
|
||||||
)
|
|
||||||
for tool_call in message.tool_calls
|
|
||||||
]
|
|
||||||
param = ChatCompletionAssistantMessageParam(
|
|
||||||
role=message.role,
|
|
||||||
content=message.content,
|
|
||||||
)
|
|
||||||
if tool_calls:
|
|
||||||
param["tool_calls"] = tool_calls
|
|
||||||
return param
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_content_to_param(
|
def _convert_content_to_param(
|
||||||
content: conversation.Content,
|
content: conversation.Content,
|
||||||
) -> ChatCompletionMessageParam:
|
) -> ChatCompletionMessageParam:
|
||||||
@ -135,6 +110,74 @@ def _convert_content_to_param(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _transform_stream(
|
||||||
|
result: AsyncStream[ChatCompletionChunk],
|
||||||
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
|
"""Transform an OpenAI delta stream into HA format."""
|
||||||
|
current_tool_call: dict | None = None
|
||||||
|
|
||||||
|
async for chunk in result:
|
||||||
|
LOGGER.debug("Received chunk: %s", chunk)
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
|
||||||
|
if choice.finish_reason:
|
||||||
|
if current_tool_call:
|
||||||
|
yield {
|
||||||
|
"tool_calls": [
|
||||||
|
llm.ToolInput(
|
||||||
|
id=current_tool_call["id"],
|
||||||
|
tool_name=current_tool_call["tool_name"],
|
||||||
|
tool_args=json.loads(current_tool_call["tool_args"]),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
|
||||||
|
# We can yield delta messages not continuing or starting tool calls
|
||||||
|
if current_tool_call is None and not delta.tool_calls:
|
||||||
|
yield { # type: ignore[misc]
|
||||||
|
key: value
|
||||||
|
for key in ("role", "content")
|
||||||
|
if (value := getattr(delta, key)) is not None
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
|
# When doing tool calls, we should always have a tool call
|
||||||
|
# object or we have gotten stopped above with a finish_reason set.
|
||||||
|
if (
|
||||||
|
not delta.tool_calls
|
||||||
|
or not (delta_tool_call := delta.tool_calls[0])
|
||||||
|
or not delta_tool_call.function
|
||||||
|
):
|
||||||
|
raise ValueError("Expected delta with tool call")
|
||||||
|
|
||||||
|
if current_tool_call and delta_tool_call.index == current_tool_call["index"]:
|
||||||
|
current_tool_call["tool_args"] += delta_tool_call.function.arguments or ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We got tool call with new index, so we need to yield the previous
|
||||||
|
if current_tool_call:
|
||||||
|
yield {
|
||||||
|
"tool_calls": [
|
||||||
|
llm.ToolInput(
|
||||||
|
id=current_tool_call["id"],
|
||||||
|
tool_name=current_tool_call["tool_name"],
|
||||||
|
tool_args=json.loads(current_tool_call["tool_args"]),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
current_tool_call = {
|
||||||
|
"index": delta_tool_call.index,
|
||||||
|
"id": delta_tool_call.id,
|
||||||
|
"tool_name": delta_tool_call.function.name,
|
||||||
|
"tool_args": delta_tool_call.function.arguments or "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConversationEntity(
|
class OpenAIConversationEntity(
|
||||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||||
):
|
):
|
||||||
@ -234,6 +277,7 @@ class OpenAIConversationEntity(
|
|||||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
"user": chat_log.conversation_id,
|
"user": chat_log.conversation_id,
|
||||||
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.startswith("o"):
|
if model.startswith("o"):
|
||||||
@ -247,39 +291,21 @@ class OpenAIConversationEntity(
|
|||||||
LOGGER.error("Error talking to OpenAI: %s", err)
|
LOGGER.error("Error talking to OpenAI: %s", err)
|
||||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||||
|
|
||||||
LOGGER.debug("Response %s", result)
|
|
||||||
response = result.choices[0].message
|
|
||||||
messages.append(_convert_message_to_param(response))
|
|
||||||
|
|
||||||
tool_calls: list[llm.ToolInput] | None = None
|
|
||||||
if response.tool_calls:
|
|
||||||
tool_calls = [
|
|
||||||
llm.ToolInput(
|
|
||||||
id=tool_call.id,
|
|
||||||
tool_name=tool_call.function.name,
|
|
||||||
tool_args=json.loads(tool_call.function.arguments),
|
|
||||||
)
|
|
||||||
for tool_call in response.tool_calls
|
|
||||||
]
|
|
||||||
|
|
||||||
messages.extend(
|
messages.extend(
|
||||||
[
|
[
|
||||||
_convert_content_to_param(tool_response)
|
_convert_content_to_param(content)
|
||||||
async for tool_response in chat_log.async_add_assistant_content(
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
conversation.AssistantContent(
|
user_input.agent_id, _transform_stream(result)
|
||||||
agent_id=user_input.agent_id,
|
|
||||||
content=response.content or "",
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if not tool_calls:
|
if not chat_log.unresponded_tool_results:
|
||||||
break
|
break
|
||||||
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
intent_response.async_set_speech(response.content or "")
|
assert type(chat_log.content[-1]) is conversation.AssistantContent
|
||||||
|
intent_response.async_set_speech(chat_log.content[-1].content or "")
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=chat_log.conversation_id
|
response=intent_response, conversation_id=chat_log.conversation_id
|
||||||
)
|
)
|
||||||
|
@ -7,6 +7,7 @@ from contextlib import contextmanager
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
import logging
|
||||||
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||||
from homeassistant.core import (
|
from homeassistant.core import (
|
||||||
@ -27,6 +28,7 @@ DATA_CHAT_SESSION: HassKey[dict[str, ChatSession]] = HassKey("chat_session")
|
|||||||
DATA_CHAT_SESSION_CLEANUP: HassKey[SessionCleanup] = HassKey("chat_session_cleanup")
|
DATA_CHAT_SESSION_CLEANUP: HassKey[SessionCleanup] = HassKey("chat_session_cleanup")
|
||||||
|
|
||||||
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
current_session: ContextVar[ChatSession | None] = ContextVar(
|
current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||||
"current_session", default=None
|
"current_session", default=None
|
||||||
@ -100,6 +102,7 @@ class SessionCleanup:
|
|||||||
# yielding session based on it.
|
# yielding session based on it.
|
||||||
for conversation_id, session in list(all_sessions.items()):
|
for conversation_id, session in list(all_sessions.items()):
|
||||||
if session.last_updated + CONVERSATION_TIMEOUT < now:
|
if session.last_updated + CONVERSATION_TIMEOUT < now:
|
||||||
|
LOGGER.debug("Cleaning up session %s", conversation_id)
|
||||||
del all_sessions[conversation_id]
|
del all_sessions[conversation_id]
|
||||||
session.async_cleanup()
|
session.async_cleanup()
|
||||||
|
|
||||||
@ -150,6 +153,7 @@ def async_get_chat_session(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
|
LOGGER.debug("Creating new session %s", conversation_id)
|
||||||
session = ChatSession(conversation_id)
|
session = ChatSession(conversation_id)
|
||||||
|
|
||||||
current_session.set(session)
|
current_session.set(session)
|
||||||
|
@ -1,4 +1,154 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
|
# name: test_add_delta_content_stream[deltas0]
|
||||||
|
list([
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_add_delta_content_stream[deltas1]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': 'Test',
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': None,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_add_delta_content_stream[deltas2]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': 'Test',
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': None,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': 'Test 2',
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': None,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_add_delta_content_stream[deltas3]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': None,
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': list([
|
||||||
|
dict({
|
||||||
|
'id': 'mock-tool-call-id',
|
||||||
|
'tool_args': dict({
|
||||||
|
'param1': 'Test Param 1',
|
||||||
|
}),
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'role': 'tool_result',
|
||||||
|
'tool_call_id': 'mock-tool-call-id',
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
'tool_result': 'Test Param 1',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_add_delta_content_stream[deltas4]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': 'Test',
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': list([
|
||||||
|
dict({
|
||||||
|
'id': 'mock-tool-call-id',
|
||||||
|
'tool_args': dict({
|
||||||
|
'param1': 'Test Param 1',
|
||||||
|
}),
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'role': 'tool_result',
|
||||||
|
'tool_call_id': 'mock-tool-call-id',
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
'tool_result': 'Test Param 1',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_add_delta_content_stream[deltas5]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': 'Test',
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': list([
|
||||||
|
dict({
|
||||||
|
'id': 'mock-tool-call-id',
|
||||||
|
'tool_args': dict({
|
||||||
|
'param1': 'Test Param 1',
|
||||||
|
}),
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'role': 'tool_result',
|
||||||
|
'tool_call_id': 'mock-tool-call-id',
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
'tool_result': 'Test Param 1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': 'Test 2',
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': None,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_add_delta_content_stream[deltas6]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': None,
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': list([
|
||||||
|
dict({
|
||||||
|
'id': 'mock-tool-call-id',
|
||||||
|
'tool_args': dict({
|
||||||
|
'param1': 'Test Param 1',
|
||||||
|
}),
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 'mock-tool-call-id-2',
|
||||||
|
'tool_args': dict({
|
||||||
|
'param1': 'Test Param 2',
|
||||||
|
}),
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'role': 'tool_result',
|
||||||
|
'tool_call_id': 'mock-tool-call-id',
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
'tool_result': 'Test Param 1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'role': 'tool_result',
|
||||||
|
'tool_call_id': 'mock-tool-call-id-2',
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
'tool_result': 'Test Param 2',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_template_error
|
# name: test_template_error
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': <ANY>,
|
'conversation_id': <ANY>,
|
||||||
|
@ -282,7 +282,7 @@ async def test_extra_systen_prompt(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"prerun_tool_tasks",
|
"prerun_tool_tasks",
|
||||||
[
|
[
|
||||||
None,
|
(),
|
||||||
("mock-tool-call-id",),
|
("mock-tool-call-id",),
|
||||||
("mock-tool-call-id", "mock-tool-call-id-2"),
|
("mock-tool-call-id", "mock-tool-call-id-2"),
|
||||||
],
|
],
|
||||||
@ -290,7 +290,7 @@ async def test_extra_systen_prompt(
|
|||||||
async def test_tool_call(
|
async def test_tool_call(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_conversation_input: ConversationInput,
|
mock_conversation_input: ConversationInput,
|
||||||
prerun_tool_tasks: tuple[str] | None,
|
prerun_tool_tasks: tuple[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test using the session tool calling API."""
|
"""Test using the session tool calling API."""
|
||||||
|
|
||||||
@ -334,15 +334,13 @@ async def test_tool_call(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call_tasks = None
|
tool_call_tasks = {
|
||||||
if prerun_tool_tasks:
|
tool_call_id: hass.async_create_task(
|
||||||
tool_call_tasks = {
|
chat_log.llm_api.async_call_tool(content.tool_calls[0]),
|
||||||
tool_call_id: hass.async_create_task(
|
tool_call_id,
|
||||||
chat_log.llm_api.async_call_tool(content.tool_calls[0]),
|
)
|
||||||
tool_call_id,
|
for tool_call_id in prerun_tool_tasks
|
||||||
)
|
}
|
||||||
for tool_call_id in prerun_tool_tasks
|
|
||||||
}
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chat_log.async_add_assistant_content_without_tools(content)
|
chat_log.async_add_assistant_content_without_tools(content)
|
||||||
@ -350,7 +348,7 @@ async def test_tool_call(
|
|||||||
results = [
|
results = [
|
||||||
tool_result_content
|
tool_result_content
|
||||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||||
content, tool_call_tasks=tool_call_tasks
|
content, tool_call_tasks=tool_call_tasks or None
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -382,37 +380,36 @@ async def test_tool_call_exception(
|
|||||||
)
|
)
|
||||||
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
|
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
patch(
|
||||||
) as mock_get_tools:
|
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||||
|
) as mock_get_tools,
|
||||||
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
|
):
|
||||||
mock_get_tools.return_value = [mock_tool]
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
with (
|
conversing_domain="test",
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
user_input=mock_conversation_input,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
user_llm_hass_api="assist",
|
||||||
):
|
user_llm_prompt=None,
|
||||||
await chat_log.async_update_llm_data(
|
)
|
||||||
conversing_domain="test",
|
result = None
|
||||||
user_input=mock_conversation_input,
|
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||||
user_llm_hass_api="assist",
|
AssistantContent(
|
||||||
user_llm_prompt=None,
|
agent_id=mock_conversation_input.agent_id,
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id",
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={"param1": "Test Param"},
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
result = None
|
):
|
||||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
assert result is None
|
||||||
AssistantContent(
|
result = tool_result_content
|
||||||
agent_id=mock_conversation_input.agent_id,
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
llm.ToolInput(
|
|
||||||
id="mock-tool-call-id",
|
|
||||||
tool_name="test_tool",
|
|
||||||
tool_args={"param1": "Test Param"},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
):
|
|
||||||
assert result is None
|
|
||||||
result = tool_result_content
|
|
||||||
|
|
||||||
assert result == ToolResultContent(
|
assert result == ToolResultContent(
|
||||||
agent_id=mock_conversation_input.agent_id,
|
agent_id=mock_conversation_input.agent_id,
|
||||||
@ -420,3 +417,188 @@ async def test_tool_call_exception(
|
|||||||
tool_result={"error": "HomeAssistantError", "error_text": "Test error"},
|
tool_result={"error": "HomeAssistantError", "error_text": "Test error"},
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"deltas",
|
||||||
|
[
|
||||||
|
[],
|
||||||
|
# With content
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{"content": "Test"},
|
||||||
|
],
|
||||||
|
# With 2 content
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{"content": "Test"},
|
||||||
|
{"role": "assistant"},
|
||||||
|
{"content": "Test 2"},
|
||||||
|
],
|
||||||
|
# With 1 tool call
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id",
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={"param1": "Test Param 1"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
],
|
||||||
|
# With content and 1 tool call
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{"content": "Test"},
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id",
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={"param1": "Test Param 1"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
],
|
||||||
|
# With 2 contents and 1 tool call
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{"content": "Test"},
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id",
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={"param1": "Test Param 1"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{"role": "assistant"},
|
||||||
|
{"content": "Test 2"},
|
||||||
|
],
|
||||||
|
# With 2 tool calls
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id",
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={"param1": "Test Param 1"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id-2",
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={"param1": "Test Param 2"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_add_delta_content_stream(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_conversation_input: ConversationInput,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
deltas: list[dict],
|
||||||
|
) -> None:
|
||||||
|
"""Test streaming deltas."""
|
||||||
|
|
||||||
|
mock_tool = AsyncMock()
|
||||||
|
mock_tool.name = "test_tool"
|
||||||
|
mock_tool.description = "Test function"
|
||||||
|
mock_tool.parameters = vol.Schema(
|
||||||
|
{vol.Optional("param1", description="Test parameters"): str}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def tool_call(
|
||||||
|
hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext
|
||||||
|
) -> str:
|
||||||
|
"""Call the tool."""
|
||||||
|
return tool_input.tool_args["param1"]
|
||||||
|
|
||||||
|
mock_tool.async_call.side_effect = tool_call
|
||||||
|
|
||||||
|
async def stream():
|
||||||
|
"""Yield deltas."""
|
||||||
|
for d in deltas:
|
||||||
|
yield d
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||||
|
) as mock_get_tools,
|
||||||
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
|
):
|
||||||
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
|
conversing_domain="test",
|
||||||
|
user_input=mock_conversation_input,
|
||||||
|
user_llm_hass_api="assist",
|
||||||
|
user_llm_prompt=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = [
|
||||||
|
tool_result_content
|
||||||
|
async for tool_result_content in chat_log.async_add_delta_content_stream(
|
||||||
|
"mock-agent-id", stream()
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert results == snapshot
|
||||||
|
assert chat_log.content[2:] == results
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_delta_content_stream_errors(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_conversation_input: ConversationInput,
|
||||||
|
) -> None:
|
||||||
|
"""Test streaming deltas error handling."""
|
||||||
|
|
||||||
|
async def stream(deltas):
|
||||||
|
"""Yield deltas."""
|
||||||
|
for d in deltas:
|
||||||
|
yield d
|
||||||
|
|
||||||
|
with (
|
||||||
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
|
):
|
||||||
|
# Stream content without LLM API set
|
||||||
|
with pytest.raises(ValueError): # noqa: PT012
|
||||||
|
async for _tool_result_content in chat_log.async_add_delta_content_stream(
|
||||||
|
"mock-agent-id",
|
||||||
|
stream(
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id",
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Non assistant role
|
||||||
|
for role in "system", "user":
|
||||||
|
with pytest.raises(ValueError): # noqa: PT012
|
||||||
|
async for (
|
||||||
|
_tool_result_content
|
||||||
|
) in chat_log.async_add_delta_content_stream(
|
||||||
|
"mock-agent-id",
|
||||||
|
stream([{"role": role}]),
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
@ -1,34 +1,64 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_unknown_hass_api
|
# name: test_function_call
|
||||||
dict({
|
list([
|
||||||
'conversation_id': 'my-conversation-id',
|
dict({
|
||||||
'response': IntentResponse(
|
'content': '''
|
||||||
card=dict({
|
Current time is 16:00:00. Today's date is 2024-06-03.
|
||||||
}),
|
You are a voice assistant for Home Assistant.
|
||||||
error_code=<IntentResponseErrorCode.UNKNOWN: 'unknown'>,
|
Answer questions about the world truthfully.
|
||||||
failed_results=list([
|
Answer in plain text. Keep it simple and to the point.
|
||||||
]),
|
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
|
||||||
intent=None,
|
''',
|
||||||
intent_targets=list([
|
'role': 'system',
|
||||||
]),
|
}),
|
||||||
language='en',
|
dict({
|
||||||
matched_states=list([
|
'content': 'hello',
|
||||||
]),
|
'role': 'user',
|
||||||
reprompt=dict({
|
}),
|
||||||
}),
|
dict({
|
||||||
response_type=<IntentResponseType.ERROR: 'error'>,
|
'content': 'Please call the test function',
|
||||||
speech=dict({
|
'role': 'user',
|
||||||
'plain': dict({
|
}),
|
||||||
'extra_data': None,
|
dict({
|
||||||
'speech': 'Error preparing LLM API',
|
'agent_id': 'conversation.openai',
|
||||||
|
'content': None,
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': list([
|
||||||
|
dict({
|
||||||
|
'id': 'call_call_1',
|
||||||
|
'tool_args': dict({
|
||||||
|
'param1': 'call1',
|
||||||
|
}),
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 'call_call_2',
|
||||||
|
'tool_args': dict({
|
||||||
|
'param1': 'call2',
|
||||||
|
}),
|
||||||
|
'tool_name': 'test_tool',
|
||||||
}),
|
}),
|
||||||
}),
|
|
||||||
speech_slots=dict({
|
|
||||||
}),
|
|
||||||
success_results=list([
|
|
||||||
]),
|
]),
|
||||||
unmatched_states=list([
|
}),
|
||||||
]),
|
dict({
|
||||||
),
|
'agent_id': 'conversation.openai',
|
||||||
})
|
'role': 'tool_result',
|
||||||
|
'tool_call_id': 'call_call_1',
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
'tool_result': 'value1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'conversation.openai',
|
||||||
|
'role': 'tool_result',
|
||||||
|
'tool_call_id': 'call_call_2',
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
'tool_result': 'value2',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'conversation.openai',
|
||||||
|
'content': 'Cool',
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': None,
|
||||||
|
}),
|
||||||
|
])
|
||||||
# ---
|
# ---
|
||||||
|
@ -1,29 +1,130 @@
|
|||||||
"""Tests for the OpenAI integration."""
|
"""Tests for the OpenAI integration."""
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
from httpx import Response
|
from httpx import Response
|
||||||
from openai import RateLimitError
|
from openai import RateLimitError
|
||||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
ChatCompletionChunk,
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
Choice,
|
||||||
ChatCompletionMessageToolCall,
|
ChoiceDelta,
|
||||||
Function,
|
ChoiceDeltaToolCall,
|
||||||
|
ChoiceDeltaToolCallFunction,
|
||||||
)
|
)
|
||||||
from openai.types.completion_usage import CompletionUsage
|
import pytest
|
||||||
import voluptuous as vol
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.conversation import trace
|
from homeassistant.components.conversation import chat_log
|
||||||
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.helpers import chat_session, intent
|
||||||
from homeassistant.helpers import intent, llm
|
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
ASSIST_RESPONSE_FINISH = (
|
||||||
|
# Assistant message
|
||||||
|
ChatCompletionChunk(
|
||||||
|
id="chatcmpl-B",
|
||||||
|
created=1700000000,
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[Choice(index=0, delta=ChoiceDelta(content="Cool"))],
|
||||||
|
),
|
||||||
|
# Finish stream
|
||||||
|
ChatCompletionChunk(
|
||||||
|
id="chatcmpl-B",
|
||||||
|
created=1700000000,
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[Choice(index=0, finish_reason="stop", delta=ChoiceDelta())],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_create_stream() -> Generator[AsyncMock]:
|
||||||
|
"""Mock stream response."""
|
||||||
|
|
||||||
|
async def mock_generator(stream):
|
||||||
|
for value in stream:
|
||||||
|
yield value
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||||
|
AsyncMock(),
|
||||||
|
) as mock_create:
|
||||||
|
mock_create.side_effect = lambda **kwargs: mock_generator(
|
||||||
|
mock_create.return_value.pop(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield mock_create
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockChatLog(chat_log.ChatLog):
|
||||||
|
"""Mock chat log."""
|
||||||
|
|
||||||
|
_mock_tool_results: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
def mock_tool_results(self, results: dict) -> None:
|
||||||
|
"""Set tool results."""
|
||||||
|
self._mock_tool_results = results
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_api(self):
|
||||||
|
"""Return LLM API."""
|
||||||
|
return self._llm_api
|
||||||
|
|
||||||
|
@llm_api.setter
|
||||||
|
def llm_api(self, value):
|
||||||
|
"""Set LLM API."""
|
||||||
|
self._llm_api = value
|
||||||
|
|
||||||
|
if not value:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def async_call_tool(tool_input):
|
||||||
|
"""Call tool."""
|
||||||
|
if tool_input.id not in self._mock_tool_results:
|
||||||
|
raise ValueError(f"Tool {tool_input.id} not found")
|
||||||
|
return self._mock_tool_results[tool_input.id]
|
||||||
|
|
||||||
|
self._llm_api.async_call_tool = async_call_tool
|
||||||
|
|
||||||
|
def latest_content(self) -> list[conversation.Content]:
|
||||||
|
"""Return content from latest version chat log.
|
||||||
|
|
||||||
|
The chat log makes copies until it's committed. Helper to get latest content.
|
||||||
|
"""
|
||||||
|
with (
|
||||||
|
chat_session.async_get_chat_session(
|
||||||
|
self.hass, self.conversation_id
|
||||||
|
) as session,
|
||||||
|
conversation.async_get_chat_log(self.hass, session) as chat_log,
|
||||||
|
):
|
||||||
|
return chat_log.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog:
|
||||||
|
"""Return mock chat logs."""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.conversation.chat_log.ChatLog",
|
||||||
|
MockChatLog,
|
||||||
|
),
|
||||||
|
chat_session.async_get_chat_session(hass, "mock-conversation-id") as session,
|
||||||
|
conversation.async_get_chat_log(hass, session) as chat_log,
|
||||||
|
):
|
||||||
|
chat_log.async_add_user_content(conversation.UserContent("hello"))
|
||||||
|
return chat_log
|
||||||
|
|
||||||
|
|
||||||
async def test_entity(
|
async def test_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -83,348 +184,299 @@ async def test_conversation_agent(
|
|||||||
assert agent.supported_languages == "*"
|
assert agent.supported_languages == "*"
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
|
||||||
)
|
|
||||||
async def test_function_call(
|
async def test_function_call(
|
||||||
mock_get_tools,
|
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry_with_assist: MockConfigEntry,
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
|
mock_create_stream: AsyncMock,
|
||||||
|
mock_chat_log: MockChatLog,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test function call from the assistant."""
|
"""Test function call from the assistant."""
|
||||||
agent_id = mock_config_entry_with_assist.entry_id
|
mock_create_stream.return_value = [
|
||||||
context = Context()
|
# Initial conversation
|
||||||
|
(
|
||||||
mock_tool = AsyncMock()
|
# First tool call
|
||||||
mock_tool.name = "test_tool"
|
ChatCompletionChunk(
|
||||||
mock_tool.description = "Test function"
|
id="chatcmpl-A",
|
||||||
mock_tool.parameters = vol.Schema(
|
created=1700000000,
|
||||||
{vol.Optional("param1", description="Test parameters"): str}
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
tool_calls=[
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
id="call_call_1",
|
||||||
|
index=0,
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name="test_tool",
|
||||||
|
arguments=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ChatCompletionChunk(
|
||||||
|
id="chatcmpl-A",
|
||||||
|
created=1700000000,
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
tool_calls=[
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name=None,
|
||||||
|
arguments='{"para',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ChatCompletionChunk(
|
||||||
|
id="chatcmpl-A",
|
||||||
|
created=1700000000,
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
tool_calls=[
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name=None,
|
||||||
|
arguments='m1":"call1"}',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
# Second tool call
|
||||||
|
ChatCompletionChunk(
|
||||||
|
id="chatcmpl-A",
|
||||||
|
created=1700000000,
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
tool_calls=[
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
id="call_call_2",
|
||||||
|
index=1,
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name="test_tool",
|
||||||
|
arguments='{"param1":"call2"}',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
# Finish stream
|
||||||
|
ChatCompletionChunk(
|
||||||
|
id="chatcmpl-A",
|
||||||
|
created=1700000000,
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[
|
||||||
|
Choice(index=0, finish_reason="tool_calls", delta=ChoiceDelta())
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# Response after tool responses
|
||||||
|
ASSIST_RESPONSE_FINISH,
|
||||||
|
]
|
||||||
|
mock_chat_log.mock_tool_results(
|
||||||
|
{
|
||||||
|
"call_call_1": "value1",
|
||||||
|
"call_call_2": "value2",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
mock_tool.async_call.return_value = "Test response"
|
|
||||||
|
|
||||||
mock_get_tools.return_value = [mock_tool]
|
with freeze_time("2024-06-03 23:00:00"):
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please call the test function",
|
||||||
|
"mock-conversation-id",
|
||||||
|
Context(),
|
||||||
|
agent_id="conversation.openai",
|
||||||
|
)
|
||||||
|
|
||||||
def completion_result(*args, messages, **kwargs):
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
for message in messages:
|
assert mock_chat_log.latest_content() == snapshot
|
||||||
role = message["role"] if isinstance(message, dict) else message.role
|
|
||||||
if role == "tool":
|
|
||||||
return ChatCompletion(
|
@pytest.mark.parametrize(
|
||||||
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
|
("description", "messages"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"Test function call started with missing arguments",
|
||||||
|
(
|
||||||
|
ChatCompletionChunk(
|
||||||
|
id="chatcmpl-A",
|
||||||
|
created=1700000000,
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion.chunk",
|
||||||
choices=[
|
choices=[
|
||||||
Choice(
|
Choice(
|
||||||
finish_reason="stop",
|
|
||||||
index=0,
|
index=0,
|
||||||
message=ChatCompletionMessage(
|
delta=ChoiceDelta(
|
||||||
content="I have successfully called the function",
|
tool_calls=[
|
||||||
role="assistant",
|
ChoiceDeltaToolCall(
|
||||||
function_call=None,
|
id="call_call_1",
|
||||||
tool_calls=None,
|
index=0,
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name="test_tool",
|
||||||
|
arguments=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
),
|
||||||
|
ChatCompletionChunk(
|
||||||
|
id="chatcmpl-B",
|
||||||
created=1700000000,
|
created=1700000000,
|
||||||
model="gpt-4-1106-preview",
|
model="gpt-4-1106-preview",
|
||||||
object="chat.completion",
|
object="chat.completion.chunk",
|
||||||
system_fingerprint=None,
|
choices=[Choice(index=0, delta=ChoiceDelta(content="Cool"))],
|
||||||
usage=CompletionUsage(
|
),
|
||||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return ChatCompletion(
|
|
||||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
|
||||||
choices=[
|
|
||||||
Choice(
|
|
||||||
finish_reason="tool_calls",
|
|
||||||
index=0,
|
|
||||||
message=ChatCompletionMessage(
|
|
||||||
content=None,
|
|
||||||
role="assistant",
|
|
||||||
function_call=None,
|
|
||||||
tool_calls=[
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
|
||||||
function=Function(
|
|
||||||
arguments='{"param1":"test_value"}',
|
|
||||||
name="test_tool",
|
|
||||||
),
|
|
||||||
type="function",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=1700000000,
|
|
||||||
model="gpt-4-1106-preview",
|
|
||||||
object="chat.completion",
|
|
||||||
system_fingerprint=None,
|
|
||||||
usage=CompletionUsage(
|
|
||||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
|
||||||
),
|
),
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=completion_result,
|
|
||||||
) as mock_create,
|
|
||||||
freeze_time("2024-06-03 23:00:00"),
|
|
||||||
):
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
"Please call the test function",
|
|
||||||
None,
|
|
||||||
context,
|
|
||||||
agent_id=agent_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
"Today's date is 2024-06-03."
|
|
||||||
in mock_create.mock_calls[1][2]["messages"][0]["content"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
||||||
assert mock_create.mock_calls[1][2]["messages"][3] == {
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
|
|
||||||
"content": '"Test response"',
|
|
||||||
}
|
|
||||||
mock_tool.async_call.assert_awaited_once_with(
|
|
||||||
hass,
|
|
||||||
llm.ToolInput(
|
|
||||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
|
||||||
tool_name="test_tool",
|
|
||||||
tool_args={"param1": "test_value"},
|
|
||||||
),
|
),
|
||||||
llm.LLMContext(
|
(
|
||||||
platform="openai_conversation",
|
"Test invalid JSON",
|
||||||
context=context,
|
(
|
||||||
user_prompt="Please call the test function",
|
ChatCompletionChunk(
|
||||||
language="en",
|
id="chatcmpl-A",
|
||||||
assistant="conversation",
|
created=1700000000,
|
||||||
device_id=None,
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
tool_calls=[
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
id="call_call_1",
|
||||||
|
index=0,
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name="test_tool",
|
||||||
|
arguments=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ChatCompletionChunk(
|
||||||
|
id="chatcmpl-A",
|
||||||
|
created=1700000000,
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
tool_calls=[
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name=None,
|
||||||
|
arguments='{"para',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ChatCompletionChunk(
|
||||||
|
id="chatcmpl-B",
|
||||||
|
created=1700000000,
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(content="Cool"),
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
],
|
||||||
|
|
||||||
# Test Conversation tracing
|
|
||||||
traces = trace.async_get_traces()
|
|
||||||
assert traces
|
|
||||||
last_trace = traces[-1].as_dict()
|
|
||||||
trace_events = last_trace.get("events", [])
|
|
||||||
assert [event["event_type"] for event in trace_events] == [
|
|
||||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
|
||||||
trace.ConversationTraceEventType.TOOL_CALL,
|
|
||||||
]
|
|
||||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
|
||||||
detail_event = trace_events[1]
|
|
||||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
|
||||||
assert (
|
|
||||||
"Today's date is 2024-06-03."
|
|
||||||
in trace_events[1]["data"]["messages"][0]["content"]
|
|
||||||
)
|
|
||||||
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
|
|
||||||
|
|
||||||
# Call it again, make sure we have updated prompt
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=completion_result,
|
|
||||||
) as mock_create,
|
|
||||||
freeze_time("2024-06-04 23:00:00"),
|
|
||||||
):
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
"Please call the test function",
|
|
||||||
None,
|
|
||||||
context,
|
|
||||||
agent_id=agent_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
"Today's date is 2024-06-04."
|
|
||||||
in mock_create.mock_calls[1][2]["messages"][0]["content"]
|
|
||||||
)
|
|
||||||
# Test old assert message not updated
|
|
||||||
assert (
|
|
||||||
"Today's date is 2024-06-03."
|
|
||||||
in trace_events[1]["data"]["messages"][0]["content"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
|
||||||
)
|
)
|
||||||
async def test_function_exception(
|
async def test_function_call_invalid(
|
||||||
mock_get_tools,
|
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry_with_assist: MockConfigEntry,
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
|
mock_create_stream: AsyncMock,
|
||||||
|
mock_chat_log: MockChatLog,
|
||||||
|
description: str,
|
||||||
|
messages: tuple[ChatCompletionChunk],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test function call with exception."""
|
"""Test function call containing invalid data."""
|
||||||
agent_id = mock_config_entry_with_assist.entry_id
|
mock_create_stream.return_value = [messages]
|
||||||
context = Context()
|
|
||||||
|
|
||||||
mock_tool = AsyncMock()
|
with pytest.raises(ValueError):
|
||||||
mock_tool.name = "test_tool"
|
await conversation.async_converse(
|
||||||
mock_tool.description = "Test function"
|
|
||||||
mock_tool.parameters = vol.Schema(
|
|
||||||
{vol.Optional("param1", description="Test parameters"): str}
|
|
||||||
)
|
|
||||||
mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception")
|
|
||||||
|
|
||||||
mock_get_tools.return_value = [mock_tool]
|
|
||||||
|
|
||||||
def completion_result(*args, messages, **kwargs):
|
|
||||||
for message in messages:
|
|
||||||
role = message["role"] if isinstance(message, dict) else message.role
|
|
||||||
if role == "tool":
|
|
||||||
return ChatCompletion(
|
|
||||||
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
|
|
||||||
choices=[
|
|
||||||
Choice(
|
|
||||||
finish_reason="stop",
|
|
||||||
index=0,
|
|
||||||
message=ChatCompletionMessage(
|
|
||||||
content="There was an error calling the function",
|
|
||||||
role="assistant",
|
|
||||||
function_call=None,
|
|
||||||
tool_calls=None,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=1700000000,
|
|
||||||
model="gpt-4-1106-preview",
|
|
||||||
object="chat.completion",
|
|
||||||
system_fingerprint=None,
|
|
||||||
usage=CompletionUsage(
|
|
||||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return ChatCompletion(
|
|
||||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
|
||||||
choices=[
|
|
||||||
Choice(
|
|
||||||
finish_reason="tool_calls",
|
|
||||||
index=0,
|
|
||||||
message=ChatCompletionMessage(
|
|
||||||
content=None,
|
|
||||||
role="assistant",
|
|
||||||
function_call=None,
|
|
||||||
tool_calls=[
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
|
||||||
function=Function(
|
|
||||||
arguments='{"param1":"test_value"}',
|
|
||||||
name="test_tool",
|
|
||||||
),
|
|
||||||
type="function",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=1700000000,
|
|
||||||
model="gpt-4-1106-preview",
|
|
||||||
object="chat.completion",
|
|
||||||
system_fingerprint=None,
|
|
||||||
usage=CompletionUsage(
|
|
||||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=completion_result,
|
|
||||||
) as mock_create:
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
hass,
|
||||||
"Please call the test function",
|
"Please call the test function",
|
||||||
None,
|
"mock-conversation-id",
|
||||||
context,
|
Context(),
|
||||||
agent_id=agent_id,
|
agent_id="conversation.openai",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
||||||
assert mock_create.mock_calls[1][2]["messages"][3] == {
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
|
|
||||||
"content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}',
|
|
||||||
}
|
|
||||||
mock_tool.async_call.assert_awaited_once_with(
|
|
||||||
hass,
|
|
||||||
llm.ToolInput(
|
|
||||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
|
||||||
tool_name="test_tool",
|
|
||||||
tool_args={"param1": "test_value"},
|
|
||||||
),
|
|
||||||
llm.LLMContext(
|
|
||||||
platform="openai_conversation",
|
|
||||||
context=context,
|
|
||||||
user_prompt="Please call the test function",
|
|
||||||
language="en",
|
|
||||||
assistant="conversation",
|
|
||||||
device_id=None,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_assist_api_tools_conversion(
|
async def test_assist_api_tools_conversion(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry_with_assist: MockConfigEntry,
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
|
mock_create_stream,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that we are able to convert actual tools from Assist API."""
|
"""Test that we are able to convert actual tools from Assist API."""
|
||||||
for component in (
|
for component in (
|
||||||
"intent",
|
"calendar",
|
||||||
"todo",
|
|
||||||
"light",
|
|
||||||
"shopping_list",
|
|
||||||
"humidifier",
|
|
||||||
"climate",
|
"climate",
|
||||||
"media_player",
|
|
||||||
"vacuum",
|
|
||||||
"cover",
|
"cover",
|
||||||
|
"humidifier",
|
||||||
|
"intent",
|
||||||
|
"light",
|
||||||
|
"media_player",
|
||||||
|
"script",
|
||||||
|
"shopping_list",
|
||||||
|
"todo",
|
||||||
|
"vacuum",
|
||||||
"weather",
|
"weather",
|
||||||
):
|
):
|
||||||
assert await async_setup_component(hass, component, {})
|
assert await async_setup_component(hass, component, {})
|
||||||
|
hass.states.async_set(f"{component}.test", "on")
|
||||||
|
async_expose_entity(hass, "conversation", f"{component}.test", True)
|
||||||
|
|
||||||
agent_id = mock_config_entry_with_assist.entry_id
|
mock_create_stream.return_value = [ASSIST_RESPONSE_FINISH]
|
||||||
with patch(
|
|
||||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=ChatCompletion(
|
|
||||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
|
||||||
choices=[
|
|
||||||
Choice(
|
|
||||||
finish_reason="stop",
|
|
||||||
index=0,
|
|
||||||
message=ChatCompletionMessage(
|
|
||||||
content="Hello, how can I help you?",
|
|
||||||
role="assistant",
|
|
||||||
function_call=None,
|
|
||||||
tool_calls=None,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=1700000000,
|
|
||||||
model="gpt-3.5-turbo-0613",
|
|
||||||
object="chat.completion",
|
|
||||||
system_fingerprint=None,
|
|
||||||
usage=CompletionUsage(
|
|
||||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
|
||||||
),
|
|
||||||
),
|
|
||||||
) as mock_create:
|
|
||||||
await conversation.async_converse(
|
|
||||||
hass, "hello", None, Context(), agent_id=agent_id
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = mock_create.mock_calls[0][2]["tools"]
|
await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id="conversation.openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = mock_create_stream.mock_calls[0][2]["tools"]
|
||||||
assert tools
|
assert tools
|
||||||
|
Loading…
x
Reference in New Issue
Block a user