Migrate Google Gen AI to ChatSession (#136779)

* Migrate Google Gen AI to ChatSession

* Remove unused method
This commit is contained in:
Paulus Schoutsen 2025-01-29 10:42:39 -05:00 committed by GitHub
parent 83b34c6faf
commit 8ab6bec746
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 128 additions and 124 deletions

View File

@ -11,18 +11,15 @@ import google.generativeai as genai
from google.generativeai import protos from google.generativeai import protos
import google.generativeai.types as genai_types import google.generativeai.types as genai_types
from google.protobuf.json_format import MessageToDict from google.protobuf.json_format import MessageToDict
import voluptuous as vol
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, intent, llm, template from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid as ulid_util
from .const import ( from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
@ -152,6 +149,17 @@ def _escape_decode(value: Any) -> Any:
return value return value
def _chat_message_convert(
message: conversation.Content | conversation.NativeContent[genai_types.ContentDict],
) -> genai_types.ContentDict:
"""Convert any native chat message for this agent to the native format."""
if message.role == "native":
return message.content
role = "model" if message.role == "assistant" else message.role
return {"role": role, "parts": message.content}
class GoogleGenerativeAIConversationEntity( class GoogleGenerativeAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent conversation.ConversationEntity, conversation.AbstractConversationAgent
): ):
@ -163,7 +171,6 @@ class GoogleGenerativeAIConversationEntity(
def __init__(self, entry: ConfigEntry) -> None: def __init__(self, entry: ConfigEntry) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.entry = entry self.entry = entry
self.history: dict[str, list[genai_types.ContentType]] = {}
self._attr_unique_id = entry.entry_id self._attr_unique_id = entry.entry_id
self._attr_device_info = dr.DeviceInfo( self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)}, identifiers={(DOMAIN, entry.entry_id)},
@ -202,50 +209,38 @@ class GoogleGenerativeAIConversationEntity(
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
result = conversation.ConversationResult( async with conversation.async_get_chat_session(
response=intent.IntentResponse(language=user_input.language), self.hass, user_input
conversation_id=user_input.conversation_id or ulid_util.ulid_now(), ) as session:
) return await self._async_handle_message(user_input, session)
assert result.conversation_id
llm_context = llm.LLMContext( async def _async_handle_message(
platform=DOMAIN, self,
context=user_input.context, user_input: conversation.ConversationInput,
user_prompt=user_input.text, session: conversation.ChatSession[genai_types.ContentDict],
language=user_input.language, ) -> conversation.ConversationResult:
assistant=conversation.DOMAIN, """Call the API."""
device_id=user_input.device_id,
assert user_input.agent_id
options = self.entry.options
try:
await session.async_update_llm_data(
DOMAIN,
user_input,
options.get(CONF_LLM_HASS_API),
options.get(CONF_PROMPT),
) )
llm_api: llm.APIInstance | None = None except conversation.ConverseError as err:
return err.as_conversation_result()
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None
if self.entry.options.get(CONF_LLM_HASS_API): if session.llm_api:
try:
llm_api = await llm.async_get_api(
self.hass,
self.entry.options[CONF_LLM_HASS_API],
llm_context,
)
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
result.response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return result
tools = [ tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools _format_tool(tool, session.llm_api.custom_serializer)
for tool in session.llm_api.tools
] ]
try:
prompt = await self._async_render_prompt(user_input, llm_api, llm_context)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
result.response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return result
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
# Gemini 1.0 doesn't support system_instruction while 1.5 does. # Gemini 1.0 doesn't support system_instruction while 1.5 does.
# Assume future versions will support it (if not, the request fails with a # Assume future versions will support it (if not, the request fails with a
@ -254,6 +249,9 @@ class GoogleGenerativeAIConversationEntity(
"gemini-1.0" not in model_name and "gemini-pro" not in model_name "gemini-1.0" not in model_name and "gemini-pro" not in model_name
) )
prompt, *messages = [
_chat_message_convert(message) for message in session.async_get_messages()
]
model = genai.GenerativeModel( model = genai.GenerativeModel(
model_name=model_name, model_name=model_name,
generation_config={ generation_config={
@ -281,27 +279,15 @@ class GoogleGenerativeAIConversationEntity(
), ),
}, },
tools=tools or None, tools=tools or None,
system_instruction=prompt if supports_system_instruction else None, system_instruction=prompt["parts"] if supports_system_instruction else None,
) )
messages = self.history.get(result.conversation_id, [])
if not supports_system_instruction: if not supports_system_instruction:
if not messages: messages = [
messages = [{}, {"role": "model", "parts": "Ok"}] {"role": "user", "parts": prompt["parts"]},
messages[0] = {"role": "user", "parts": prompt} {"role": "model", "parts": "Ok"},
*messages,
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) ]
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{
# Make a copy to attach it to the trace event.
"messages": messages[:]
if supports_system_instruction
else messages[2:],
"prompt": prompt,
"tools": [*llm_api.tools] if llm_api else None,
},
)
chat = model.start_chat(history=messages) chat = model.start_chat(history=messages)
chat_request = user_input.text chat_request = user_input.text
@ -326,24 +312,30 @@ class GoogleGenerativeAIConversationEntity(
f"Sorry, I had a problem talking to Google Generative AI: {err}" f"Sorry, I had a problem talking to Google Generative AI: {err}"
) )
result.response.async_set_error( raise HomeAssistantError(error) from err
intent.IntentResponseErrorCode.UNKNOWN,
error,
)
return result
LOGGER.debug("Response: %s", chat_response.parts) LOGGER.debug("Response: %s", chat_response.parts)
if not chat_response.parts: if not chat_response.parts:
result.response.async_set_error( raise HomeAssistantError(
intent.IntentResponseErrorCode.UNKNOWN, "Sorry, I had a problem getting a response from Google Generative AI."
"Sorry, I had a problem getting a response from Google Generative AI.",
) )
return result content = " ".join(
self.history[result.conversation_id] = chat.history [part.text.strip() for part in chat_response.parts if part.text]
)
if content:
session.async_add_message(
conversation.Content(
role="assistant",
agent_id=user_input.agent_id,
content=content,
)
)
function_calls = [ function_calls = [
part.function_call for part in chat_response.parts if part.function_call part.function_call for part in chat_response.parts if part.function_call
] ]
if not function_calls or not llm_api:
if not function_calls or not session.llm_api:
break break
tool_responses = [] tool_responses = []
@ -351,16 +343,8 @@ class GoogleGenerativeAIConversationEntity(
tool_call = MessageToDict(function_call._pb) # noqa: SLF001 tool_call = MessageToDict(function_call._pb) # noqa: SLF001
tool_name = tool_call["name"] tool_name = tool_call["name"]
tool_args = _escape_decode(tool_call["args"]) tool_args = _escape_decode(tool_call["args"])
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args) tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
try: function_response = await session.async_call_tool(tool_input)
function_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
function_response = {"error": type(e).__name__}
if str(e):
function_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", function_response)
tool_responses.append( tool_responses.append(
protos.Part( protos.Part(
function_response=protos.FunctionResponse( function_response=protos.FunctionResponse(
@ -369,47 +353,20 @@ class GoogleGenerativeAIConversationEntity(
) )
) )
chat_request = protos.Content(parts=tool_responses) chat_request = protos.Content(parts=tool_responses)
session.async_add_message(
conversation.NativeContent(
agent_id=user_input.agent_id,
content=chat_request,
)
)
result.response.async_set_speech( response = intent.IntentResponse(language=user_input.language)
response.async_set_speech(
" ".join([part.text.strip() for part in chat_response.parts if part.text]) " ".join([part.text.strip() for part in chat_response.parts if part.text])
) )
return result return conversation.ConversationResult(
response=response, conversation_id=session.conversation_id
async def _async_render_prompt(
self,
user_input: conversation.ConversationInput,
llm_api: llm.APIInstance | None,
llm_context: llm.LLMContext,
) -> str:
user_name: str | None = None
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
) )
):
user_name = user.name
parts = [
template.Template(
llm.BASE_PROMPT
+ self.entry.options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
)
]
if llm_api:
parts.append(llm_api.api_prompt)
return "\n".join(parts)
async def _async_entry_update_listener( async def _async_entry_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry self, hass: HomeAssistant, entry: ConfigEntry

View File

@ -42,6 +42,10 @@
'parts': 'Ok', 'parts': 'Ok',
'role': 'model', 'role': 'model',
}), }),
dict({
'parts': '1st user request',
'role': 'user',
}),
]), ]),
}), }),
), ),
@ -102,6 +106,10 @@
'parts': '1st model response', 'parts': '1st model response',
'role': 'model', 'role': 'model',
}), }),
dict({
'parts': '2nd user request',
'role': 'user',
}),
]), ]),
}), }),
), ),
@ -150,6 +158,10 @@
), ),
dict({ dict({
'history': list([ 'history': list([
dict({
'parts': '1st user request',
'role': 'user',
}),
]), ]),
}), }),
), ),
@ -202,6 +214,10 @@
'parts': '1st model response', 'parts': '1st model response',
'role': 'model', 'role': 'model',
}), }),
dict({
'parts': '2nd user request',
'role': 'user',
}),
]), ]),
}), }),
), ),
@ -250,6 +266,10 @@
), ),
dict({ dict({
'history': list([ 'history': list([
dict({
'parts': 'hello',
'role': 'user',
}),
]), ]),
}), }),
), ),
@ -298,6 +318,10 @@
), ),
dict({ dict({
'history': list([ 'history': list([
dict({
'parts': 'hello',
'role': 'user',
}),
]), ]),
}), }),
), ),
@ -347,6 +371,10 @@
), ),
dict({ dict({
'history': list([ 'history': list([
dict({
'parts': 'hello',
'role': 'user',
}),
]), ]),
}), }),
), ),
@ -396,6 +424,10 @@
), ),
dict({ dict({
'history': list([ 'history': list([
dict({
'parts': 'hello',
'role': 'user',
}),
]), ]),
}), }),
), ),
@ -482,6 +514,10 @@
), ),
dict({ dict({
'history': list([ 'history': list([
dict({
'parts': 'Please call the test function',
'role': 'user',
}),
]), ]),
}), }),
), ),
@ -558,6 +594,10 @@
), ),
dict({ dict({
'history': list([ 'history': list([
dict({
'parts': 'Please call the test function',
'role': 'user',
}),
]), ]),
}), }),
), ),

View File

@ -208,6 +208,7 @@ async def test_function_call(
chat_response = MagicMock() chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock() mock_part = MagicMock()
mock_part.text = ""
mock_part.function_call = FunctionCall( mock_part.function_call = FunctionCall(
name="test_tool", name="test_tool",
args={ args={
@ -284,8 +285,12 @@ async def test_function_call(
] ]
# AGENT_DETAIL event contains the raw prompt passed to the model # AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1] detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["prompt"] assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"] assert [
p.function_response.name
for p in detail_event["data"]["messages"][2]["content"].parts
if p.function_response
] == ["test_tool"]
@patch( @patch(
@ -315,6 +320,7 @@ async def test_function_call_without_parameters(
chat_response = MagicMock() chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock() mock_part = MagicMock()
mock_part.text = ""
mock_part.function_call = FunctionCall(name="test_tool", args={}) mock_part.function_call = FunctionCall(name="test_tool", args={})
def tool_call( def tool_call(
@ -403,6 +409,7 @@ async def test_function_exception(
chat_response = MagicMock() chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock() mock_part = MagicMock()
mock_part.text = ""
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1}) mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
def tool_call( def tool_call(
@ -543,7 +550,7 @@ async def test_invalid_llm_api(
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == ( assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Error preparing LLM API: API invalid_llm_api not found" "Error preparing LLM API"
) )