mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Split OpenAI entity (#147771)
This commit is contained in:
parent
be6b624081
commit
70856bd92a
@ -1,73 +1,19 @@
|
||||
"""Conversation support for OpenAI."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import json
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import openai
|
||||
from openai._streaming import AsyncStream
|
||||
from openai.types.responses import (
|
||||
EasyInputMessageParam,
|
||||
FunctionToolParam,
|
||||
ResponseCompletedEvent,
|
||||
ResponseErrorEvent,
|
||||
ResponseFailedEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseFunctionToolCallParam,
|
||||
ResponseIncompleteEvent,
|
||||
ResponseInputParam,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputMessageParam,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningItemParam,
|
||||
ResponseStreamEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ToolParam,
|
||||
WebSearchToolParam,
|
||||
)
|
||||
from openai.types.responses.response_input_param import FunctionCallOutput
|
||||
from openai.types.responses.web_search_tool_param import UserLocation
|
||||
from voluptuous_openapi import convert
|
||||
from typing import Literal
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, intent, llm
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from . import OpenAIConfigEntry
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_REASONING_EFFORT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||
CONF_WEB_SEARCH_COUNTRY,
|
||||
CONF_WEB_SEARCH_REGION,
|
||||
CONF_WEB_SEARCH_TIMEZONE,
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_REASONING_EFFORT,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
||||
)
|
||||
from .const import CONF_PROMPT, DOMAIN
|
||||
from .entity import OpenAIBaseLLMEntity
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
@ -86,152 +32,10 @@ async def async_setup_entry(
|
||||
)
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> FunctionToolParam:
|
||||
"""Format tool specification."""
|
||||
return FunctionToolParam(
|
||||
type="function",
|
||||
name=tool.name,
|
||||
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
description=tool.description,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
|
||||
def _convert_content_to_param(
|
||||
content: conversation.Content,
|
||||
) -> ResponseInputParam:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
messages: ResponseInputParam = []
|
||||
if isinstance(content, conversation.ToolResultContent):
|
||||
return [
|
||||
FunctionCallOutput(
|
||||
type="function_call_output",
|
||||
call_id=content.tool_call_id,
|
||||
output=json.dumps(content.tool_result),
|
||||
)
|
||||
]
|
||||
|
||||
if content.content:
|
||||
role: Literal["user", "assistant", "system", "developer"] = content.role
|
||||
if role == "system":
|
||||
role = "developer"
|
||||
messages.append(
|
||||
EasyInputMessageParam(type="message", role=role, content=content.content)
|
||||
)
|
||||
|
||||
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
||||
messages.extend(
|
||||
ResponseFunctionToolCallParam(
|
||||
type="function_call",
|
||||
name=tool_call.tool_name,
|
||||
arguments=json.dumps(tool_call.tool_args),
|
||||
call_id=tool_call.id,
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
chat_log: conversation.ChatLog,
|
||||
result: AsyncStream[ResponseStreamEvent],
|
||||
messages: ResponseInputParam,
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform an OpenAI delta stream into HA format."""
|
||||
async for event in result:
|
||||
LOGGER.debug("Received event: %s", event)
|
||||
|
||||
if isinstance(event, ResponseOutputItemAddedEvent):
|
||||
if isinstance(event.item, ResponseOutputMessage):
|
||||
yield {"role": event.item.role}
|
||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||
# OpenAI has tool calls as individual events
|
||||
# while HA puts tool calls inside the assistant message.
|
||||
# We turn them into individual assistant content for HA
|
||||
# to ensure that tools are called as soon as possible.
|
||||
yield {"role": "assistant"}
|
||||
current_tool_call = event.item
|
||||
elif isinstance(event, ResponseOutputItemDoneEvent):
|
||||
item = event.item.model_dump()
|
||||
item.pop("status", None)
|
||||
if isinstance(event.item, ResponseReasoningItem):
|
||||
messages.append(cast(ResponseReasoningItemParam, item))
|
||||
elif isinstance(event.item, ResponseOutputMessage):
|
||||
messages.append(cast(ResponseOutputMessageParam, item))
|
||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||
messages.append(cast(ResponseFunctionToolCallParam, item))
|
||||
elif isinstance(event, ResponseTextDeltaEvent):
|
||||
yield {"content": event.delta}
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
|
||||
current_tool_call.arguments += event.delta
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
|
||||
current_tool_call.status = "completed"
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=current_tool_call.call_id,
|
||||
tool_name=current_tool_call.name,
|
||||
tool_args=json.loads(current_tool_call.arguments),
|
||||
)
|
||||
]
|
||||
}
|
||||
elif isinstance(event, ResponseCompletedEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
elif isinstance(event, ResponseIncompleteEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
event.response.incomplete_details
|
||||
and event.response.incomplete_details.reason
|
||||
):
|
||||
reason: str = event.response.incomplete_details.reason
|
||||
else:
|
||||
reason = "unknown reason"
|
||||
|
||||
if reason == "max_output_tokens":
|
||||
reason = "max output tokens reached"
|
||||
elif reason == "content_filter":
|
||||
reason = "content filter triggered"
|
||||
|
||||
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
|
||||
elif isinstance(event, ResponseFailedEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
reason = "unknown reason"
|
||||
if event.response.error is not None:
|
||||
reason = event.response.error.message
|
||||
raise HomeAssistantError(f"OpenAI response failed: {reason}")
|
||||
elif isinstance(event, ResponseErrorEvent):
|
||||
raise HomeAssistantError(f"OpenAI response error: {event.message}")
|
||||
|
||||
|
||||
class OpenAIConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
conversation.ConversationEntity,
|
||||
conversation.AbstractConversationAgent,
|
||||
OpenAIBaseLLMEntity,
|
||||
):
|
||||
"""OpenAI conversation agent."""
|
||||
|
||||
@ -239,17 +43,7 @@ class OpenAIConversationEntity(
|
||||
|
||||
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.entry = entry
|
||||
self.subentry = subentry
|
||||
self._attr_name = subentry.title
|
||||
self._attr_unique_id = subentry.subentry_id
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, subentry.subentry_id)},
|
||||
name=subentry.title,
|
||||
manufacturer="OpenAI",
|
||||
model=subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
super().__init__(entry, subentry)
|
||||
if self.subentry.data.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
@ -305,91 +99,6 @@ class OpenAIConversationEntity(
|
||||
continue_conversation=chat_log.continue_conversation,
|
||||
)
|
||||
|
||||
async def _async_handle_chat_log(
|
||||
self,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
options = self.subentry.data
|
||||
|
||||
tools: list[ToolParam] | None = None
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
if options.get(CONF_WEB_SEARCH):
|
||||
web_search = WebSearchToolParam(
|
||||
type="web_search_preview",
|
||||
search_context_size=options.get(
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
|
||||
),
|
||||
)
|
||||
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
|
||||
web_search["user_location"] = UserLocation(
|
||||
type="approximate",
|
||||
city=options.get(CONF_WEB_SEARCH_CITY, ""),
|
||||
region=options.get(CONF_WEB_SEARCH_REGION, ""),
|
||||
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
|
||||
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
|
||||
)
|
||||
if tools is None:
|
||||
tools = []
|
||||
tools.append(web_search)
|
||||
|
||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
messages = [
|
||||
m
|
||||
for content in chat_log.content
|
||||
for m in _convert_content_to_param(content)
|
||||
]
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
model_args = {
|
||||
"model": model,
|
||||
"input": messages,
|
||||
"max_output_tokens": options.get(
|
||||
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
||||
),
|
||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
"user": chat_log.conversation_id,
|
||||
"stream": True,
|
||||
}
|
||||
if tools:
|
||||
model_args["tools"] = tools
|
||||
|
||||
if model.startswith("o"):
|
||||
model_args["reasoning"] = {
|
||||
"effort": options.get(
|
||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||
)
|
||||
}
|
||||
else:
|
||||
model_args["store"] = False
|
||||
|
||||
try:
|
||||
result = await client.responses.create(**model_args)
|
||||
except openai.RateLimitError as err:
|
||||
LOGGER.error("Rate limited by OpenAI: %s", err)
|
||||
raise HomeAssistantError("Rate limited or insufficient funds") from err
|
||||
except openai.OpenAIError as err:
|
||||
LOGGER.error("Error talking to OpenAI: %s", err)
|
||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
self.entity_id, _transform_stream(chat_log, result, messages)
|
||||
):
|
||||
if not isinstance(content, conversation.AssistantContent):
|
||||
messages.extend(_convert_content_to_param(content))
|
||||
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
|
||||
async def _async_entry_update_listener(
|
||||
self, hass: HomeAssistant, entry: ConfigEntry
|
||||
) -> None:
|
||||
|
314
homeassistant/components/openai_conversation/entity.py
Normal file
314
homeassistant/components/openai_conversation/entity.py
Normal file
@ -0,0 +1,314 @@
|
||||
"""Base entity for OpenAI."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import json
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import openai
|
||||
from openai._streaming import AsyncStream
|
||||
from openai.types.responses import (
|
||||
EasyInputMessageParam,
|
||||
FunctionToolParam,
|
||||
ResponseCompletedEvent,
|
||||
ResponseErrorEvent,
|
||||
ResponseFailedEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseFunctionToolCallParam,
|
||||
ResponseIncompleteEvent,
|
||||
ResponseInputParam,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputMessageParam,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningItemParam,
|
||||
ResponseStreamEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ToolParam,
|
||||
WebSearchToolParam,
|
||||
)
|
||||
from openai.types.responses.response_input_param import FunctionCallOutput
|
||||
from openai.types.responses.web_search_tool_param import UserLocation
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, llm
|
||||
from homeassistant.helpers.entity import Entity
|
||||
|
||||
from . import OpenAIConfigEntry
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_REASONING_EFFORT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||
CONF_WEB_SEARCH_COUNTRY,
|
||||
CONF_WEB_SEARCH_REGION,
|
||||
CONF_WEB_SEARCH_TIMEZONE,
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_REASONING_EFFORT,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
||||
)
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> FunctionToolParam:
|
||||
"""Format tool specification."""
|
||||
return FunctionToolParam(
|
||||
type="function",
|
||||
name=tool.name,
|
||||
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
description=tool.description,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
|
||||
def _convert_content_to_param(
|
||||
content: conversation.Content,
|
||||
) -> ResponseInputParam:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
messages: ResponseInputParam = []
|
||||
if isinstance(content, conversation.ToolResultContent):
|
||||
return [
|
||||
FunctionCallOutput(
|
||||
type="function_call_output",
|
||||
call_id=content.tool_call_id,
|
||||
output=json.dumps(content.tool_result),
|
||||
)
|
||||
]
|
||||
|
||||
if content.content:
|
||||
role: Literal["user", "assistant", "system", "developer"] = content.role
|
||||
if role == "system":
|
||||
role = "developer"
|
||||
messages.append(
|
||||
EasyInputMessageParam(type="message", role=role, content=content.content)
|
||||
)
|
||||
|
||||
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
||||
messages.extend(
|
||||
ResponseFunctionToolCallParam(
|
||||
type="function_call",
|
||||
name=tool_call.tool_name,
|
||||
arguments=json.dumps(tool_call.tool_args),
|
||||
call_id=tool_call.id,
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
chat_log: conversation.ChatLog,
|
||||
result: AsyncStream[ResponseStreamEvent],
|
||||
messages: ResponseInputParam,
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform an OpenAI delta stream into HA format."""
|
||||
async for event in result:
|
||||
LOGGER.debug("Received event: %s", event)
|
||||
|
||||
if isinstance(event, ResponseOutputItemAddedEvent):
|
||||
if isinstance(event.item, ResponseOutputMessage):
|
||||
yield {"role": event.item.role}
|
||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||
# OpenAI has tool calls as individual events
|
||||
# while HA puts tool calls inside the assistant message.
|
||||
# We turn them into individual assistant content for HA
|
||||
# to ensure that tools are called as soon as possible.
|
||||
yield {"role": "assistant"}
|
||||
current_tool_call = event.item
|
||||
elif isinstance(event, ResponseOutputItemDoneEvent):
|
||||
item = event.item.model_dump()
|
||||
item.pop("status", None)
|
||||
if isinstance(event.item, ResponseReasoningItem):
|
||||
messages.append(cast(ResponseReasoningItemParam, item))
|
||||
elif isinstance(event.item, ResponseOutputMessage):
|
||||
messages.append(cast(ResponseOutputMessageParam, item))
|
||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||
messages.append(cast(ResponseFunctionToolCallParam, item))
|
||||
elif isinstance(event, ResponseTextDeltaEvent):
|
||||
yield {"content": event.delta}
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
|
||||
current_tool_call.arguments += event.delta
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
|
||||
current_tool_call.status = "completed"
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=current_tool_call.call_id,
|
||||
tool_name=current_tool_call.name,
|
||||
tool_args=json.loads(current_tool_call.arguments),
|
||||
)
|
||||
]
|
||||
}
|
||||
elif isinstance(event, ResponseCompletedEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
elif isinstance(event, ResponseIncompleteEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
event.response.incomplete_details
|
||||
and event.response.incomplete_details.reason
|
||||
):
|
||||
reason: str = event.response.incomplete_details.reason
|
||||
else:
|
||||
reason = "unknown reason"
|
||||
|
||||
if reason == "max_output_tokens":
|
||||
reason = "max output tokens reached"
|
||||
elif reason == "content_filter":
|
||||
reason = "content filter triggered"
|
||||
|
||||
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
|
||||
elif isinstance(event, ResponseFailedEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
reason = "unknown reason"
|
||||
if event.response.error is not None:
|
||||
reason = event.response.error.message
|
||||
raise HomeAssistantError(f"OpenAI response failed: {reason}")
|
||||
elif isinstance(event, ResponseErrorEvent):
|
||||
raise HomeAssistantError(f"OpenAI response error: {event.message}")
|
||||
|
||||
|
||||
class OpenAIBaseLLMEntity(Entity):
|
||||
"""OpenAI conversation agent."""
|
||||
|
||||
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
|
||||
"""Initialize the entity."""
|
||||
self.entry = entry
|
||||
self.subentry = subentry
|
||||
self._attr_name = subentry.title
|
||||
self._attr_unique_id = subentry.subentry_id
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, subentry.subentry_id)},
|
||||
name=subentry.title,
|
||||
manufacturer="OpenAI",
|
||||
model=subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
|
||||
async def _async_handle_chat_log(
|
||||
self,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
options = self.subentry.data
|
||||
|
||||
tools: list[ToolParam] | None = None
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
if options.get(CONF_WEB_SEARCH):
|
||||
web_search = WebSearchToolParam(
|
||||
type="web_search_preview",
|
||||
search_context_size=options.get(
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
|
||||
),
|
||||
)
|
||||
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
|
||||
web_search["user_location"] = UserLocation(
|
||||
type="approximate",
|
||||
city=options.get(CONF_WEB_SEARCH_CITY, ""),
|
||||
region=options.get(CONF_WEB_SEARCH_REGION, ""),
|
||||
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
|
||||
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
|
||||
)
|
||||
if tools is None:
|
||||
tools = []
|
||||
tools.append(web_search)
|
||||
|
||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
messages = [
|
||||
m
|
||||
for content in chat_log.content
|
||||
for m in _convert_content_to_param(content)
|
||||
]
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
model_args = {
|
||||
"model": model,
|
||||
"input": messages,
|
||||
"max_output_tokens": options.get(
|
||||
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
||||
),
|
||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
"user": chat_log.conversation_id,
|
||||
"stream": True,
|
||||
}
|
||||
if tools:
|
||||
model_args["tools"] = tools
|
||||
|
||||
if model.startswith("o"):
|
||||
model_args["reasoning"] = {
|
||||
"effort": options.get(
|
||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||
)
|
||||
}
|
||||
else:
|
||||
model_args["store"] = False
|
||||
|
||||
try:
|
||||
result = await client.responses.create(**model_args)
|
||||
except openai.RateLimitError as err:
|
||||
LOGGER.error("Rate limited by OpenAI: %s", err)
|
||||
raise HomeAssistantError("Rate limited or insufficient funds") from err
|
||||
except openai.OpenAIError as err:
|
||||
LOGGER.error("Error talking to OpenAI: %s", err)
|
||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
self.entity_id, _transform_stream(chat_log, result, messages)
|
||||
):
|
||||
if not isinstance(content, conversation.AssistantContent):
|
||||
messages.extend(_convert_content_to_param(content))
|
||||
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
Loading…
x
Reference in New Issue
Block a user