"""Base entity for Open Router.""" from __future__ import annotations from collections.abc import AsyncGenerator, Callable import json from typing import TYPE_CHECKING, Any, Literal import openai from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionMessage, ChatCompletionMessageParam, ChatCompletionMessageToolCallParam, ChatCompletionSystemMessageParam, ChatCompletionToolMessageParam, ChatCompletionToolParam, ChatCompletionUserMessageParam, ) from openai.types.chat.chat_completion_message_tool_call_param import Function from openai.types.shared_params import FunctionDefinition, ResponseFormatJSONSchema from openai.types.shared_params.response_format_json_schema import JSONSchema import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import conversation from homeassistant.config_entries import ConfigSubentry from homeassistant.const import CONF_MODEL from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import device_registry as dr, llm from homeassistant.helpers.entity import Entity from . import OpenRouterConfigEntry from .const import DOMAIN, LOGGER # Max number of back and forth with the LLM to generate a response MAX_TOOL_ITERATIONS = 10 def _adjust_schema(schema: dict[str, Any]) -> None: """Adjust the schema to be compatible with OpenRouter API.""" if schema["type"] == "object": if "properties" not in schema: return if "required" not in schema: schema["required"] = [] # Ensure all properties are required for prop, prop_info in schema["properties"].items(): _adjust_schema(prop_info) if prop not in schema["required"]: prop_info["type"] = [prop_info["type"], "null"] schema["required"].append(prop) elif schema["type"] == "array": if "items" not in schema: return _adjust_schema(schema["items"]) def _format_structured_output( name: str, schema: vol.Schema, llm_api: llm.APIInstance | None ) -> JSONSchema: """Format the schema to be compatible with OpenRouter API.""" result: JSONSchema = { "name": name, "strict": True, } result_schema = convert( schema, custom_serializer=( llm_api.custom_serializer if llm_api else llm.selector_serializer ), ) _adjust_schema(result_schema) result["schema"] = result_schema return result def _format_tool( tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None, ) -> ChatCompletionToolParam: """Format tool specification.""" tool_spec = FunctionDefinition( name=tool.name, parameters=convert(tool.parameters, custom_serializer=custom_serializer), ) if tool.description: tool_spec["description"] = tool.description return ChatCompletionToolParam(type="function", function=tool_spec) def _convert_content_to_chat_message( content: conversation.Content, ) -> ChatCompletionMessageParam | None: """Convert any native chat message for this agent to the native format.""" LOGGER.debug("_convert_content_to_chat_message=%s", content) if isinstance(content, conversation.ToolResultContent): return ChatCompletionToolMessageParam( role="tool", tool_call_id=content.tool_call_id, content=json.dumps(content.tool_result), ) role: Literal["user", "assistant", "system"] = content.role if role == "system" and content.content: return ChatCompletionSystemMessageParam(role="system", content=content.content) if role == "user" and content.content: return ChatCompletionUserMessageParam(role="user", content=content.content) if role == "assistant": param = ChatCompletionAssistantMessageParam( role="assistant", content=content.content, ) if isinstance(content, conversation.AssistantContent) and content.tool_calls: param["tool_calls"] = [ ChatCompletionMessageToolCallParam( type="function", id=tool_call.id, function=Function( arguments=json.dumps(tool_call.tool_args), name=tool_call.tool_name, ), ) for tool_call in content.tool_calls ] return param LOGGER.warning("Could not convert message to Completions API: %s", content) return None def _decode_tool_arguments(arguments: str) -> Any: """Decode tool call arguments.""" try: return json.loads(arguments) except json.JSONDecodeError as err: raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err async def _transform_response( message: ChatCompletionMessage, ) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: """Transform the OpenRouter message to a ChatLog format.""" data: conversation.AssistantContentDeltaDict = { "role": message.role, "content": message.content, } if message.tool_calls: data["tool_calls"] = [ llm.ToolInput( id=tool_call.id, tool_name=tool_call.function.name, tool_args=_decode_tool_arguments(tool_call.function.arguments), ) for tool_call in message.tool_calls ] yield data class OpenRouterEntity(Entity): """Base entity for Open Router.""" _attr_has_entity_name = True def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None: """Initialize the entity.""" self.entry = entry self.subentry = subentry self.model = subentry.data[CONF_MODEL] self._attr_unique_id = subentry.subentry_id self._attr_device_info = dr.DeviceInfo( identifiers={(DOMAIN, subentry.subentry_id)}, name=subentry.title, entry_type=dr.DeviceEntryType.SERVICE, ) async def _async_handle_chat_log( self, chat_log: conversation.ChatLog, structure_name: str | None = None, structure: vol.Schema | None = None, ) -> None: """Generate an answer for the chat log.""" model_args = { "model": self.model, "user": chat_log.conversation_id, "extra_headers": { "X-Title": "Home Assistant", "HTTP-Referer": "https://www.home-assistant.io/integrations/open_router", }, "extra_body": {"require_parameters": True}, } tools: list[ChatCompletionToolParam] | None = None if chat_log.llm_api: tools = [ _format_tool(tool, chat_log.llm_api.custom_serializer) for tool in chat_log.llm_api.tools ] if tools: model_args["tools"] = tools model_args["messages"] = [ m for content in chat_log.content if (m := _convert_content_to_chat_message(content)) ] if structure: if TYPE_CHECKING: assert structure_name is not None model_args["response_format"] = ResponseFormatJSONSchema( type="json_schema", json_schema=_format_structured_output( structure_name, structure, chat_log.llm_api ), ) client = self.entry.runtime_data for _iteration in range(MAX_TOOL_ITERATIONS): try: result = await client.chat.completions.create(**model_args) except openai.OpenAIError as err: LOGGER.error("Error talking to API: %s", err) raise HomeAssistantError("Error talking to API") from err result_message = result.choices[0].message model_args["messages"].extend( [ msg async for content in chat_log.async_add_delta_content_stream( self.entity_id, _transform_response(result_message) ) if (msg := _convert_content_to_chat_message(content)) ] ) if not chat_log.unresponded_tool_results: break