mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 07:37:34 +00:00
Introduce base entity in Open Router (#148910)
This commit is contained in:
parent
49807c9fbe
commit
e5c7e04329
@ -2,13 +2,12 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API, CONF_PROMPT
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
|
|
||||||
DOMAIN = "open_router"
|
DOMAIN = "open_router"
|
||||||
LOGGER = logging.getLogger(__package__)
|
LOGGER = logging.getLogger(__package__)
|
||||||
|
|
||||||
CONF_PROMPT = "prompt"
|
|
||||||
CONF_RECOMMENDED = "recommended"
|
CONF_RECOMMENDED = "recommended"
|
||||||
|
|
||||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||||
|
@ -1,39 +1,16 @@
|
|||||||
"""Conversation support for OpenRouter."""
|
"""Conversation support for OpenRouter."""
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from typing import Literal
|
||||||
import json
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
import openai
|
|
||||||
from openai import NOT_GIVEN
|
|
||||||
from openai.types.chat import (
|
|
||||||
ChatCompletionAssistantMessageParam,
|
|
||||||
ChatCompletionMessage,
|
|
||||||
ChatCompletionMessageParam,
|
|
||||||
ChatCompletionMessageToolCallParam,
|
|
||||||
ChatCompletionSystemMessageParam,
|
|
||||||
ChatCompletionToolMessageParam,
|
|
||||||
ChatCompletionToolParam,
|
|
||||||
ChatCompletionUserMessageParam,
|
|
||||||
)
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
|
||||||
from openai.types.shared_params import FunctionDefinition
|
|
||||||
from voluptuous_openapi import convert
|
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.config_entries import ConfigSubentry
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, CONF_PROMPT, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
|
||||||
from homeassistant.helpers import llm
|
|
||||||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
from . import OpenRouterConfigEntry
|
from . import OpenRouterConfigEntry
|
||||||
from .const import CONF_PROMPT, DOMAIN, LOGGER
|
from .const import DOMAIN
|
||||||
|
from .entity import OpenRouterEntity
|
||||||
# Max number of back and forth with the LLM to generate a response
|
|
||||||
MAX_TOOL_ITERATIONS = 10
|
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
@ -49,106 +26,14 @@ async def async_setup_entry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _format_tool(
|
class OpenRouterConversationEntity(OpenRouterEntity, conversation.ConversationEntity):
|
||||||
tool: llm.Tool,
|
|
||||||
custom_serializer: Callable[[Any], Any] | None,
|
|
||||||
) -> ChatCompletionToolParam:
|
|
||||||
"""Format tool specification."""
|
|
||||||
tool_spec = FunctionDefinition(
|
|
||||||
name=tool.name,
|
|
||||||
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
|
||||||
)
|
|
||||||
if tool.description:
|
|
||||||
tool_spec["description"] = tool.description
|
|
||||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_content_to_chat_message(
|
|
||||||
content: conversation.Content,
|
|
||||||
) -> ChatCompletionMessageParam | None:
|
|
||||||
"""Convert any native chat message for this agent to the native format."""
|
|
||||||
LOGGER.debug("_convert_content_to_chat_message=%s", content)
|
|
||||||
if isinstance(content, conversation.ToolResultContent):
|
|
||||||
return ChatCompletionToolMessageParam(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=content.tool_call_id,
|
|
||||||
content=json.dumps(content.tool_result),
|
|
||||||
)
|
|
||||||
|
|
||||||
role: Literal["user", "assistant", "system"] = content.role
|
|
||||||
if role == "system" and content.content:
|
|
||||||
return ChatCompletionSystemMessageParam(role="system", content=content.content)
|
|
||||||
|
|
||||||
if role == "user" and content.content:
|
|
||||||
return ChatCompletionUserMessageParam(role="user", content=content.content)
|
|
||||||
|
|
||||||
if role == "assistant":
|
|
||||||
param = ChatCompletionAssistantMessageParam(
|
|
||||||
role="assistant",
|
|
||||||
content=content.content,
|
|
||||||
)
|
|
||||||
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
|
||||||
param["tool_calls"] = [
|
|
||||||
ChatCompletionMessageToolCallParam(
|
|
||||||
type="function",
|
|
||||||
id=tool_call.id,
|
|
||||||
function=Function(
|
|
||||||
arguments=json.dumps(tool_call.tool_args),
|
|
||||||
name=tool_call.tool_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for tool_call in content.tool_calls
|
|
||||||
]
|
|
||||||
return param
|
|
||||||
LOGGER.warning("Could not convert message to Completions API: %s", content)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _decode_tool_arguments(arguments: str) -> Any:
|
|
||||||
"""Decode tool call arguments."""
|
|
||||||
try:
|
|
||||||
return json.loads(arguments)
|
|
||||||
except json.JSONDecodeError as err:
|
|
||||||
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
|
|
||||||
|
|
||||||
|
|
||||||
async def _transform_response(
|
|
||||||
message: ChatCompletionMessage,
|
|
||||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
|
||||||
"""Transform the OpenRouter message to a ChatLog format."""
|
|
||||||
data: conversation.AssistantContentDeltaDict = {
|
|
||||||
"role": message.role,
|
|
||||||
"content": message.content,
|
|
||||||
}
|
|
||||||
if message.tool_calls:
|
|
||||||
data["tool_calls"] = [
|
|
||||||
llm.ToolInput(
|
|
||||||
id=tool_call.id,
|
|
||||||
tool_name=tool_call.function.name,
|
|
||||||
tool_args=_decode_tool_arguments(tool_call.function.arguments),
|
|
||||||
)
|
|
||||||
for tool_call in message.tool_calls
|
|
||||||
]
|
|
||||||
yield data
|
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterConversationEntity(conversation.ConversationEntity):
|
|
||||||
"""OpenRouter conversation agent."""
|
"""OpenRouter conversation agent."""
|
||||||
|
|
||||||
_attr_has_entity_name = True
|
|
||||||
_attr_name = None
|
_attr_name = None
|
||||||
|
|
||||||
def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None:
|
def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.entry = entry
|
super().__init__(entry, subentry)
|
||||||
self.subentry = subentry
|
|
||||||
self.model = subentry.data[CONF_MODEL]
|
|
||||||
self._attr_unique_id = subentry.subentry_id
|
|
||||||
self._attr_device_info = DeviceInfo(
|
|
||||||
identifiers={(DOMAIN, subentry.subentry_id)},
|
|
||||||
name=subentry.title,
|
|
||||||
entry_type=DeviceEntryType.SERVICE,
|
|
||||||
)
|
|
||||||
if self.subentry.data.get(CONF_LLM_HASS_API):
|
if self.subentry.data.get(CONF_LLM_HASS_API):
|
||||||
self._attr_supported_features = (
|
self._attr_supported_features = (
|
||||||
conversation.ConversationEntityFeature.CONTROL
|
conversation.ConversationEntityFeature.CONTROL
|
||||||
@ -164,7 +49,7 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
|||||||
user_input: conversation.ConversationInput,
|
user_input: conversation.ConversationInput,
|
||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process the user input and call the API."""
|
||||||
options = self.subentry.data
|
options = self.subentry.data
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -177,49 +62,6 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
|||||||
except conversation.ConverseError as err:
|
except conversation.ConverseError as err:
|
||||||
return err.as_conversation_result()
|
return err.as_conversation_result()
|
||||||
|
|
||||||
tools: list[ChatCompletionToolParam] | None = None
|
await self._async_handle_chat_log(chat_log)
|
||||||
if chat_log.llm_api:
|
|
||||||
tools = [
|
|
||||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
|
||||||
for tool in chat_log.llm_api.tools
|
|
||||||
]
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
m
|
|
||||||
for content in chat_log.content
|
|
||||||
if (m := _convert_content_to_chat_message(content))
|
|
||||||
]
|
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
|
||||||
|
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
|
||||||
try:
|
|
||||||
result = await client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
tools=tools or NOT_GIVEN,
|
|
||||||
user=chat_log.conversation_id,
|
|
||||||
extra_headers={
|
|
||||||
"X-Title": "Home Assistant",
|
|
||||||
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except openai.OpenAIError as err:
|
|
||||||
LOGGER.error("Error talking to API: %s", err)
|
|
||||||
raise HomeAssistantError("Error talking to API") from err
|
|
||||||
|
|
||||||
result_message = result.choices[0].message
|
|
||||||
|
|
||||||
messages.extend(
|
|
||||||
[
|
|
||||||
msg
|
|
||||||
async for content in chat_log.async_add_delta_content_stream(
|
|
||||||
user_input.agent_id, _transform_response(result_message)
|
|
||||||
)
|
|
||||||
if (msg := _convert_content_to_chat_message(content))
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if not chat_log.unresponded_tool_results:
|
|
||||||
break
|
|
||||||
|
|
||||||
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||||
|
185
homeassistant/components/open_router/entity.py
Normal file
185
homeassistant/components/open_router/entity.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
"""Base entity for Open Router."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator, Callable
|
||||||
|
import json
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from openai import NOT_GIVEN
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionMessage,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionMessageToolCallParam,
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
ChatCompletionToolParam,
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||||
|
from openai.types.shared_params import FunctionDefinition
|
||||||
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
|
from homeassistant.components import conversation
|
||||||
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
|
from homeassistant.const import CONF_MODEL
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import device_registry as dr, llm
|
||||||
|
from homeassistant.helpers.entity import Entity
|
||||||
|
|
||||||
|
from . import OpenRouterConfigEntry
|
||||||
|
from .const import DOMAIN, LOGGER
|
||||||
|
|
||||||
|
# Max number of back and forth with the LLM to generate a response
|
||||||
|
MAX_TOOL_ITERATIONS = 10
|
||||||
|
|
||||||
|
|
||||||
|
def _format_tool(
|
||||||
|
tool: llm.Tool,
|
||||||
|
custom_serializer: Callable[[Any], Any] | None,
|
||||||
|
) -> ChatCompletionToolParam:
|
||||||
|
"""Format tool specification."""
|
||||||
|
tool_spec = FunctionDefinition(
|
||||||
|
name=tool.name,
|
||||||
|
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||||
|
)
|
||||||
|
if tool.description:
|
||||||
|
tool_spec["description"] = tool.description
|
||||||
|
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_content_to_chat_message(
|
||||||
|
content: conversation.Content,
|
||||||
|
) -> ChatCompletionMessageParam | None:
|
||||||
|
"""Convert any native chat message for this agent to the native format."""
|
||||||
|
LOGGER.debug("_convert_content_to_chat_message=%s", content)
|
||||||
|
if isinstance(content, conversation.ToolResultContent):
|
||||||
|
return ChatCompletionToolMessageParam(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=content.tool_call_id,
|
||||||
|
content=json.dumps(content.tool_result),
|
||||||
|
)
|
||||||
|
|
||||||
|
role: Literal["user", "assistant", "system"] = content.role
|
||||||
|
if role == "system" and content.content:
|
||||||
|
return ChatCompletionSystemMessageParam(role="system", content=content.content)
|
||||||
|
|
||||||
|
if role == "user" and content.content:
|
||||||
|
return ChatCompletionUserMessageParam(role="user", content=content.content)
|
||||||
|
|
||||||
|
if role == "assistant":
|
||||||
|
param = ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant",
|
||||||
|
content=content.content,
|
||||||
|
)
|
||||||
|
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
||||||
|
param["tool_calls"] = [
|
||||||
|
ChatCompletionMessageToolCallParam(
|
||||||
|
type="function",
|
||||||
|
id=tool_call.id,
|
||||||
|
function=Function(
|
||||||
|
arguments=json.dumps(tool_call.tool_args),
|
||||||
|
name=tool_call.tool_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tool_call in content.tool_calls
|
||||||
|
]
|
||||||
|
return param
|
||||||
|
LOGGER.warning("Could not convert message to Completions API: %s", content)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_tool_arguments(arguments: str) -> Any:
|
||||||
|
"""Decode tool call arguments."""
|
||||||
|
try:
|
||||||
|
return json.loads(arguments)
|
||||||
|
except json.JSONDecodeError as err:
|
||||||
|
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
|
||||||
|
|
||||||
|
|
||||||
|
async def _transform_response(
|
||||||
|
message: ChatCompletionMessage,
|
||||||
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
|
"""Transform the OpenRouter message to a ChatLog format."""
|
||||||
|
data: conversation.AssistantContentDeltaDict = {
|
||||||
|
"role": message.role,
|
||||||
|
"content": message.content,
|
||||||
|
}
|
||||||
|
if message.tool_calls:
|
||||||
|
data["tool_calls"] = [
|
||||||
|
llm.ToolInput(
|
||||||
|
id=tool_call.id,
|
||||||
|
tool_name=tool_call.function.name,
|
||||||
|
tool_args=_decode_tool_arguments(tool_call.function.arguments),
|
||||||
|
)
|
||||||
|
for tool_call in message.tool_calls
|
||||||
|
]
|
||||||
|
yield data
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterEntity(Entity):
|
||||||
|
"""Base entity for Open Router."""
|
||||||
|
|
||||||
|
_attr_has_entity_name = True
|
||||||
|
|
||||||
|
def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None:
|
||||||
|
"""Initialize the entity."""
|
||||||
|
self.entry = entry
|
||||||
|
self.subentry = subentry
|
||||||
|
self.model = subentry.data[CONF_MODEL]
|
||||||
|
self._attr_unique_id = subentry.subentry_id
|
||||||
|
self._attr_device_info = dr.DeviceInfo(
|
||||||
|
identifiers={(DOMAIN, subentry.subentry_id)},
|
||||||
|
name=subentry.title,
|
||||||
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _async_handle_chat_log(self, chat_log: conversation.ChatLog) -> None:
|
||||||
|
"""Generate an answer for the chat log."""
|
||||||
|
|
||||||
|
tools: list[ChatCompletionToolParam] | None = None
|
||||||
|
if chat_log.llm_api:
|
||||||
|
tools = [
|
||||||
|
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||||
|
for tool in chat_log.llm_api.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
m
|
||||||
|
for content in chat_log.content
|
||||||
|
if (m := _convert_content_to_chat_message(content))
|
||||||
|
]
|
||||||
|
|
||||||
|
client = self.entry.runtime_data
|
||||||
|
|
||||||
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
|
try:
|
||||||
|
result = await client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools or NOT_GIVEN,
|
||||||
|
user=chat_log.conversation_id,
|
||||||
|
extra_headers={
|
||||||
|
"X-Title": "Home Assistant",
|
||||||
|
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except openai.OpenAIError as err:
|
||||||
|
LOGGER.error("Error talking to API: %s", err)
|
||||||
|
raise HomeAssistantError("Error talking to API") from err
|
||||||
|
|
||||||
|
result_message = result.choices[0].message
|
||||||
|
|
||||||
|
messages.extend(
|
||||||
|
[
|
||||||
|
msg
|
||||||
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
|
self.entity_id, _transform_response(result_message)
|
||||||
|
)
|
||||||
|
if (msg := _convert_content_to_chat_message(content))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if not chat_log.unresponded_tool_results:
|
||||||
|
break
|
@ -25,7 +25,7 @@
|
|||||||
"description": "Configure the new conversation agent",
|
"description": "Configure the new conversation agent",
|
||||||
"data": {
|
"data": {
|
||||||
"model": "Model",
|
"model": "Model",
|
||||||
"prompt": "Instructions",
|
"prompt": "[%key:common::config_flow::data::prompt%]",
|
||||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
||||||
},
|
},
|
||||||
"data_description": {
|
"data_description": {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user