mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 23:27:37 +00:00
Introduce base entity in Open Router (#148910)
This commit is contained in:
parent
49807c9fbe
commit
e5c7e04329
@ -2,13 +2,12 @@
|
||||
|
||||
import logging
|
||||
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_PROMPT
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
DOMAIN = "open_router"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
|
||||
CONF_PROMPT = "prompt"
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
|
||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
|
@ -1,39 +1,16 @@
|
||||
"""Conversation support for OpenRouter."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from openai import NOT_GIVEN
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||
from openai.types.shared_params import FunctionDefinition
|
||||
from voluptuous_openapi import convert
|
||||
from typing import Literal
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
|
||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_PROMPT, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from . import OpenRouterConfigEntry
|
||||
from .const import CONF_PROMPT, DOMAIN, LOGGER
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
from .const import DOMAIN
|
||||
from .entity import OpenRouterEntity
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
@ -49,106 +26,14 @@ async def async_setup_entry(
|
||||
)
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool,
|
||||
custom_serializer: Callable[[Any], Any] | None,
|
||||
) -> ChatCompletionToolParam:
|
||||
"""Format tool specification."""
|
||||
tool_spec = FunctionDefinition(
|
||||
name=tool.name,
|
||||
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
)
|
||||
if tool.description:
|
||||
tool_spec["description"] = tool.description
|
||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
||||
|
||||
def _convert_content_to_chat_message(
|
||||
content: conversation.Content,
|
||||
) -> ChatCompletionMessageParam | None:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
LOGGER.debug("_convert_content_to_chat_message=%s", content)
|
||||
if isinstance(content, conversation.ToolResultContent):
|
||||
return ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
tool_call_id=content.tool_call_id,
|
||||
content=json.dumps(content.tool_result),
|
||||
)
|
||||
|
||||
role: Literal["user", "assistant", "system"] = content.role
|
||||
if role == "system" and content.content:
|
||||
return ChatCompletionSystemMessageParam(role="system", content=content.content)
|
||||
|
||||
if role == "user" and content.content:
|
||||
return ChatCompletionUserMessageParam(role="user", content=content.content)
|
||||
|
||||
if role == "assistant":
|
||||
param = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content.content,
|
||||
)
|
||||
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
||||
param["tool_calls"] = [
|
||||
ChatCompletionMessageToolCallParam(
|
||||
type="function",
|
||||
id=tool_call.id,
|
||||
function=Function(
|
||||
arguments=json.dumps(tool_call.tool_args),
|
||||
name=tool_call.tool_name,
|
||||
),
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
]
|
||||
return param
|
||||
LOGGER.warning("Could not convert message to Completions API: %s", content)
|
||||
return None
|
||||
|
||||
|
||||
def _decode_tool_arguments(arguments: str) -> Any:
|
||||
"""Decode tool call arguments."""
|
||||
try:
|
||||
return json.loads(arguments)
|
||||
except json.JSONDecodeError as err:
|
||||
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
|
||||
|
||||
|
||||
async def _transform_response(
|
||||
message: ChatCompletionMessage,
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform the OpenRouter message to a ChatLog format."""
|
||||
data: conversation.AssistantContentDeltaDict = {
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.tool_calls:
|
||||
data["tool_calls"] = [
|
||||
llm.ToolInput(
|
||||
id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=_decode_tool_arguments(tool_call.function.arguments),
|
||||
)
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
yield data
|
||||
|
||||
|
||||
class OpenRouterConversationEntity(conversation.ConversationEntity):
|
||||
class OpenRouterConversationEntity(OpenRouterEntity, conversation.ConversationEntity):
|
||||
"""OpenRouter conversation agent."""
|
||||
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
|
||||
def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.entry = entry
|
||||
self.subentry = subentry
|
||||
self.model = subentry.data[CONF_MODEL]
|
||||
self._attr_unique_id = subentry.subentry_id
|
||||
self._attr_device_info = DeviceInfo(
|
||||
identifiers={(DOMAIN, subentry.subentry_id)},
|
||||
name=subentry.title,
|
||||
entry_type=DeviceEntryType.SERVICE,
|
||||
)
|
||||
super().__init__(entry, subentry)
|
||||
if self.subentry.data.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
@ -164,7 +49,7 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
||||
user_input: conversation.ConversationInput,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
"""Process the user input and call the API."""
|
||||
options = self.subentry.data
|
||||
|
||||
try:
|
||||
@ -177,49 +62,6 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
||||
except conversation.ConverseError as err:
|
||||
return err.as_conversation_result()
|
||||
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
messages = [
|
||||
m
|
||||
for content in chat_log.content
|
||||
if (m := _convert_content_to_chat_message(content))
|
||||
]
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
result = await client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=tools or NOT_GIVEN,
|
||||
user=chat_log.conversation_id,
|
||||
extra_headers={
|
||||
"X-Title": "Home Assistant",
|
||||
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
||||
},
|
||||
)
|
||||
except openai.OpenAIError as err:
|
||||
LOGGER.error("Error talking to API: %s", err)
|
||||
raise HomeAssistantError("Error talking to API") from err
|
||||
|
||||
result_message = result.choices[0].message
|
||||
|
||||
messages.extend(
|
||||
[
|
||||
msg
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
user_input.agent_id, _transform_response(result_message)
|
||||
)
|
||||
if (msg := _convert_content_to_chat_message(content))
|
||||
]
|
||||
)
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
await self._async_handle_chat_log(chat_log)
|
||||
|
||||
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||
|
185
homeassistant/components/open_router/entity.py
Normal file
185
homeassistant/components/open_router/entity.py
Normal file
@ -0,0 +1,185 @@
|
||||
"""Base entity for Open Router."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from openai import NOT_GIVEN
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||
from openai.types.shared_params import FunctionDefinition
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.const import CONF_MODEL
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, llm
|
||||
from homeassistant.helpers.entity import Entity
|
||||
|
||||
from . import OpenRouterConfigEntry
|
||||
from .const import DOMAIN, LOGGER
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool,
|
||||
custom_serializer: Callable[[Any], Any] | None,
|
||||
) -> ChatCompletionToolParam:
|
||||
"""Format tool specification."""
|
||||
tool_spec = FunctionDefinition(
|
||||
name=tool.name,
|
||||
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
)
|
||||
if tool.description:
|
||||
tool_spec["description"] = tool.description
|
||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
||||
|
||||
def _convert_content_to_chat_message(
|
||||
content: conversation.Content,
|
||||
) -> ChatCompletionMessageParam | None:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
LOGGER.debug("_convert_content_to_chat_message=%s", content)
|
||||
if isinstance(content, conversation.ToolResultContent):
|
||||
return ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
tool_call_id=content.tool_call_id,
|
||||
content=json.dumps(content.tool_result),
|
||||
)
|
||||
|
||||
role: Literal["user", "assistant", "system"] = content.role
|
||||
if role == "system" and content.content:
|
||||
return ChatCompletionSystemMessageParam(role="system", content=content.content)
|
||||
|
||||
if role == "user" and content.content:
|
||||
return ChatCompletionUserMessageParam(role="user", content=content.content)
|
||||
|
||||
if role == "assistant":
|
||||
param = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content.content,
|
||||
)
|
||||
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
||||
param["tool_calls"] = [
|
||||
ChatCompletionMessageToolCallParam(
|
||||
type="function",
|
||||
id=tool_call.id,
|
||||
function=Function(
|
||||
arguments=json.dumps(tool_call.tool_args),
|
||||
name=tool_call.tool_name,
|
||||
),
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
]
|
||||
return param
|
||||
LOGGER.warning("Could not convert message to Completions API: %s", content)
|
||||
return None
|
||||
|
||||
|
||||
def _decode_tool_arguments(arguments: str) -> Any:
|
||||
"""Decode tool call arguments."""
|
||||
try:
|
||||
return json.loads(arguments)
|
||||
except json.JSONDecodeError as err:
|
||||
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
|
||||
|
||||
|
||||
async def _transform_response(
|
||||
message: ChatCompletionMessage,
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform the OpenRouter message to a ChatLog format."""
|
||||
data: conversation.AssistantContentDeltaDict = {
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.tool_calls:
|
||||
data["tool_calls"] = [
|
||||
llm.ToolInput(
|
||||
id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=_decode_tool_arguments(tool_call.function.arguments),
|
||||
)
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
yield data
|
||||
|
||||
|
||||
class OpenRouterEntity(Entity):
|
||||
"""Base entity for Open Router."""
|
||||
|
||||
_attr_has_entity_name = True
|
||||
|
||||
def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None:
|
||||
"""Initialize the entity."""
|
||||
self.entry = entry
|
||||
self.subentry = subentry
|
||||
self.model = subentry.data[CONF_MODEL]
|
||||
self._attr_unique_id = subentry.subentry_id
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, subentry.subentry_id)},
|
||||
name=subentry.title,
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
|
||||
async def _async_handle_chat_log(self, chat_log: conversation.ChatLog) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
messages = [
|
||||
m
|
||||
for content in chat_log.content
|
||||
if (m := _convert_content_to_chat_message(content))
|
||||
]
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
result = await client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=tools or NOT_GIVEN,
|
||||
user=chat_log.conversation_id,
|
||||
extra_headers={
|
||||
"X-Title": "Home Assistant",
|
||||
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
||||
},
|
||||
)
|
||||
except openai.OpenAIError as err:
|
||||
LOGGER.error("Error talking to API: %s", err)
|
||||
raise HomeAssistantError("Error talking to API") from err
|
||||
|
||||
result_message = result.choices[0].message
|
||||
|
||||
messages.extend(
|
||||
[
|
||||
msg
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
self.entity_id, _transform_response(result_message)
|
||||
)
|
||||
if (msg := _convert_content_to_chat_message(content))
|
||||
]
|
||||
)
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
@ -25,7 +25,7 @@
|
||||
"description": "Configure the new conversation agent",
|
||||
"data": {
|
||||
"model": "Model",
|
||||
"prompt": "Instructions",
|
||||
"prompt": "[%key:common::config_flow::data::prompt%]",
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
||||
},
|
||||
"data_description": {
|
||||
|
Loading…
x
Reference in New Issue
Block a user