Add typing for OpenAI client and fallout (#118514)

* typing for client and consequences

* Update homeassistant/components/openai_conversation/conversation.py

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Josef Zweck 2024-05-31 04:13:18 +02:00 committed by GitHub
parent 2bd142d3a6
commit eae04bf2e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 21 deletions

View File

@ -1,9 +1,22 @@
"""Conversation support for OpenAI.""" """Conversation support for OpenAI."""
import json import json
from typing import Any, Literal from typing import Literal
import openai import openai
from openai._types import NOT_GIVEN
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition
import voluptuous as vol import voluptuous as vol
from voluptuous_openapi import convert from voluptuous_openapi import convert
@ -45,13 +58,12 @@ async def async_setup_entry(
async_add_entities([agent]) async_add_entities([agent])
def _format_tool(tool: llm.Tool) -> dict[str, Any]: def _format_tool(tool: llm.Tool) -> ChatCompletionToolParam:
"""Format tool specification.""" """Format tool specification."""
tool_spec = {"name": tool.name} tool_spec = FunctionDefinition(name=tool.name, parameters=convert(tool.parameters))
if tool.description: if tool.description:
tool_spec["description"] = tool.description tool_spec["description"] = tool.description
tool_spec["parameters"] = convert(tool.parameters) return ChatCompletionToolParam(type="function", function=tool_spec)
return {"type": "function", "function": tool_spec}
class OpenAIConversationEntity( class OpenAIConversationEntity(
@ -65,7 +77,7 @@ class OpenAIConversationEntity(
def __init__(self, entry: ConfigEntry) -> None: def __init__(self, entry: ConfigEntry) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.entry = entry self.entry = entry
self.history: dict[str, list[dict]] = {} self.history: dict[str, list[ChatCompletionMessageParam]] = {}
self._attr_unique_id = entry.entry_id self._attr_unique_id = entry.entry_id
self._attr_device_info = dr.DeviceInfo( self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)}, identifiers={(DOMAIN, entry.entry_id)},
@ -100,7 +112,7 @@ class OpenAIConversationEntity(
options = self.entry.options options = self.entry.options
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.APIInstance | None = None llm_api: llm.APIInstance | None = None
tools: list[dict[str, Any]] | None = None tools: list[ChatCompletionToolParam] | None = None
if options.get(CONF_LLM_HASS_API): if options.get(CONF_LLM_HASS_API):
try: try:
@ -164,16 +176,18 @@ class OpenAIConversationEntity(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
messages = [{"role": "system", "content": prompt}] messages = [ChatCompletionSystemMessageParam(role="system", content=prompt)]
messages.append({"role": "user", "content": user_input.text}) messages.append(
ChatCompletionUserMessageParam(role="user", content=user_input.text)
)
LOGGER.debug("Prompt: %s", messages) LOGGER.debug("Prompt: %s", messages)
trace.async_conversation_trace_append( trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages} trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
) )
client = self.hass.data[DOMAIN][self.entry.entry_id] client: openai.AsyncClient = self.hass.data[DOMAIN][self.entry.entry_id]
# To prevent infinite loops, we limit the number of iterations # To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS): for _iteration in range(MAX_TOOL_ITERATIONS):
@ -181,7 +195,7 @@ class OpenAIConversationEntity(
result = await client.chat.completions.create( result = await client.chat.completions.create(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages, messages=messages,
tools=tools or None, tools=tools or NOT_GIVEN,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
@ -199,7 +213,31 @@ class OpenAIConversationEntity(
LOGGER.debug("Response %s", result) LOGGER.debug("Response %s", result)
response = result.choices[0].message response = result.choices[0].message
messages.append(response)
def message_convert(
message: ChatCompletionMessage,
) -> ChatCompletionMessageParam:
"""Convert from class to TypedDict."""
tool_calls: list[ChatCompletionMessageToolCallParam] = []
if message.tool_calls:
tool_calls = [
ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(
arguments=tool_call.function.arguments,
name=tool_call.function.name,
),
type=tool_call.type,
)
for tool_call in message.tool_calls
]
return ChatCompletionAssistantMessageParam(
role=message.role,
tool_calls=tool_calls,
content=message.content,
)
messages.append(message_convert(response))
tool_calls = response.tool_calls tool_calls = response.tool_calls
if not tool_calls or not llm_api: if not tool_calls or not llm_api:
@ -223,18 +261,17 @@ class OpenAIConversationEntity(
LOGGER.debug("Tool response: %s", tool_response) LOGGER.debug("Tool response: %s", tool_response)
messages.append( messages.append(
{ ChatCompletionToolMessageParam(
"role": "tool", role="tool",
"tool_call_id": tool_call.id, tool_call_id=tool_call.id,
"name": tool_call.function.name, content=json.dumps(tool_response),
"content": json.dumps(tool_response), )
}
) )
self.history[conversation_id] = messages self.history[conversation_id] = messages
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.content) intent_response.async_set_speech(response.content or "")
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )

View File

@ -184,7 +184,6 @@ async def test_function_call(
assert mock_create.mock_calls[1][2]["messages"][3] == { assert mock_create.mock_calls[1][2]["messages"][3] == {
"role": "tool", "role": "tool",
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx", "tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
"name": "test_tool",
"content": '"Test response"', "content": '"Test response"',
} }
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
@ -317,7 +316,6 @@ async def test_function_exception(
assert mock_create.mock_calls[1][2]["messages"][3] == { assert mock_create.mock_calls[1][2]["messages"][3] == {
"role": "tool", "role": "tool",
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx", "tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
"name": "test_tool",
"content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}', "content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}',
} }
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(