diff --git a/homeassistant/components/open_router/const.py b/homeassistant/components/open_router/const.py index 9fbce10da4e..7316d45c3e5 100644 --- a/homeassistant/components/open_router/const.py +++ b/homeassistant/components/open_router/const.py @@ -2,13 +2,12 @@ import logging -from homeassistant.const import CONF_LLM_HASS_API +from homeassistant.const import CONF_LLM_HASS_API, CONF_PROMPT from homeassistant.helpers import llm DOMAIN = "open_router" LOGGER = logging.getLogger(__package__) -CONF_PROMPT = "prompt" CONF_RECOMMENDED = "recommended" RECOMMENDED_CONVERSATION_OPTIONS = { diff --git a/homeassistant/components/open_router/conversation.py b/homeassistant/components/open_router/conversation.py index 06196565aad..826931d3da7 100644 --- a/homeassistant/components/open_router/conversation.py +++ b/homeassistant/components/open_router/conversation.py @@ -1,39 +1,16 @@ """Conversation support for OpenRouter.""" -from collections.abc import AsyncGenerator, Callable -import json -from typing import Any, Literal - -import openai -from openai import NOT_GIVEN -from openai.types.chat import ( - ChatCompletionAssistantMessageParam, - ChatCompletionMessage, - ChatCompletionMessageParam, - ChatCompletionMessageToolCallParam, - ChatCompletionSystemMessageParam, - ChatCompletionToolMessageParam, - ChatCompletionToolParam, - ChatCompletionUserMessageParam, -) -from openai.types.chat.chat_completion_message_tool_call_param import Function -from openai.types.shared_params import FunctionDefinition -from voluptuous_openapi import convert +from typing import Literal from homeassistant.components import conversation from homeassistant.config_entries import ConfigSubentry -from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL +from homeassistant.const import CONF_LLM_HASS_API, CONF_PROMPT, MATCH_ALL from homeassistant.core import HomeAssistant -from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import llm -from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from . import OpenRouterConfigEntry -from .const import CONF_PROMPT, DOMAIN, LOGGER - -# Max number of back and forth with the LLM to generate a response -MAX_TOOL_ITERATIONS = 10 +from .const import DOMAIN +from .entity import OpenRouterEntity async def async_setup_entry( @@ -49,106 +26,14 @@ async def async_setup_entry( ) -def _format_tool( - tool: llm.Tool, - custom_serializer: Callable[[Any], Any] | None, -) -> ChatCompletionToolParam: - """Format tool specification.""" - tool_spec = FunctionDefinition( - name=tool.name, - parameters=convert(tool.parameters, custom_serializer=custom_serializer), - ) - if tool.description: - tool_spec["description"] = tool.description - return ChatCompletionToolParam(type="function", function=tool_spec) - - -def _convert_content_to_chat_message( - content: conversation.Content, -) -> ChatCompletionMessageParam | None: - """Convert any native chat message for this agent to the native format.""" - LOGGER.debug("_convert_content_to_chat_message=%s", content) - if isinstance(content, conversation.ToolResultContent): - return ChatCompletionToolMessageParam( - role="tool", - tool_call_id=content.tool_call_id, - content=json.dumps(content.tool_result), - ) - - role: Literal["user", "assistant", "system"] = content.role - if role == "system" and content.content: - return ChatCompletionSystemMessageParam(role="system", content=content.content) - - if role == "user" and content.content: - return ChatCompletionUserMessageParam(role="user", content=content.content) - - if role == "assistant": - param = ChatCompletionAssistantMessageParam( - role="assistant", - content=content.content, - ) - if isinstance(content, conversation.AssistantContent) and content.tool_calls: - param["tool_calls"] = [ - ChatCompletionMessageToolCallParam( - type="function", - id=tool_call.id, - function=Function( - arguments=json.dumps(tool_call.tool_args), - name=tool_call.tool_name, - ), - ) - for tool_call in content.tool_calls - ] - return param - LOGGER.warning("Could not convert message to Completions API: %s", content) - return None - - -def _decode_tool_arguments(arguments: str) -> Any: - """Decode tool call arguments.""" - try: - return json.loads(arguments) - except json.JSONDecodeError as err: - raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err - - -async def _transform_response( - message: ChatCompletionMessage, -) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: - """Transform the OpenRouter message to a ChatLog format.""" - data: conversation.AssistantContentDeltaDict = { - "role": message.role, - "content": message.content, - } - if message.tool_calls: - data["tool_calls"] = [ - llm.ToolInput( - id=tool_call.id, - tool_name=tool_call.function.name, - tool_args=_decode_tool_arguments(tool_call.function.arguments), - ) - for tool_call in message.tool_calls - ] - yield data - - -class OpenRouterConversationEntity(conversation.ConversationEntity): +class OpenRouterConversationEntity(OpenRouterEntity, conversation.ConversationEntity): """OpenRouter conversation agent.""" - _attr_has_entity_name = True _attr_name = None def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None: """Initialize the agent.""" - self.entry = entry - self.subentry = subentry - self.model = subentry.data[CONF_MODEL] - self._attr_unique_id = subentry.subentry_id - self._attr_device_info = DeviceInfo( - identifiers={(DOMAIN, subentry.subentry_id)}, - name=subentry.title, - entry_type=DeviceEntryType.SERVICE, - ) + super().__init__(entry, subentry) if self.subentry.data.get(CONF_LLM_HASS_API): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL @@ -164,7 +49,7 @@ class OpenRouterConversationEntity(conversation.ConversationEntity): user_input: conversation.ConversationInput, chat_log: conversation.ChatLog, ) -> conversation.ConversationResult: - """Process a sentence.""" + """Process the user input and call the API.""" options = self.subentry.data try: @@ -177,49 +62,6 @@ class OpenRouterConversationEntity(conversation.ConversationEntity): except conversation.ConverseError as err: return err.as_conversation_result() - tools: list[ChatCompletionToolParam] | None = None - if chat_log.llm_api: - tools = [ - _format_tool(tool, chat_log.llm_api.custom_serializer) - for tool in chat_log.llm_api.tools - ] - - messages = [ - m - for content in chat_log.content - if (m := _convert_content_to_chat_message(content)) - ] - - client = self.entry.runtime_data - - for _iteration in range(MAX_TOOL_ITERATIONS): - try: - result = await client.chat.completions.create( - model=self.model, - messages=messages, - tools=tools or NOT_GIVEN, - user=chat_log.conversation_id, - extra_headers={ - "X-Title": "Home Assistant", - "HTTP-Referer": "https://www.home-assistant.io/integrations/open_router", - }, - ) - except openai.OpenAIError as err: - LOGGER.error("Error talking to API: %s", err) - raise HomeAssistantError("Error talking to API") from err - - result_message = result.choices[0].message - - messages.extend( - [ - msg - async for content in chat_log.async_add_delta_content_stream( - user_input.agent_id, _transform_response(result_message) - ) - if (msg := _convert_content_to_chat_message(content)) - ] - ) - if not chat_log.unresponded_tool_results: - break + await self._async_handle_chat_log(chat_log) return conversation.async_get_result_from_chat_log(user_input, chat_log) diff --git a/homeassistant/components/open_router/entity.py b/homeassistant/components/open_router/entity.py new file mode 100644 index 00000000000..e706656d377 --- /dev/null +++ b/homeassistant/components/open_router/entity.py @@ -0,0 +1,185 @@ +"""Base entity for Open Router.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Callable +import json +from typing import Any, Literal + +import openai +from openai import NOT_GIVEN +from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, +) +from openai.types.chat.chat_completion_message_tool_call_param import Function +from openai.types.shared_params import FunctionDefinition +from voluptuous_openapi import convert + +from homeassistant.components import conversation +from homeassistant.config_entries import ConfigSubentry +from homeassistant.const import CONF_MODEL +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import device_registry as dr, llm +from homeassistant.helpers.entity import Entity + +from . import OpenRouterConfigEntry +from .const import DOMAIN, LOGGER + +# Max number of back and forth with the LLM to generate a response +MAX_TOOL_ITERATIONS = 10 + + +def _format_tool( + tool: llm.Tool, + custom_serializer: Callable[[Any], Any] | None, +) -> ChatCompletionToolParam: + """Format tool specification.""" + tool_spec = FunctionDefinition( + name=tool.name, + parameters=convert(tool.parameters, custom_serializer=custom_serializer), + ) + if tool.description: + tool_spec["description"] = tool.description + return ChatCompletionToolParam(type="function", function=tool_spec) + + +def _convert_content_to_chat_message( + content: conversation.Content, +) -> ChatCompletionMessageParam | None: + """Convert any native chat message for this agent to the native format.""" + LOGGER.debug("_convert_content_to_chat_message=%s", content) + if isinstance(content, conversation.ToolResultContent): + return ChatCompletionToolMessageParam( + role="tool", + tool_call_id=content.tool_call_id, + content=json.dumps(content.tool_result), + ) + + role: Literal["user", "assistant", "system"] = content.role + if role == "system" and content.content: + return ChatCompletionSystemMessageParam(role="system", content=content.content) + + if role == "user" and content.content: + return ChatCompletionUserMessageParam(role="user", content=content.content) + + if role == "assistant": + param = ChatCompletionAssistantMessageParam( + role="assistant", + content=content.content, + ) + if isinstance(content, conversation.AssistantContent) and content.tool_calls: + param["tool_calls"] = [ + ChatCompletionMessageToolCallParam( + type="function", + id=tool_call.id, + function=Function( + arguments=json.dumps(tool_call.tool_args), + name=tool_call.tool_name, + ), + ) + for tool_call in content.tool_calls + ] + return param + LOGGER.warning("Could not convert message to Completions API: %s", content) + return None + + +def _decode_tool_arguments(arguments: str) -> Any: + """Decode tool call arguments.""" + try: + return json.loads(arguments) + except json.JSONDecodeError as err: + raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err + + +async def _transform_response( + message: ChatCompletionMessage, +) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: + """Transform the OpenRouter message to a ChatLog format.""" + data: conversation.AssistantContentDeltaDict = { + "role": message.role, + "content": message.content, + } + if message.tool_calls: + data["tool_calls"] = [ + llm.ToolInput( + id=tool_call.id, + tool_name=tool_call.function.name, + tool_args=_decode_tool_arguments(tool_call.function.arguments), + ) + for tool_call in message.tool_calls + ] + yield data + + +class OpenRouterEntity(Entity): + """Base entity for Open Router.""" + + _attr_has_entity_name = True + + def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None: + """Initialize the entity.""" + self.entry = entry + self.subentry = subentry + self.model = subentry.data[CONF_MODEL] + self._attr_unique_id = subentry.subentry_id + self._attr_device_info = dr.DeviceInfo( + identifiers={(DOMAIN, subentry.subentry_id)}, + name=subentry.title, + entry_type=dr.DeviceEntryType.SERVICE, + ) + + async def _async_handle_chat_log(self, chat_log: conversation.ChatLog) -> None: + """Generate an answer for the chat log.""" + + tools: list[ChatCompletionToolParam] | None = None + if chat_log.llm_api: + tools = [ + _format_tool(tool, chat_log.llm_api.custom_serializer) + for tool in chat_log.llm_api.tools + ] + + messages = [ + m + for content in chat_log.content + if (m := _convert_content_to_chat_message(content)) + ] + + client = self.entry.runtime_data + + for _iteration in range(MAX_TOOL_ITERATIONS): + try: + result = await client.chat.completions.create( + model=self.model, + messages=messages, + tools=tools or NOT_GIVEN, + user=chat_log.conversation_id, + extra_headers={ + "X-Title": "Home Assistant", + "HTTP-Referer": "https://www.home-assistant.io/integrations/open_router", + }, + ) + except openai.OpenAIError as err: + LOGGER.error("Error talking to API: %s", err) + raise HomeAssistantError("Error talking to API") from err + + result_message = result.choices[0].message + + messages.extend( + [ + msg + async for content in chat_log.async_add_delta_content_stream( + self.entity_id, _transform_response(result_message) + ) + if (msg := _convert_content_to_chat_message(content)) + ] + ) + if not chat_log.unresponded_tool_results: + break diff --git a/homeassistant/components/open_router/strings.json b/homeassistant/components/open_router/strings.json index 6e6674dac06..91c4cc350ae 100644 --- a/homeassistant/components/open_router/strings.json +++ b/homeassistant/components/open_router/strings.json @@ -25,7 +25,7 @@ "description": "Configure the new conversation agent", "data": { "model": "Model", - "prompt": "Instructions", + "prompt": "[%key:common::config_flow::data::prompt%]", "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" }, "data_description": {