Migrate Google Gen AI to ChatSession (#136779)

* Migrate Google Gen AI to ChatSession

* Remove unused method
This commit is contained in:
Paulus Schoutsen 2025-01-29 10:42:39 -05:00 committed by GitHub
parent 83b34c6faf
commit 8ab6bec746
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 128 additions and 124 deletions

View File

@ -11,18 +11,15 @@ import google.generativeai as genai
from google.generativeai import protos
import google.generativeai.types as genai_types
from google.protobuf.json_format import MessageToDict
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import device_registry as dr, intent, llm, template
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid as ulid_util
from .const import (
CONF_CHAT_MODEL,
@ -152,6 +149,17 @@ def _escape_decode(value: Any) -> Any:
return value
def _chat_message_convert(
message: conversation.Content | conversation.NativeContent[genai_types.ContentDict],
) -> genai_types.ContentDict:
"""Convert any native chat message for this agent to the native format."""
if message.role == "native":
return message.content
role = "model" if message.role == "assistant" else message.role
return {"role": role, "parts": message.content}
class GoogleGenerativeAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -163,7 +171,6 @@ class GoogleGenerativeAIConversationEntity(
def __init__(self, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.entry = entry
self.history: dict[str, list[genai_types.ContentType]] = {}
self._attr_unique_id = entry.entry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
@ -202,49 +209,37 @@ class GoogleGenerativeAIConversationEntity(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
result = conversation.ConversationResult(
response=intent.IntentResponse(language=user_input.language),
conversation_id=user_input.conversation_id or ulid_util.ulid_now(),
)
assert result.conversation_id
async with conversation.async_get_chat_session(
self.hass, user_input
) as session:
return await self._async_handle_message(user_input, session)
llm_context = llm.LLMContext(
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
llm_api: llm.APIInstance | None = None
tools: list[dict[str, Any]] | None = None
if self.entry.options.get(CONF_LLM_HASS_API):
try:
llm_api = await llm.async_get_api(
self.hass,
self.entry.options[CONF_LLM_HASS_API],
llm_context,
)
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
result.response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return result
tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
session: conversation.ChatSession[genai_types.ContentDict],
) -> conversation.ConversationResult:
"""Call the API."""
assert user_input.agent_id
options = self.entry.options
try:
prompt = await self._async_render_prompt(user_input, llm_api, llm_context)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
result.response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
await session.async_update_llm_data(
DOMAIN,
user_input,
options.get(CONF_LLM_HASS_API),
options.get(CONF_PROMPT),
)
return result
except conversation.ConverseError as err:
return err.as_conversation_result()
tools: list[dict[str, Any]] | None = None
if session.llm_api:
tools = [
_format_tool(tool, session.llm_api.custom_serializer)
for tool in session.llm_api.tools
]
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
# Gemini 1.0 doesn't support system_instruction while 1.5 does.
@ -254,6 +249,9 @@ class GoogleGenerativeAIConversationEntity(
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
)
prompt, *messages = [
_chat_message_convert(message) for message in session.async_get_messages()
]
model = genai.GenerativeModel(
model_name=model_name,
generation_config={
@ -281,27 +279,15 @@ class GoogleGenerativeAIConversationEntity(
),
},
tools=tools or None,
system_instruction=prompt if supports_system_instruction else None,
system_instruction=prompt["parts"] if supports_system_instruction else None,
)
messages = self.history.get(result.conversation_id, [])
if not supports_system_instruction:
if not messages:
messages = [{}, {"role": "model", "parts": "Ok"}]
messages[0] = {"role": "user", "parts": prompt}
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{
# Make a copy to attach it to the trace event.
"messages": messages[:]
if supports_system_instruction
else messages[2:],
"prompt": prompt,
"tools": [*llm_api.tools] if llm_api else None,
},
)
messages = [
{"role": "user", "parts": prompt["parts"]},
{"role": "model", "parts": "Ok"},
*messages,
]
chat = model.start_chat(history=messages)
chat_request = user_input.text
@ -326,24 +312,30 @@ class GoogleGenerativeAIConversationEntity(
f"Sorry, I had a problem talking to Google Generative AI: {err}"
)
result.response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
error,
)
return result
raise HomeAssistantError(error) from err
LOGGER.debug("Response: %s", chat_response.parts)
if not chat_response.parts:
result.response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem getting a response from Google Generative AI.",
raise HomeAssistantError(
"Sorry, I had a problem getting a response from Google Generative AI."
)
return result
self.history[result.conversation_id] = chat.history
content = " ".join(
[part.text.strip() for part in chat_response.parts if part.text]
)
if content:
session.async_add_message(
conversation.Content(
role="assistant",
agent_id=user_input.agent_id,
content=content,
)
)
function_calls = [
part.function_call for part in chat_response.parts if part.function_call
]
if not function_calls or not llm_api:
if not function_calls or not session.llm_api:
break
tool_responses = []
@ -351,16 +343,8 @@ class GoogleGenerativeAIConversationEntity(
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
tool_name = tool_call["name"]
tool_args = _escape_decode(tool_call["args"])
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
try:
function_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
function_response = {"error": type(e).__name__}
if str(e):
function_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", function_response)
function_response = await session.async_call_tool(tool_input)
tool_responses.append(
protos.Part(
function_response=protos.FunctionResponse(
@ -369,47 +353,20 @@ class GoogleGenerativeAIConversationEntity(
)
)
chat_request = protos.Content(parts=tool_responses)
session.async_add_message(
conversation.NativeContent(
agent_id=user_input.agent_id,
content=chat_request,
)
)
result.response.async_set_speech(
response = intent.IntentResponse(language=user_input.language)
response.async_set_speech(
" ".join([part.text.strip() for part in chat_response.parts if part.text])
)
return result
async def _async_render_prompt(
self,
user_input: conversation.ConversationInput,
llm_api: llm.APIInstance | None,
llm_context: llm.LLMContext,
) -> str:
user_name: str | None = None
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
):
user_name = user.name
parts = [
template.Template(
llm.BASE_PROMPT
+ self.entry.options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
)
]
if llm_api:
parts.append(llm_api.api_prompt)
return "\n".join(parts)
return conversation.ConversationResult(
response=response, conversation_id=session.conversation_id
)
async def _async_entry_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry

View File

@ -42,6 +42,10 @@
'parts': 'Ok',
'role': 'model',
}),
dict({
'parts': '1st user request',
'role': 'user',
}),
]),
}),
),
@ -102,6 +106,10 @@
'parts': '1st model response',
'role': 'model',
}),
dict({
'parts': '2nd user request',
'role': 'user',
}),
]),
}),
),
@ -150,6 +158,10 @@
),
dict({
'history': list([
dict({
'parts': '1st user request',
'role': 'user',
}),
]),
}),
),
@ -202,6 +214,10 @@
'parts': '1st model response',
'role': 'model',
}),
dict({
'parts': '2nd user request',
'role': 'user',
}),
]),
}),
),
@ -250,6 +266,10 @@
),
dict({
'history': list([
dict({
'parts': 'hello',
'role': 'user',
}),
]),
}),
),
@ -298,6 +318,10 @@
),
dict({
'history': list([
dict({
'parts': 'hello',
'role': 'user',
}),
]),
}),
),
@ -347,6 +371,10 @@
),
dict({
'history': list([
dict({
'parts': 'hello',
'role': 'user',
}),
]),
}),
),
@ -396,6 +424,10 @@
),
dict({
'history': list([
dict({
'parts': 'hello',
'role': 'user',
}),
]),
}),
),
@ -482,6 +514,10 @@
),
dict({
'history': list([
dict({
'parts': 'Please call the test function',
'role': 'user',
}),
]),
}),
),
@ -558,6 +594,10 @@
),
dict({
'history': list([
dict({
'parts': 'Please call the test function',
'role': 'user',
}),
]),
}),
),

View File

@ -208,6 +208,7 @@ async def test_function_call(
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.text = ""
mock_part.function_call = FunctionCall(
name="test_tool",
args={
@ -284,8 +285,12 @@ async def test_function_call(
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["prompt"]
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
assert [
p.function_response.name
for p in detail_event["data"]["messages"][2]["content"].parts
if p.function_response
] == ["test_tool"]
@patch(
@ -315,6 +320,7 @@ async def test_function_call_without_parameters(
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.text = ""
mock_part.function_call = FunctionCall(name="test_tool", args={})
def tool_call(
@ -403,6 +409,7 @@ async def test_function_exception(
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.text = ""
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
def tool_call(
@ -543,7 +550,7 @@ async def test_invalid_llm_api(
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Error preparing LLM API: API invalid_llm_api not found"
"Error preparing LLM API"
)