mirror of
https://github.com/home-assistant/core.git
synced 2025-06-18 12:07:06 +00:00
Add typing for OpenAI client and fallout (#118514)
* typing for client and consequences * Update homeassistant/components/openai_conversation/conversation.py --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
2bd142d3a6
commit
eae04bf2e9
@ -1,9 +1,22 @@
|
|||||||
"""Conversation support for OpenAI."""
|
"""Conversation support for OpenAI."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Literal
|
from typing import Literal
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
from openai._types import NOT_GIVEN
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionMessage,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionMessageToolCallParam,
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
ChatCompletionToolParam,
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||||
|
from openai.types.shared_params import FunctionDefinition
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
@ -45,13 +58,12 @@ async def async_setup_entry(
|
|||||||
async_add_entities([agent])
|
async_add_entities([agent])
|
||||||
|
|
||||||
|
|
||||||
def _format_tool(tool: llm.Tool) -> dict[str, Any]:
|
def _format_tool(tool: llm.Tool) -> ChatCompletionToolParam:
|
||||||
"""Format tool specification."""
|
"""Format tool specification."""
|
||||||
tool_spec = {"name": tool.name}
|
tool_spec = FunctionDefinition(name=tool.name, parameters=convert(tool.parameters))
|
||||||
if tool.description:
|
if tool.description:
|
||||||
tool_spec["description"] = tool.description
|
tool_spec["description"] = tool.description
|
||||||
tool_spec["parameters"] = convert(tool.parameters)
|
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||||
return {"type": "function", "function": tool_spec}
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConversationEntity(
|
class OpenAIConversationEntity(
|
||||||
@ -65,7 +77,7 @@ class OpenAIConversationEntity(
|
|||||||
def __init__(self, entry: ConfigEntry) -> None:
|
def __init__(self, entry: ConfigEntry) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
self.history: dict[str, list[dict]] = {}
|
self.history: dict[str, list[ChatCompletionMessageParam]] = {}
|
||||||
self._attr_unique_id = entry.entry_id
|
self._attr_unique_id = entry.entry_id
|
||||||
self._attr_device_info = dr.DeviceInfo(
|
self._attr_device_info = dr.DeviceInfo(
|
||||||
identifiers={(DOMAIN, entry.entry_id)},
|
identifiers={(DOMAIN, entry.entry_id)},
|
||||||
@ -100,7 +112,7 @@ class OpenAIConversationEntity(
|
|||||||
options = self.entry.options
|
options = self.entry.options
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
llm_api: llm.APIInstance | None = None
|
llm_api: llm.APIInstance | None = None
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[ChatCompletionToolParam] | None = None
|
||||||
|
|
||||||
if options.get(CONF_LLM_HASS_API):
|
if options.get(CONF_LLM_HASS_API):
|
||||||
try:
|
try:
|
||||||
@ -164,16 +176,18 @@ class OpenAIConversationEntity(
|
|||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=conversation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [{"role": "system", "content": prompt}]
|
messages = [ChatCompletionSystemMessageParam(role="system", content=prompt)]
|
||||||
|
|
||||||
messages.append({"role": "user", "content": user_input.text})
|
messages.append(
|
||||||
|
ChatCompletionUserMessageParam(role="user", content=user_input.text)
|
||||||
|
)
|
||||||
|
|
||||||
LOGGER.debug("Prompt: %s", messages)
|
LOGGER.debug("Prompt: %s", messages)
|
||||||
trace.async_conversation_trace_append(
|
trace.async_conversation_trace_append(
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
||||||
)
|
)
|
||||||
|
|
||||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
client: openai.AsyncClient = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||||
|
|
||||||
# To prevent infinite loops, we limit the number of iterations
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
@ -181,7 +195,7 @@ class OpenAIConversationEntity(
|
|||||||
result = await client.chat.completions.create(
|
result = await client.chat.completions.create(
|
||||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools or None,
|
tools=tools or NOT_GIVEN,
|
||||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
@ -199,7 +213,31 @@ class OpenAIConversationEntity(
|
|||||||
|
|
||||||
LOGGER.debug("Response %s", result)
|
LOGGER.debug("Response %s", result)
|
||||||
response = result.choices[0].message
|
response = result.choices[0].message
|
||||||
messages.append(response)
|
|
||||||
|
def message_convert(
|
||||||
|
message: ChatCompletionMessage,
|
||||||
|
) -> ChatCompletionMessageParam:
|
||||||
|
"""Convert from class to TypedDict."""
|
||||||
|
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
||||||
|
if message.tool_calls:
|
||||||
|
tool_calls = [
|
||||||
|
ChatCompletionMessageToolCallParam(
|
||||||
|
id=tool_call.id,
|
||||||
|
function=Function(
|
||||||
|
arguments=tool_call.function.arguments,
|
||||||
|
name=tool_call.function.name,
|
||||||
|
),
|
||||||
|
type=tool_call.type,
|
||||||
|
)
|
||||||
|
for tool_call in message.tool_calls
|
||||||
|
]
|
||||||
|
return ChatCompletionAssistantMessageParam(
|
||||||
|
role=message.role,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=message.content,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(message_convert(response))
|
||||||
tool_calls = response.tool_calls
|
tool_calls = response.tool_calls
|
||||||
|
|
||||||
if not tool_calls or not llm_api:
|
if not tool_calls or not llm_api:
|
||||||
@ -223,18 +261,17 @@ class OpenAIConversationEntity(
|
|||||||
|
|
||||||
LOGGER.debug("Tool response: %s", tool_response)
|
LOGGER.debug("Tool response: %s", tool_response)
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
ChatCompletionToolMessageParam(
|
||||||
"role": "tool",
|
role="tool",
|
||||||
"tool_call_id": tool_call.id,
|
tool_call_id=tool_call.id,
|
||||||
"name": tool_call.function.name,
|
content=json.dumps(tool_response),
|
||||||
"content": json.dumps(tool_response),
|
)
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.history[conversation_id] = messages
|
self.history[conversation_id] = messages
|
||||||
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
intent_response.async_set_speech(response.content)
|
intent_response.async_set_speech(response.content or "")
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=conversation_id
|
||||||
)
|
)
|
||||||
|
@ -184,7 +184,6 @@ async def test_function_call(
|
|||||||
assert mock_create.mock_calls[1][2]["messages"][3] == {
|
assert mock_create.mock_calls[1][2]["messages"][3] == {
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
|
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||||
"name": "test_tool",
|
|
||||||
"content": '"Test response"',
|
"content": '"Test response"',
|
||||||
}
|
}
|
||||||
mock_tool.async_call.assert_awaited_once_with(
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
@ -317,7 +316,6 @@ async def test_function_exception(
|
|||||||
assert mock_create.mock_calls[1][2]["messages"][3] == {
|
assert mock_create.mock_calls[1][2]["messages"][3] == {
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
|
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||||
"name": "test_tool",
|
|
||||||
"content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}',
|
"content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}',
|
||||||
}
|
}
|
||||||
mock_tool.async_call.assert_awaited_once_with(
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user