mirror of
https://github.com/home-assistant/core.git
synced 2025-07-07 13:27:09 +00:00
Migrate Google Gen AI to ChatSession (#136779)
* Migrate Google Gen AI to ChatSession * Remove unused method
This commit is contained in:
parent
83b34c6faf
commit
8ab6bec746
@ -11,18 +11,15 @@ import google.generativeai as genai
|
||||
from google.generativeai import protos
|
||||
import google.generativeai.types as genai_types
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import device_registry as dr, intent, llm, template
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, intent, llm
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util import ulid as ulid_util
|
||||
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
@ -152,6 +149,17 @@ def _escape_decode(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _chat_message_convert(
|
||||
message: conversation.Content | conversation.NativeContent[genai_types.ContentDict],
|
||||
) -> genai_types.ContentDict:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
if message.role == "native":
|
||||
return message.content
|
||||
|
||||
role = "model" if message.role == "assistant" else message.role
|
||||
return {"role": role, "parts": message.content}
|
||||
|
||||
|
||||
class GoogleGenerativeAIConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
@ -163,7 +171,6 @@ class GoogleGenerativeAIConversationEntity(
|
||||
def __init__(self, entry: ConfigEntry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.entry = entry
|
||||
self.history: dict[str, list[genai_types.ContentType]] = {}
|
||||
self._attr_unique_id = entry.entry_id
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, entry.entry_id)},
|
||||
@ -202,49 +209,37 @@ class GoogleGenerativeAIConversationEntity(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
result = conversation.ConversationResult(
|
||||
response=intent.IntentResponse(language=user_input.language),
|
||||
conversation_id=user_input.conversation_id or ulid_util.ulid_now(),
|
||||
)
|
||||
assert result.conversation_id
|
||||
async with conversation.async_get_chat_session(
|
||||
self.hass, user_input
|
||||
) as session:
|
||||
return await self._async_handle_message(user_input, session)
|
||||
|
||||
llm_context = llm.LLMContext(
|
||||
platform=DOMAIN,
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
llm_api: llm.APIInstance | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
llm_api = await llm.async_get_api(
|
||||
self.hass,
|
||||
self.entry.options[CONF_LLM_HASS_API],
|
||||
llm_context,
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
LOGGER.error("Error getting LLM API: %s", err)
|
||||
result.response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Error preparing LLM API: {err}",
|
||||
)
|
||||
return result
|
||||
tools = [
|
||||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
||||
]
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
session: conversation.ChatSession[genai_types.ContentDict],
|
||||
) -> conversation.ConversationResult:
|
||||
"""Call the API."""
|
||||
|
||||
assert user_input.agent_id
|
||||
options = self.entry.options
|
||||
|
||||
try:
|
||||
prompt = await self._async_render_prompt(user_input, llm_api, llm_context)
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
result.response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem with my template: {err}",
|
||||
await session.async_update_llm_data(
|
||||
DOMAIN,
|
||||
user_input,
|
||||
options.get(CONF_LLM_HASS_API),
|
||||
options.get(CONF_PROMPT),
|
||||
)
|
||||
return result
|
||||
except conversation.ConverseError as err:
|
||||
return err.as_conversation_result()
|
||||
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
if session.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, session.llm_api.custom_serializer)
|
||||
for tool in session.llm_api.tools
|
||||
]
|
||||
|
||||
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
# Gemini 1.0 doesn't support system_instruction while 1.5 does.
|
||||
@ -254,6 +249,9 @@ class GoogleGenerativeAIConversationEntity(
|
||||
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
|
||||
)
|
||||
|
||||
prompt, *messages = [
|
||||
_chat_message_convert(message) for message in session.async_get_messages()
|
||||
]
|
||||
model = genai.GenerativeModel(
|
||||
model_name=model_name,
|
||||
generation_config={
|
||||
@ -281,27 +279,15 @@ class GoogleGenerativeAIConversationEntity(
|
||||
),
|
||||
},
|
||||
tools=tools or None,
|
||||
system_instruction=prompt if supports_system_instruction else None,
|
||||
system_instruction=prompt["parts"] if supports_system_instruction else None,
|
||||
)
|
||||
|
||||
messages = self.history.get(result.conversation_id, [])
|
||||
if not supports_system_instruction:
|
||||
if not messages:
|
||||
messages = [{}, {"role": "model", "parts": "Ok"}]
|
||||
messages[0] = {"role": "user", "parts": prompt}
|
||||
|
||||
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
{
|
||||
# Make a copy to attach it to the trace event.
|
||||
"messages": messages[:]
|
||||
if supports_system_instruction
|
||||
else messages[2:],
|
||||
"prompt": prompt,
|
||||
"tools": [*llm_api.tools] if llm_api else None,
|
||||
},
|
||||
)
|
||||
messages = [
|
||||
{"role": "user", "parts": prompt["parts"]},
|
||||
{"role": "model", "parts": "Ok"},
|
||||
*messages,
|
||||
]
|
||||
|
||||
chat = model.start_chat(history=messages)
|
||||
chat_request = user_input.text
|
||||
@ -326,24 +312,30 @@ class GoogleGenerativeAIConversationEntity(
|
||||
f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
||||
)
|
||||
|
||||
result.response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
error,
|
||||
)
|
||||
return result
|
||||
raise HomeAssistantError(error) from err
|
||||
|
||||
LOGGER.debug("Response: %s", chat_response.parts)
|
||||
if not chat_response.parts:
|
||||
result.response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Sorry, I had a problem getting a response from Google Generative AI.",
|
||||
raise HomeAssistantError(
|
||||
"Sorry, I had a problem getting a response from Google Generative AI."
|
||||
)
|
||||
return result
|
||||
self.history[result.conversation_id] = chat.history
|
||||
content = " ".join(
|
||||
[part.text.strip() for part in chat_response.parts if part.text]
|
||||
)
|
||||
if content:
|
||||
session.async_add_message(
|
||||
conversation.Content(
|
||||
role="assistant",
|
||||
agent_id=user_input.agent_id,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
function_calls = [
|
||||
part.function_call for part in chat_response.parts if part.function_call
|
||||
]
|
||||
if not function_calls or not llm_api:
|
||||
|
||||
if not function_calls or not session.llm_api:
|
||||
break
|
||||
|
||||
tool_responses = []
|
||||
@ -351,16 +343,8 @@ class GoogleGenerativeAIConversationEntity(
|
||||
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = _escape_decode(tool_call["args"])
|
||||
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
|
||||
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||
try:
|
||||
function_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
function_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
function_response["error_text"] = str(e)
|
||||
|
||||
LOGGER.debug("Tool response: %s", function_response)
|
||||
function_response = await session.async_call_tool(tool_input)
|
||||
tool_responses.append(
|
||||
protos.Part(
|
||||
function_response=protos.FunctionResponse(
|
||||
@ -369,47 +353,20 @@ class GoogleGenerativeAIConversationEntity(
|
||||
)
|
||||
)
|
||||
chat_request = protos.Content(parts=tool_responses)
|
||||
session.async_add_message(
|
||||
conversation.NativeContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=chat_request,
|
||||
)
|
||||
)
|
||||
|
||||
result.response.async_set_speech(
|
||||
response = intent.IntentResponse(language=user_input.language)
|
||||
response.async_set_speech(
|
||||
" ".join([part.text.strip() for part in chat_response.parts if part.text])
|
||||
)
|
||||
return result
|
||||
|
||||
async def _async_render_prompt(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
llm_api: llm.APIInstance | None,
|
||||
llm_context: llm.LLMContext,
|
||||
) -> str:
|
||||
user_name: str | None = None
|
||||
if (
|
||||
user_input.context
|
||||
and user_input.context.user_id
|
||||
and (
|
||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
||||
)
|
||||
):
|
||||
user_name = user.name
|
||||
|
||||
parts = [
|
||||
template.Template(
|
||||
llm.BASE_PROMPT
|
||||
+ self.entry.options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||
self.hass,
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"user_name": user_name,
|
||||
"llm_context": llm_context,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
]
|
||||
|
||||
if llm_api:
|
||||
parts.append(llm_api.api_prompt)
|
||||
|
||||
return "\n".join(parts)
|
||||
return conversation.ConversationResult(
|
||||
response=response, conversation_id=session.conversation_id
|
||||
)
|
||||
|
||||
async def _async_entry_update_listener(
|
||||
self, hass: HomeAssistant, entry: ConfigEntry
|
||||
|
@ -42,6 +42,10 @@
|
||||
'parts': 'Ok',
|
||||
'role': 'model',
|
||||
}),
|
||||
dict({
|
||||
'parts': '1st user request',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
@ -102,6 +106,10 @@
|
||||
'parts': '1st model response',
|
||||
'role': 'model',
|
||||
}),
|
||||
dict({
|
||||
'parts': '2nd user request',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
@ -150,6 +158,10 @@
|
||||
),
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': '1st user request',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
@ -202,6 +214,10 @@
|
||||
'parts': '1st model response',
|
||||
'role': 'model',
|
||||
}),
|
||||
dict({
|
||||
'parts': '2nd user request',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
@ -250,6 +266,10 @@
|
||||
),
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': 'hello',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
@ -298,6 +318,10 @@
|
||||
),
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': 'hello',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
@ -347,6 +371,10 @@
|
||||
),
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': 'hello',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
@ -396,6 +424,10 @@
|
||||
),
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': 'hello',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
@ -482,6 +514,10 @@
|
||||
),
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': 'Please call the test function',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
@ -558,6 +594,10 @@
|
||||
),
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': 'Please call the test function',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
|
@ -208,6 +208,7 @@ async def test_function_call(
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.text = ""
|
||||
mock_part.function_call = FunctionCall(
|
||||
name="test_tool",
|
||||
args={
|
||||
@ -284,8 +285,12 @@ async def test_function_call(
|
||||
]
|
||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||
detail_event = trace_events[1]
|
||||
assert "Answer in plain text" in detail_event["data"]["prompt"]
|
||||
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
|
||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
||||
assert [
|
||||
p.function_response.name
|
||||
for p in detail_event["data"]["messages"][2]["content"].parts
|
||||
if p.function_response
|
||||
] == ["test_tool"]
|
||||
|
||||
|
||||
@patch(
|
||||
@ -315,6 +320,7 @@ async def test_function_call_without_parameters(
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.text = ""
|
||||
mock_part.function_call = FunctionCall(name="test_tool", args={})
|
||||
|
||||
def tool_call(
|
||||
@ -403,6 +409,7 @@ async def test_function_exception(
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.text = ""
|
||||
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
|
||||
|
||||
def tool_call(
|
||||
@ -543,7 +550,7 @@ async def test_invalid_llm_api(
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
"Error preparing LLM API: API invalid_llm_api not found"
|
||||
"Error preparing LLM API"
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user