mirror of
https://github.com/home-assistant/core.git
synced 2025-06-10 16:17:10 +00:00
378 lines
14 KiB
Python
378 lines
14 KiB
Python
"""Conversation support for OpenAI."""
|
|
|
|
from collections.abc import AsyncGenerator, Callable
|
|
import json
|
|
from typing import Any, Literal, cast
|
|
|
|
import openai
|
|
from openai._streaming import AsyncStream
|
|
from openai.types.responses import (
|
|
EasyInputMessageParam,
|
|
FunctionToolParam,
|
|
ResponseCompletedEvent,
|
|
ResponseErrorEvent,
|
|
ResponseFailedEvent,
|
|
ResponseFunctionCallArgumentsDeltaEvent,
|
|
ResponseFunctionCallArgumentsDoneEvent,
|
|
ResponseFunctionToolCall,
|
|
ResponseFunctionToolCallParam,
|
|
ResponseIncompleteEvent,
|
|
ResponseInputParam,
|
|
ResponseOutputItemAddedEvent,
|
|
ResponseOutputItemDoneEvent,
|
|
ResponseOutputMessage,
|
|
ResponseOutputMessageParam,
|
|
ResponseReasoningItem,
|
|
ResponseReasoningItemParam,
|
|
ResponseStreamEvent,
|
|
ResponseTextDeltaEvent,
|
|
ToolParam,
|
|
WebSearchToolParam,
|
|
)
|
|
from openai.types.responses.response_input_param import FunctionCallOutput
|
|
from openai.types.responses.web_search_tool_param import UserLocation
|
|
from voluptuous_openapi import convert
|
|
|
|
from homeassistant.components import assist_pipeline, conversation
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers import device_registry as dr, intent, llm
|
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
|
|
|
from . import OpenAIConfigEntry
|
|
from .const import (
|
|
CONF_CHAT_MODEL,
|
|
CONF_MAX_TOKENS,
|
|
CONF_PROMPT,
|
|
CONF_REASONING_EFFORT,
|
|
CONF_TEMPERATURE,
|
|
CONF_TOP_P,
|
|
CONF_WEB_SEARCH,
|
|
CONF_WEB_SEARCH_CITY,
|
|
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
|
CONF_WEB_SEARCH_COUNTRY,
|
|
CONF_WEB_SEARCH_REGION,
|
|
CONF_WEB_SEARCH_TIMEZONE,
|
|
CONF_WEB_SEARCH_USER_LOCATION,
|
|
DOMAIN,
|
|
LOGGER,
|
|
RECOMMENDED_CHAT_MODEL,
|
|
RECOMMENDED_MAX_TOKENS,
|
|
RECOMMENDED_REASONING_EFFORT,
|
|
RECOMMENDED_TEMPERATURE,
|
|
RECOMMENDED_TOP_P,
|
|
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
|
)
|
|
|
|
# Max number of back and forth with the LLM to generate a response
|
|
MAX_TOOL_ITERATIONS = 10
|
|
|
|
|
|
async def async_setup_entry(
|
|
hass: HomeAssistant,
|
|
config_entry: OpenAIConfigEntry,
|
|
async_add_entities: AddConfigEntryEntitiesCallback,
|
|
) -> None:
|
|
"""Set up conversation entities."""
|
|
agent = OpenAIConversationEntity(config_entry)
|
|
async_add_entities([agent])
|
|
|
|
|
|
def _format_tool(
|
|
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
|
) -> FunctionToolParam:
|
|
"""Format tool specification."""
|
|
return FunctionToolParam(
|
|
type="function",
|
|
name=tool.name,
|
|
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
|
description=tool.description,
|
|
strict=False,
|
|
)
|
|
|
|
|
|
def _convert_content_to_param(
|
|
content: conversation.Content,
|
|
) -> ResponseInputParam:
|
|
"""Convert any native chat message for this agent to the native format."""
|
|
messages: ResponseInputParam = []
|
|
if isinstance(content, conversation.ToolResultContent):
|
|
return [
|
|
FunctionCallOutput(
|
|
type="function_call_output",
|
|
call_id=content.tool_call_id,
|
|
output=json.dumps(content.tool_result),
|
|
)
|
|
]
|
|
|
|
if content.content:
|
|
role: Literal["user", "assistant", "system", "developer"] = content.role
|
|
if role == "system":
|
|
role = "developer"
|
|
messages.append(
|
|
EasyInputMessageParam(type="message", role=role, content=content.content)
|
|
)
|
|
|
|
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
|
messages.extend(
|
|
ResponseFunctionToolCallParam(
|
|
type="function_call",
|
|
name=tool_call.tool_name,
|
|
arguments=json.dumps(tool_call.tool_args),
|
|
call_id=tool_call.id,
|
|
)
|
|
for tool_call in content.tool_calls
|
|
)
|
|
return messages
|
|
|
|
|
|
async def _transform_stream(
|
|
chat_log: conversation.ChatLog,
|
|
result: AsyncStream[ResponseStreamEvent],
|
|
messages: ResponseInputParam,
|
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
|
"""Transform an OpenAI delta stream into HA format."""
|
|
async for event in result:
|
|
LOGGER.debug("Received event: %s", event)
|
|
|
|
if isinstance(event, ResponseOutputItemAddedEvent):
|
|
if isinstance(event.item, ResponseOutputMessage):
|
|
yield {"role": event.item.role}
|
|
elif isinstance(event.item, ResponseFunctionToolCall):
|
|
current_tool_call = event.item
|
|
elif isinstance(event, ResponseOutputItemDoneEvent):
|
|
item = event.item.model_dump()
|
|
item.pop("status", None)
|
|
if isinstance(event.item, ResponseReasoningItem):
|
|
messages.append(cast(ResponseReasoningItemParam, item))
|
|
elif isinstance(event.item, ResponseOutputMessage):
|
|
messages.append(cast(ResponseOutputMessageParam, item))
|
|
elif isinstance(event.item, ResponseFunctionToolCall):
|
|
messages.append(cast(ResponseFunctionToolCallParam, item))
|
|
elif isinstance(event, ResponseTextDeltaEvent):
|
|
yield {"content": event.delta}
|
|
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
|
|
current_tool_call.arguments += event.delta
|
|
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
|
|
current_tool_call.status = "completed"
|
|
yield {
|
|
"tool_calls": [
|
|
llm.ToolInput(
|
|
id=current_tool_call.call_id,
|
|
tool_name=current_tool_call.name,
|
|
tool_args=json.loads(current_tool_call.arguments),
|
|
)
|
|
]
|
|
}
|
|
elif isinstance(event, ResponseCompletedEvent):
|
|
if event.response.usage is not None:
|
|
chat_log.async_trace(
|
|
{
|
|
"stats": {
|
|
"input_tokens": event.response.usage.input_tokens,
|
|
"output_tokens": event.response.usage.output_tokens,
|
|
}
|
|
}
|
|
)
|
|
elif isinstance(event, ResponseIncompleteEvent):
|
|
if event.response.usage is not None:
|
|
chat_log.async_trace(
|
|
{
|
|
"stats": {
|
|
"input_tokens": event.response.usage.input_tokens,
|
|
"output_tokens": event.response.usage.output_tokens,
|
|
}
|
|
}
|
|
)
|
|
|
|
if (
|
|
event.response.incomplete_details
|
|
and event.response.incomplete_details.reason
|
|
):
|
|
reason: str = event.response.incomplete_details.reason
|
|
else:
|
|
reason = "unknown reason"
|
|
|
|
if reason == "max_output_tokens":
|
|
reason = "max output tokens reached"
|
|
elif reason == "content_filter":
|
|
reason = "content filter triggered"
|
|
|
|
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
|
|
elif isinstance(event, ResponseFailedEvent):
|
|
if event.response.usage is not None:
|
|
chat_log.async_trace(
|
|
{
|
|
"stats": {
|
|
"input_tokens": event.response.usage.input_tokens,
|
|
"output_tokens": event.response.usage.output_tokens,
|
|
}
|
|
}
|
|
)
|
|
reason = "unknown reason"
|
|
if event.response.error is not None:
|
|
reason = event.response.error.message
|
|
raise HomeAssistantError(f"OpenAI response failed: {reason}")
|
|
elif isinstance(event, ResponseErrorEvent):
|
|
raise HomeAssistantError(f"OpenAI response error: {event.message}")
|
|
|
|
|
|
class OpenAIConversationEntity(
|
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
|
):
|
|
"""OpenAI conversation agent."""
|
|
|
|
_attr_has_entity_name = True
|
|
_attr_name = None
|
|
|
|
def __init__(self, entry: OpenAIConfigEntry) -> None:
|
|
"""Initialize the agent."""
|
|
self.entry = entry
|
|
self._attr_unique_id = entry.entry_id
|
|
self._attr_device_info = dr.DeviceInfo(
|
|
identifiers={(DOMAIN, entry.entry_id)},
|
|
name=entry.title,
|
|
manufacturer="OpenAI",
|
|
model="ChatGPT",
|
|
entry_type=dr.DeviceEntryType.SERVICE,
|
|
)
|
|
if self.entry.options.get(CONF_LLM_HASS_API):
|
|
self._attr_supported_features = (
|
|
conversation.ConversationEntityFeature.CONTROL
|
|
)
|
|
|
|
@property
|
|
def supported_languages(self) -> list[str] | Literal["*"]:
|
|
"""Return a list of supported languages."""
|
|
return MATCH_ALL
|
|
|
|
async def async_added_to_hass(self) -> None:
|
|
"""When entity is added to Home Assistant."""
|
|
await super().async_added_to_hass()
|
|
assist_pipeline.async_migrate_engine(
|
|
self.hass, "conversation", self.entry.entry_id, self.entity_id
|
|
)
|
|
conversation.async_set_agent(self.hass, self.entry, self)
|
|
self.entry.async_on_unload(
|
|
self.entry.add_update_listener(self._async_entry_update_listener)
|
|
)
|
|
|
|
async def async_will_remove_from_hass(self) -> None:
|
|
"""When entity will be removed from Home Assistant."""
|
|
conversation.async_unset_agent(self.hass, self.entry)
|
|
await super().async_will_remove_from_hass()
|
|
|
|
async def _async_handle_message(
|
|
self,
|
|
user_input: conversation.ConversationInput,
|
|
chat_log: conversation.ChatLog,
|
|
) -> conversation.ConversationResult:
|
|
"""Call the API."""
|
|
options = self.entry.options
|
|
|
|
try:
|
|
await chat_log.async_update_llm_data(
|
|
DOMAIN,
|
|
user_input,
|
|
options.get(CONF_LLM_HASS_API),
|
|
options.get(CONF_PROMPT),
|
|
)
|
|
except conversation.ConverseError as err:
|
|
return err.as_conversation_result()
|
|
|
|
tools: list[ToolParam] | None = None
|
|
if chat_log.llm_api:
|
|
tools = [
|
|
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
|
for tool in chat_log.llm_api.tools
|
|
]
|
|
|
|
if options.get(CONF_WEB_SEARCH):
|
|
web_search = WebSearchToolParam(
|
|
type="web_search_preview",
|
|
search_context_size=options.get(
|
|
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
|
|
),
|
|
)
|
|
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
|
|
web_search["user_location"] = UserLocation(
|
|
type="approximate",
|
|
city=options.get(CONF_WEB_SEARCH_CITY, ""),
|
|
region=options.get(CONF_WEB_SEARCH_REGION, ""),
|
|
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
|
|
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
|
|
)
|
|
if tools is None:
|
|
tools = []
|
|
tools.append(web_search)
|
|
|
|
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
|
messages = [
|
|
m
|
|
for content in chat_log.content
|
|
for m in _convert_content_to_param(content)
|
|
]
|
|
|
|
client = self.entry.runtime_data
|
|
|
|
# To prevent infinite loops, we limit the number of iterations
|
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
|
model_args = {
|
|
"model": model,
|
|
"input": messages,
|
|
"max_output_tokens": options.get(
|
|
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
|
),
|
|
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
|
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
|
"user": chat_log.conversation_id,
|
|
"stream": True,
|
|
}
|
|
if tools:
|
|
model_args["tools"] = tools
|
|
|
|
if model.startswith("o"):
|
|
model_args["reasoning"] = {
|
|
"effort": options.get(
|
|
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
|
)
|
|
}
|
|
else:
|
|
model_args["store"] = False
|
|
|
|
try:
|
|
result = await client.responses.create(**model_args)
|
|
except openai.RateLimitError as err:
|
|
LOGGER.error("Rate limited by OpenAI: %s", err)
|
|
raise HomeAssistantError("Rate limited or insufficient funds") from err
|
|
except openai.OpenAIError as err:
|
|
LOGGER.error("Error talking to OpenAI: %s", err)
|
|
raise HomeAssistantError("Error talking to OpenAI") from err
|
|
|
|
async for content in chat_log.async_add_delta_content_stream(
|
|
user_input.agent_id, _transform_stream(chat_log, result, messages)
|
|
):
|
|
if not isinstance(content, conversation.AssistantContent):
|
|
messages.extend(_convert_content_to_param(content))
|
|
|
|
if not chat_log.unresponded_tool_results:
|
|
break
|
|
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
assert type(chat_log.content[-1]) is conversation.AssistantContent
|
|
intent_response.async_set_speech(chat_log.content[-1].content or "")
|
|
return conversation.ConversationResult(
|
|
response=intent_response,
|
|
conversation_id=chat_log.conversation_id,
|
|
continue_conversation=chat_log.continue_conversation,
|
|
)
|
|
|
|
async def _async_entry_update_listener(
|
|
self, hass: HomeAssistant, entry: ConfigEntry
|
|
) -> None:
|
|
"""Handle options update."""
|
|
# Reload as we update device info + entity name + supported features
|
|
await hass.config_entries.async_reload(entry.entry_id)
|