Introduce base entity in Open Router (#148910)

This commit is contained in:
Joost Lekkerkerker 2025-07-22 13:43:41 +02:00 committed by GitHub
parent 49807c9fbe
commit e5c7e04329
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 195 additions and 169 deletions

View File

@ -2,13 +2,12 @@
import logging import logging
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API, CONF_PROMPT
from homeassistant.helpers import llm from homeassistant.helpers import llm
DOMAIN = "open_router" DOMAIN = "open_router"
LOGGER = logging.getLogger(__package__) LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
CONF_RECOMMENDED = "recommended" CONF_RECOMMENDED = "recommended"
RECOMMENDED_CONVERSATION_OPTIONS = { RECOMMENDED_CONVERSATION_OPTIONS = {

View File

@ -1,39 +1,16 @@
"""Conversation support for OpenRouter.""" """Conversation support for OpenRouter."""
from collections.abc import AsyncGenerator, Callable from typing import Literal
import json
from typing import Any, Literal
import openai
from openai import NOT_GIVEN
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition
from voluptuous_openapi import convert
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, CONF_PROMPT, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OpenRouterConfigEntry from . import OpenRouterConfigEntry
from .const import CONF_PROMPT, DOMAIN, LOGGER from .const import DOMAIN
from .entity import OpenRouterEntity
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
async def async_setup_entry( async def async_setup_entry(
@ -49,106 +26,14 @@ async def async_setup_entry(
) )
def _format_tool( class OpenRouterConversationEntity(OpenRouterEntity, conversation.ConversationEntity):
tool: llm.Tool,
custom_serializer: Callable[[Any], Any] | None,
) -> ChatCompletionToolParam:
"""Format tool specification."""
tool_spec = FunctionDefinition(
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
)
if tool.description:
tool_spec["description"] = tool.description
return ChatCompletionToolParam(type="function", function=tool_spec)
def _convert_content_to_chat_message(
content: conversation.Content,
) -> ChatCompletionMessageParam | None:
"""Convert any native chat message for this agent to the native format."""
LOGGER.debug("_convert_content_to_chat_message=%s", content)
if isinstance(content, conversation.ToolResultContent):
return ChatCompletionToolMessageParam(
role="tool",
tool_call_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
role: Literal["user", "assistant", "system"] = content.role
if role == "system" and content.content:
return ChatCompletionSystemMessageParam(role="system", content=content.content)
if role == "user" and content.content:
return ChatCompletionUserMessageParam(role="user", content=content.content)
if role == "assistant":
param = ChatCompletionAssistantMessageParam(
role="assistant",
content=content.content,
)
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
param["tool_calls"] = [
ChatCompletionMessageToolCallParam(
type="function",
id=tool_call.id,
function=Function(
arguments=json.dumps(tool_call.tool_args),
name=tool_call.tool_name,
),
)
for tool_call in content.tool_calls
]
return param
LOGGER.warning("Could not convert message to Completions API: %s", content)
return None
def _decode_tool_arguments(arguments: str) -> Any:
"""Decode tool call arguments."""
try:
return json.loads(arguments)
except json.JSONDecodeError as err:
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
async def _transform_response(
message: ChatCompletionMessage,
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the OpenRouter message to a ChatLog format."""
data: conversation.AssistantContentDeltaDict = {
"role": message.role,
"content": message.content,
}
if message.tool_calls:
data["tool_calls"] = [
llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.function.name,
tool_args=_decode_tool_arguments(tool_call.function.arguments),
)
for tool_call in message.tool_calls
]
yield data
class OpenRouterConversationEntity(conversation.ConversationEntity):
"""OpenRouter conversation agent.""" """OpenRouter conversation agent."""
_attr_has_entity_name = True
_attr_name = None _attr_name = None
def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None: def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.entry = entry super().__init__(entry, subentry)
self.subentry = subentry
self.model = subentry.data[CONF_MODEL]
self._attr_unique_id = subentry.subentry_id
self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, subentry.subentry_id)},
name=subentry.title,
entry_type=DeviceEntryType.SERVICE,
)
if self.subentry.data.get(CONF_LLM_HASS_API): if self.subentry.data.get(CONF_LLM_HASS_API):
self._attr_supported_features = ( self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL conversation.ConversationEntityFeature.CONTROL
@ -164,7 +49,7 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
user_input: conversation.ConversationInput, user_input: conversation.ConversationInput,
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process the user input and call the API."""
options = self.subentry.data options = self.subentry.data
try: try:
@ -177,49 +62,6 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
except conversation.ConverseError as err: except conversation.ConverseError as err:
return err.as_conversation_result() return err.as_conversation_result()
tools: list[ChatCompletionToolParam] | None = None await self._async_handle_chat_log(chat_log)
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
messages = [
m
for content in chat_log.content
if (m := _convert_content_to_chat_message(content))
]
client = self.entry.runtime_data
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(
model=self.model,
messages=messages,
tools=tools or NOT_GIVEN,
user=chat_log.conversation_id,
extra_headers={
"X-Title": "Home Assistant",
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
},
)
except openai.OpenAIError as err:
LOGGER.error("Error talking to API: %s", err)
raise HomeAssistantError("Error talking to API") from err
result_message = result.choices[0].message
messages.extend(
[
msg
async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id, _transform_response(result_message)
)
if (msg := _convert_content_to_chat_message(content))
]
)
if not chat_log.unresponded_tool_results:
break
return conversation.async_get_result_from_chat_log(user_input, chat_log) return conversation.async_get_result_from_chat_log(user_input, chat_log)

View File

@ -0,0 +1,185 @@
"""Base entity for Open Router."""
from __future__ import annotations
from collections.abc import AsyncGenerator, Callable
import json
from typing import Any, Literal
import openai
from openai import NOT_GIVEN
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition
from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import CONF_MODEL
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity
from . import OpenRouterConfigEntry
from .const import DOMAIN, LOGGER
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
def _format_tool(
tool: llm.Tool,
custom_serializer: Callable[[Any], Any] | None,
) -> ChatCompletionToolParam:
"""Format tool specification."""
tool_spec = FunctionDefinition(
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
)
if tool.description:
tool_spec["description"] = tool.description
return ChatCompletionToolParam(type="function", function=tool_spec)
def _convert_content_to_chat_message(
content: conversation.Content,
) -> ChatCompletionMessageParam | None:
"""Convert any native chat message for this agent to the native format."""
LOGGER.debug("_convert_content_to_chat_message=%s", content)
if isinstance(content, conversation.ToolResultContent):
return ChatCompletionToolMessageParam(
role="tool",
tool_call_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
role: Literal["user", "assistant", "system"] = content.role
if role == "system" and content.content:
return ChatCompletionSystemMessageParam(role="system", content=content.content)
if role == "user" and content.content:
return ChatCompletionUserMessageParam(role="user", content=content.content)
if role == "assistant":
param = ChatCompletionAssistantMessageParam(
role="assistant",
content=content.content,
)
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
param["tool_calls"] = [
ChatCompletionMessageToolCallParam(
type="function",
id=tool_call.id,
function=Function(
arguments=json.dumps(tool_call.tool_args),
name=tool_call.tool_name,
),
)
for tool_call in content.tool_calls
]
return param
LOGGER.warning("Could not convert message to Completions API: %s", content)
return None
def _decode_tool_arguments(arguments: str) -> Any:
"""Decode tool call arguments."""
try:
return json.loads(arguments)
except json.JSONDecodeError as err:
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
async def _transform_response(
message: ChatCompletionMessage,
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the OpenRouter message to a ChatLog format."""
data: conversation.AssistantContentDeltaDict = {
"role": message.role,
"content": message.content,
}
if message.tool_calls:
data["tool_calls"] = [
llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.function.name,
tool_args=_decode_tool_arguments(tool_call.function.arguments),
)
for tool_call in message.tool_calls
]
yield data
class OpenRouterEntity(Entity):
"""Base entity for Open Router."""
_attr_has_entity_name = True
def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the entity."""
self.entry = entry
self.subentry = subentry
self.model = subentry.data[CONF_MODEL]
self._attr_unique_id = subentry.subentry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, subentry.subentry_id)},
name=subentry.title,
entry_type=dr.DeviceEntryType.SERVICE,
)
async def _async_handle_chat_log(self, chat_log: conversation.ChatLog) -> None:
"""Generate an answer for the chat log."""
tools: list[ChatCompletionToolParam] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
messages = [
m
for content in chat_log.content
if (m := _convert_content_to_chat_message(content))
]
client = self.entry.runtime_data
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(
model=self.model,
messages=messages,
tools=tools or NOT_GIVEN,
user=chat_log.conversation_id,
extra_headers={
"X-Title": "Home Assistant",
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
},
)
except openai.OpenAIError as err:
LOGGER.error("Error talking to API: %s", err)
raise HomeAssistantError("Error talking to API") from err
result_message = result.choices[0].message
messages.extend(
[
msg
async for content in chat_log.async_add_delta_content_stream(
self.entity_id, _transform_response(result_message)
)
if (msg := _convert_content_to_chat_message(content))
]
)
if not chat_log.unresponded_tool_results:
break

View File

@ -25,7 +25,7 @@
"description": "Configure the new conversation agent", "description": "Configure the new conversation agent",
"data": { "data": {
"model": "Model", "model": "Model",
"prompt": "Instructions", "prompt": "[%key:common::config_flow::data::prompt%]",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
}, },
"data_description": { "data_description": {