diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index f4652a1f820..afc5396e0ba 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -1,9 +1,22 @@ """Conversation support for OpenAI.""" import json -from typing import Any, Literal +from typing import Literal import openai +from openai._types import NOT_GIVEN +from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, +) +from openai.types.chat.chat_completion_message_tool_call_param import Function +from openai.types.shared_params import FunctionDefinition import voluptuous as vol from voluptuous_openapi import convert @@ -45,13 +58,12 @@ async def async_setup_entry( async_add_entities([agent]) -def _format_tool(tool: llm.Tool) -> dict[str, Any]: +def _format_tool(tool: llm.Tool) -> ChatCompletionToolParam: """Format tool specification.""" - tool_spec = {"name": tool.name} + tool_spec = FunctionDefinition(name=tool.name, parameters=convert(tool.parameters)) if tool.description: tool_spec["description"] = tool.description - tool_spec["parameters"] = convert(tool.parameters) - return {"type": "function", "function": tool_spec} + return ChatCompletionToolParam(type="function", function=tool_spec) class OpenAIConversationEntity( @@ -65,7 +77,7 @@ class OpenAIConversationEntity( def __init__(self, entry: ConfigEntry) -> None: """Initialize the agent.""" self.entry = entry - self.history: dict[str, list[dict]] = {} + self.history: dict[str, list[ChatCompletionMessageParam]] = {} self._attr_unique_id = entry.entry_id self._attr_device_info = dr.DeviceInfo( identifiers={(DOMAIN, entry.entry_id)}, @@ -100,7 +112,7 @@ class OpenAIConversationEntity( options = self.entry.options intent_response = intent.IntentResponse(language=user_input.language) llm_api: llm.APIInstance | None = None - tools: list[dict[str, Any]] | None = None + tools: list[ChatCompletionToolParam] | None = None if options.get(CONF_LLM_HASS_API): try: @@ -164,16 +176,18 @@ class OpenAIConversationEntity( response=intent_response, conversation_id=conversation_id ) - messages = [{"role": "system", "content": prompt}] + messages = [ChatCompletionSystemMessageParam(role="system", content=prompt)] - messages.append({"role": "user", "content": user_input.text}) + messages.append( + ChatCompletionUserMessageParam(role="user", content=user_input.text) + ) LOGGER.debug("Prompt: %s", messages) trace.async_conversation_trace_append( trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages} ) - client = self.hass.data[DOMAIN][self.entry.entry_id] + client: openai.AsyncClient = self.hass.data[DOMAIN][self.entry.entry_id] # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): @@ -181,7 +195,7 @@ class OpenAIConversationEntity( result = await client.chat.completions.create( model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), messages=messages, - tools=tools or None, + tools=tools or NOT_GIVEN, max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), @@ -199,7 +213,31 @@ class OpenAIConversationEntity( LOGGER.debug("Response %s", result) response = result.choices[0].message - messages.append(response) + + def message_convert( + message: ChatCompletionMessage, + ) -> ChatCompletionMessageParam: + """Convert from class to TypedDict.""" + tool_calls: list[ChatCompletionMessageToolCallParam] = [] + if message.tool_calls: + tool_calls = [ + ChatCompletionMessageToolCallParam( + id=tool_call.id, + function=Function( + arguments=tool_call.function.arguments, + name=tool_call.function.name, + ), + type=tool_call.type, + ) + for tool_call in message.tool_calls + ] + return ChatCompletionAssistantMessageParam( + role=message.role, + tool_calls=tool_calls, + content=message.content, + ) + + messages.append(message_convert(response)) tool_calls = response.tool_calls if not tool_calls or not llm_api: @@ -223,18 +261,17 @@ class OpenAIConversationEntity( LOGGER.debug("Tool response: %s", tool_response) messages.append( - { - "role": "tool", - "tool_call_id": tool_call.id, - "name": tool_call.function.name, - "content": json.dumps(tool_response), - } + ChatCompletionToolMessageParam( + role="tool", + tool_call_id=tool_call.id, + content=json.dumps(tool_response), + ) ) self.history[conversation_id] = messages intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_speech(response.content) + intent_response.async_set_speech(response.content or "") return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 0eec14395e5..4d16973ddfc 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -184,7 +184,6 @@ async def test_function_call( assert mock_create.mock_calls[1][2]["messages"][3] == { "role": "tool", "tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx", - "name": "test_tool", "content": '"Test response"', } mock_tool.async_call.assert_awaited_once_with( @@ -317,7 +316,6 @@ async def test_function_exception( assert mock_create.mock_calls[1][2]["messages"][3] == { "role": "tool", "tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx", - "name": "test_tool", "content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}', } mock_tool.async_call.assert_awaited_once_with(