mirror of
https://github.com/home-assistant/core.git
synced 2025-07-31 17:18:23 +00:00
Use modern Python for OpenAI
This commit is contained in:
parent
4cc4bd3b9a
commit
099a480e57
@ -117,15 +117,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
)
|
||||
|
||||
# Get first conversation subentry for options
|
||||
conversation_subentry = next(
|
||||
(
|
||||
sub
|
||||
for sub in entry.subentries.values()
|
||||
if sub.subentry_type == "conversation"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not conversation_subentry:
|
||||
if not (
|
||||
conversation_subentry := next(
|
||||
(
|
||||
sub
|
||||
for sub in entry.subentries.values()
|
||||
if sub.subentry_type == "conversation"
|
||||
),
|
||||
None,
|
||||
)
|
||||
):
|
||||
raise ServiceValidationError("No conversation configuration found")
|
||||
|
||||
model: str = conversation_subentry.data.get(
|
||||
@ -138,13 +139,20 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
]
|
||||
|
||||
if filenames := call.data.get(CONF_FILENAMES):
|
||||
for filename in filenames:
|
||||
if not hass.config.is_allowed_path(filename):
|
||||
raise HomeAssistantError(
|
||||
f"Cannot read `{filename}`, no access to path; "
|
||||
"`allowlist_external_dirs` may need to be adjusted in "
|
||||
"`configuration.yaml`"
|
||||
)
|
||||
# Check access to all files first
|
||||
if blocked_file := next(
|
||||
(
|
||||
filename
|
||||
for filename in filenames
|
||||
if not hass.config.is_allowed_path(filename)
|
||||
),
|
||||
None,
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
f"Cannot read `{blocked_file}`, no access to path; "
|
||||
"`allowlist_external_dirs` may need to be adjusted in "
|
||||
"`configuration.yaml`"
|
||||
)
|
||||
|
||||
content.extend(
|
||||
await async_prepare_files_for_prompt(
|
||||
@ -244,11 +252,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> bo
|
||||
|
||||
try:
|
||||
await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list)
|
||||
except openai.AuthenticationError as err:
|
||||
LOGGER.error("Invalid API key: %s", err)
|
||||
return False
|
||||
except openai.OpenAIError as err:
|
||||
raise ConfigEntryNotReady(err) from err
|
||||
match err:
|
||||
case openai.AuthenticationError():
|
||||
LOGGER.error("Invalid API key: %s", err)
|
||||
return False
|
||||
case _:
|
||||
raise ConfigEntryNotReady(err) from err
|
||||
|
||||
entry.runtime_data = client
|
||||
|
||||
|
@ -114,10 +114,15 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
self._async_abort_entries_match(user_input)
|
||||
try:
|
||||
await validate_input(self.hass, user_input)
|
||||
except openai.APIConnectionError:
|
||||
errors["base"] = "cannot_connect"
|
||||
except openai.AuthenticationError:
|
||||
errors["base"] = "invalid_auth"
|
||||
except openai.OpenAIError as err:
|
||||
match err:
|
||||
case openai.APIConnectionError():
|
||||
errors["base"] = "cannot_connect"
|
||||
case openai.AuthenticationError():
|
||||
errors["base"] = "invalid_auth"
|
||||
case _:
|
||||
_LOGGER.exception("Unexpected OpenAI error")
|
||||
errors["base"] = "unknown"
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
@ -293,8 +298,12 @@ class OpenAISubentryFlowHandler(ConfigSubentryFlow):
|
||||
|
||||
if user_input is not None:
|
||||
options.update(user_input)
|
||||
if user_input.get(CONF_CHAT_MODEL) in UNSUPPORTED_MODELS:
|
||||
errors[CONF_CHAT_MODEL] = "model_not_supported"
|
||||
|
||||
match user_input.get(CONF_CHAT_MODEL):
|
||||
case model if model in UNSUPPORTED_MODELS:
|
||||
errors[CONF_CHAT_MODEL] = "model_not_supported"
|
||||
case _:
|
||||
pass
|
||||
|
||||
if not errors:
|
||||
return await self.async_step_model()
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""Constants for the OpenAI Conversation integration."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.helpers import llm
|
||||
@ -58,6 +60,50 @@ UNSUPPORTED_WEB_SEARCH_MODELS: list[str] = [
|
||||
"o3-mini",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConversationOptions:
|
||||
"""Configuration options for conversation."""
|
||||
|
||||
recommended: bool = True
|
||||
llm_hass_api: list[str] = field(default_factory=lambda: [llm.LLM_API_ASSIST])
|
||||
prompt: str = llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||
chat_model: str = RECOMMENDED_CHAT_MODEL
|
||||
max_tokens: int = RECOMMENDED_MAX_TOKENS
|
||||
temperature: float = RECOMMENDED_TEMPERATURE
|
||||
top_p: float = RECOMMENDED_TOP_P
|
||||
reasoning_effort: str = RECOMMENDED_REASONING_EFFORT
|
||||
web_search: bool = RECOMMENDED_WEB_SEARCH
|
||||
code_interpreter: bool = RECOMMENDED_CODE_INTERPRETER
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
CONF_RECOMMENDED: self.recommended,
|
||||
CONF_LLM_HASS_API: self.llm_hass_api,
|
||||
CONF_PROMPT: self.prompt,
|
||||
CONF_CHAT_MODEL: self.chat_model,
|
||||
CONF_MAX_TOKENS: self.max_tokens,
|
||||
CONF_TEMPERATURE: self.temperature,
|
||||
CONF_TOP_P: self.top_p,
|
||||
CONF_REASONING_EFFORT: self.reasoning_effort,
|
||||
CONF_WEB_SEARCH: self.web_search,
|
||||
CONF_CODE_INTERPRETER: self.code_interpreter,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AITaskOptions:
|
||||
"""Configuration options for AI tasks."""
|
||||
|
||||
recommended: bool = True
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {CONF_RECOMMENDED: self.recommended}
|
||||
|
||||
|
||||
# Maintain backward compatibility with existing dictionary format
|
||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
|
@ -184,91 +184,97 @@ async def _transform_stream(
|
||||
async for event in result:
|
||||
LOGGER.debug("Received event: %s", event)
|
||||
|
||||
if isinstance(event, ResponseOutputItemAddedEvent):
|
||||
if isinstance(event.item, ResponseOutputMessage):
|
||||
yield {"role": event.item.role}
|
||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||
match event:
|
||||
case ResponseOutputItemAddedEvent(item=ResponseOutputMessage() as msg):
|
||||
yield {"role": msg.role}
|
||||
case ResponseOutputItemAddedEvent(
|
||||
item=ResponseFunctionToolCall() as tool_call
|
||||
):
|
||||
# OpenAI has tool calls as individual events
|
||||
# while HA puts tool calls inside the assistant message.
|
||||
# We turn them into individual assistant content for HA
|
||||
# to ensure that tools are called as soon as possible.
|
||||
yield {"role": "assistant"}
|
||||
current_tool_call = event.item
|
||||
elif isinstance(event, ResponseOutputItemDoneEvent):
|
||||
item = event.item.model_dump()
|
||||
item.pop("status", None)
|
||||
if isinstance(event.item, ResponseReasoningItem):
|
||||
messages.append(cast(ResponseReasoningItemParam, item))
|
||||
elif isinstance(event.item, ResponseOutputMessage):
|
||||
messages.append(cast(ResponseOutputMessageParam, item))
|
||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||
messages.append(cast(ResponseFunctionToolCallParam, item))
|
||||
elif isinstance(event, ResponseTextDeltaEvent):
|
||||
yield {"content": event.delta}
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
|
||||
current_tool_call.arguments += event.delta
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
|
||||
current_tool_call.status = "completed"
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=current_tool_call.call_id,
|
||||
tool_name=current_tool_call.name,
|
||||
tool_args=json.loads(current_tool_call.arguments),
|
||||
)
|
||||
]
|
||||
}
|
||||
elif isinstance(event, ResponseCompletedEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
elif isinstance(event, ResponseIncompleteEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
event.response.incomplete_details
|
||||
and event.response.incomplete_details.reason
|
||||
current_tool_call = tool_call
|
||||
case ResponseOutputItemDoneEvent() as done_event:
|
||||
item = done_event.item.model_dump()
|
||||
item.pop("status", None)
|
||||
match done_event.item:
|
||||
case ResponseReasoningItem():
|
||||
messages.append(cast(ResponseReasoningItemParam, item))
|
||||
case ResponseOutputMessage():
|
||||
messages.append(cast(ResponseOutputMessageParam, item))
|
||||
case ResponseFunctionToolCall():
|
||||
messages.append(cast(ResponseFunctionToolCallParam, item))
|
||||
case ResponseTextDeltaEvent(delta=delta):
|
||||
yield {"content": delta}
|
||||
case ResponseFunctionCallArgumentsDeltaEvent(delta=delta):
|
||||
current_tool_call.arguments += delta
|
||||
case ResponseFunctionCallArgumentsDoneEvent():
|
||||
current_tool_call.status = "completed"
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=current_tool_call.call_id,
|
||||
tool_name=current_tool_call.name,
|
||||
tool_args=json.loads(current_tool_call.arguments),
|
||||
)
|
||||
]
|
||||
}
|
||||
case ResponseCompletedEvent(response=response) if (
|
||||
response.usage is not None
|
||||
):
|
||||
reason: str = event.response.incomplete_details.reason
|
||||
else:
|
||||
reason = "unknown reason"
|
||||
|
||||
if reason == "max_output_tokens":
|
||||
reason = "max output tokens reached"
|
||||
elif reason == "content_filter":
|
||||
reason = "content filter triggered"
|
||||
|
||||
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
|
||||
elif isinstance(event, ResponseFailedEvent):
|
||||
if event.response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": event.response.usage.input_tokens,
|
||||
"output_tokens": event.response.usage.output_tokens,
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
reason = "unknown reason"
|
||||
if event.response.error is not None:
|
||||
reason = event.response.error.message
|
||||
raise HomeAssistantError(f"OpenAI response failed: {reason}")
|
||||
elif isinstance(event, ResponseErrorEvent):
|
||||
raise HomeAssistantError(f"OpenAI response error: {event.message}")
|
||||
case ResponseIncompleteEvent(response=response):
|
||||
if response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
reason = (
|
||||
response.incomplete_details.reason
|
||||
if response.incomplete_details
|
||||
and response.incomplete_details.reason
|
||||
else "unknown reason"
|
||||
)
|
||||
|
||||
match reason:
|
||||
case "max_output_tokens":
|
||||
error_message = "max output tokens reached"
|
||||
case "content_filter":
|
||||
error_message = "content filter triggered"
|
||||
case _:
|
||||
error_message = reason
|
||||
|
||||
raise HomeAssistantError(f"OpenAI response incomplete: {error_message}")
|
||||
case ResponseFailedEvent(response=response):
|
||||
if response.usage is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
error_message = (
|
||||
response.error.message if response.error else "unknown reason"
|
||||
)
|
||||
raise HomeAssistantError(f"OpenAI response failed: {error_message}")
|
||||
case ResponseErrorEvent(message=message):
|
||||
raise HomeAssistantError(f"OpenAI response error: {message}")
|
||||
|
||||
|
||||
class OpenAIBaseLLMEntity(Entity):
|
||||
@ -436,32 +442,33 @@ async def async_prepare_files_for_prompt(
|
||||
if not file_path.exists():
|
||||
raise HomeAssistantError(f"`{file_path}` does not exist")
|
||||
|
||||
mime_type, _ = guess_file_type(file_path)
|
||||
|
||||
if not mime_type or not mime_type.startswith(("image/", "application/pdf")):
|
||||
if not (
|
||||
mime_type := guess_file_type(file_path)[0]
|
||||
) or not mime_type.startswith(("image/", "application/pdf")):
|
||||
raise HomeAssistantError(
|
||||
"Only images and PDF are supported by the OpenAI API,"
|
||||
f"`{file_path}` is not an image file or PDF"
|
||||
)
|
||||
|
||||
base64_file = base64.b64encode(file_path.read_bytes()).decode("utf-8")
|
||||
base64_data = f"data:{mime_type};base64,{base64.b64encode(file_path.read_bytes()).decode('utf-8')}"
|
||||
|
||||
if mime_type.startswith("image/"):
|
||||
content.append(
|
||||
ResponseInputImageParam(
|
||||
type="input_image",
|
||||
image_url=f"data:{mime_type};base64,{base64_file}",
|
||||
detail="auto",
|
||||
match mime_type.split("/")[0]:
|
||||
case "image":
|
||||
content.append(
|
||||
ResponseInputImageParam(
|
||||
type="input_image",
|
||||
image_url=base64_data,
|
||||
detail="auto",
|
||||
)
|
||||
)
|
||||
)
|
||||
elif mime_type.startswith("application/pdf"):
|
||||
content.append(
|
||||
ResponseInputFileParam(
|
||||
type="input_file",
|
||||
filename=str(file_path),
|
||||
file_data=f"data:{mime_type};base64,{base64_file}",
|
||||
case "application" if mime_type == "application/pdf":
|
||||
content.append(
|
||||
ResponseInputFileParam(
|
||||
type="input_file",
|
||||
filename=str(file_path),
|
||||
file_data=base64_data,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user