mirror of
https://github.com/home-assistant/core.git
synced 2025-07-08 13:57:10 +00:00
Migrate Google Gen AI to ChatSession (#136779)
* Migrate Google Gen AI to ChatSession * Remove unused method
This commit is contained in:
parent
83b34c6faf
commit
8ab6bec746
@ -11,18 +11,15 @@ import google.generativeai as genai
|
|||||||
from google.generativeai import protos
|
from google.generativeai import protos
|
||||||
import google.generativeai.types as genai_types
|
import google.generativeai.types as genai_types
|
||||||
from google.protobuf.json_format import MessageToDict
|
from google.protobuf.json_format import MessageToDict
|
||||||
import voluptuous as vol
|
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
from homeassistant.components.conversation import trace
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr, intent, llm, template
|
from homeassistant.helpers import device_registry as dr, intent, llm
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.util import ulid as ulid_util
|
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
@ -152,6 +149,17 @@ def _escape_decode(value: Any) -> Any:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _chat_message_convert(
|
||||||
|
message: conversation.Content | conversation.NativeContent[genai_types.ContentDict],
|
||||||
|
) -> genai_types.ContentDict:
|
||||||
|
"""Convert any native chat message for this agent to the native format."""
|
||||||
|
if message.role == "native":
|
||||||
|
return message.content
|
||||||
|
|
||||||
|
role = "model" if message.role == "assistant" else message.role
|
||||||
|
return {"role": role, "parts": message.content}
|
||||||
|
|
||||||
|
|
||||||
class GoogleGenerativeAIConversationEntity(
|
class GoogleGenerativeAIConversationEntity(
|
||||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||||
):
|
):
|
||||||
@ -163,7 +171,6 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
def __init__(self, entry: ConfigEntry) -> None:
|
def __init__(self, entry: ConfigEntry) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
self.history: dict[str, list[genai_types.ContentType]] = {}
|
|
||||||
self._attr_unique_id = entry.entry_id
|
self._attr_unique_id = entry.entry_id
|
||||||
self._attr_device_info = dr.DeviceInfo(
|
self._attr_device_info = dr.DeviceInfo(
|
||||||
identifiers={(DOMAIN, entry.entry_id)},
|
identifiers={(DOMAIN, entry.entry_id)},
|
||||||
@ -202,50 +209,38 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
result = conversation.ConversationResult(
|
async with conversation.async_get_chat_session(
|
||||||
response=intent.IntentResponse(language=user_input.language),
|
self.hass, user_input
|
||||||
conversation_id=user_input.conversation_id or ulid_util.ulid_now(),
|
) as session:
|
||||||
)
|
return await self._async_handle_message(user_input, session)
|
||||||
assert result.conversation_id
|
|
||||||
|
|
||||||
llm_context = llm.LLMContext(
|
async def _async_handle_message(
|
||||||
platform=DOMAIN,
|
self,
|
||||||
context=user_input.context,
|
user_input: conversation.ConversationInput,
|
||||||
user_prompt=user_input.text,
|
session: conversation.ChatSession[genai_types.ContentDict],
|
||||||
language=user_input.language,
|
) -> conversation.ConversationResult:
|
||||||
assistant=conversation.DOMAIN,
|
"""Call the API."""
|
||||||
device_id=user_input.device_id,
|
|
||||||
|
assert user_input.agent_id
|
||||||
|
options = self.entry.options
|
||||||
|
|
||||||
|
try:
|
||||||
|
await session.async_update_llm_data(
|
||||||
|
DOMAIN,
|
||||||
|
user_input,
|
||||||
|
options.get(CONF_LLM_HASS_API),
|
||||||
|
options.get(CONF_PROMPT),
|
||||||
)
|
)
|
||||||
llm_api: llm.APIInstance | None = None
|
except conversation.ConverseError as err:
|
||||||
|
return err.as_conversation_result()
|
||||||
|
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
if session.llm_api:
|
||||||
try:
|
|
||||||
llm_api = await llm.async_get_api(
|
|
||||||
self.hass,
|
|
||||||
self.entry.options[CONF_LLM_HASS_API],
|
|
||||||
llm_context,
|
|
||||||
)
|
|
||||||
except HomeAssistantError as err:
|
|
||||||
LOGGER.error("Error getting LLM API: %s", err)
|
|
||||||
result.response.async_set_error(
|
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
|
||||||
f"Error preparing LLM API: {err}",
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
tools = [
|
tools = [
|
||||||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
_format_tool(tool, session.llm_api.custom_serializer)
|
||||||
|
for tool in session.llm_api.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
|
||||||
prompt = await self._async_render_prompt(user_input, llm_api, llm_context)
|
|
||||||
except TemplateError as err:
|
|
||||||
LOGGER.error("Error rendering prompt: %s", err)
|
|
||||||
result.response.async_set_error(
|
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
|
||||||
f"Sorry, I had a problem with my template: {err}",
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||||
# Gemini 1.0 doesn't support system_instruction while 1.5 does.
|
# Gemini 1.0 doesn't support system_instruction while 1.5 does.
|
||||||
# Assume future versions will support it (if not, the request fails with a
|
# Assume future versions will support it (if not, the request fails with a
|
||||||
@ -254,6 +249,9 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
|
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt, *messages = [
|
||||||
|
_chat_message_convert(message) for message in session.async_get_messages()
|
||||||
|
]
|
||||||
model = genai.GenerativeModel(
|
model = genai.GenerativeModel(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
generation_config={
|
generation_config={
|
||||||
@ -281,27 +279,15 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
),
|
),
|
||||||
},
|
},
|
||||||
tools=tools or None,
|
tools=tools or None,
|
||||||
system_instruction=prompt if supports_system_instruction else None,
|
system_instruction=prompt["parts"] if supports_system_instruction else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = self.history.get(result.conversation_id, [])
|
|
||||||
if not supports_system_instruction:
|
if not supports_system_instruction:
|
||||||
if not messages:
|
messages = [
|
||||||
messages = [{}, {"role": "model", "parts": "Ok"}]
|
{"role": "user", "parts": prompt["parts"]},
|
||||||
messages[0] = {"role": "user", "parts": prompt}
|
{"role": "model", "parts": "Ok"},
|
||||||
|
*messages,
|
||||||
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
|
]
|
||||||
trace.async_conversation_trace_append(
|
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
|
||||||
{
|
|
||||||
# Make a copy to attach it to the trace event.
|
|
||||||
"messages": messages[:]
|
|
||||||
if supports_system_instruction
|
|
||||||
else messages[2:],
|
|
||||||
"prompt": prompt,
|
|
||||||
"tools": [*llm_api.tools] if llm_api else None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
chat = model.start_chat(history=messages)
|
chat = model.start_chat(history=messages)
|
||||||
chat_request = user_input.text
|
chat_request = user_input.text
|
||||||
@ -326,24 +312,30 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
||||||
)
|
)
|
||||||
|
|
||||||
result.response.async_set_error(
|
raise HomeAssistantError(error) from err
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
|
||||||
error,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
LOGGER.debug("Response: %s", chat_response.parts)
|
LOGGER.debug("Response: %s", chat_response.parts)
|
||||||
if not chat_response.parts:
|
if not chat_response.parts:
|
||||||
result.response.async_set_error(
|
raise HomeAssistantError(
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
"Sorry, I had a problem getting a response from Google Generative AI."
|
||||||
"Sorry, I had a problem getting a response from Google Generative AI.",
|
|
||||||
)
|
)
|
||||||
return result
|
content = " ".join(
|
||||||
self.history[result.conversation_id] = chat.history
|
[part.text.strip() for part in chat_response.parts if part.text]
|
||||||
|
)
|
||||||
|
if content:
|
||||||
|
session.async_add_message(
|
||||||
|
conversation.Content(
|
||||||
|
role="assistant",
|
||||||
|
agent_id=user_input.agent_id,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
function_calls = [
|
function_calls = [
|
||||||
part.function_call for part in chat_response.parts if part.function_call
|
part.function_call for part in chat_response.parts if part.function_call
|
||||||
]
|
]
|
||||||
if not function_calls or not llm_api:
|
|
||||||
|
if not function_calls or not session.llm_api:
|
||||||
break
|
break
|
||||||
|
|
||||||
tool_responses = []
|
tool_responses = []
|
||||||
@ -351,16 +343,8 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
||||||
tool_name = tool_call["name"]
|
tool_name = tool_call["name"]
|
||||||
tool_args = _escape_decode(tool_call["args"])
|
tool_args = _escape_decode(tool_call["args"])
|
||||||
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
|
|
||||||
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||||
try:
|
function_response = await session.async_call_tool(tool_input)
|
||||||
function_response = await llm_api.async_call_tool(tool_input)
|
|
||||||
except (HomeAssistantError, vol.Invalid) as e:
|
|
||||||
function_response = {"error": type(e).__name__}
|
|
||||||
if str(e):
|
|
||||||
function_response["error_text"] = str(e)
|
|
||||||
|
|
||||||
LOGGER.debug("Tool response: %s", function_response)
|
|
||||||
tool_responses.append(
|
tool_responses.append(
|
||||||
protos.Part(
|
protos.Part(
|
||||||
function_response=protos.FunctionResponse(
|
function_response=protos.FunctionResponse(
|
||||||
@ -369,47 +353,20 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
chat_request = protos.Content(parts=tool_responses)
|
chat_request = protos.Content(parts=tool_responses)
|
||||||
|
session.async_add_message(
|
||||||
|
conversation.NativeContent(
|
||||||
|
agent_id=user_input.agent_id,
|
||||||
|
content=chat_request,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
result.response.async_set_speech(
|
response = intent.IntentResponse(language=user_input.language)
|
||||||
|
response.async_set_speech(
|
||||||
" ".join([part.text.strip() for part in chat_response.parts if part.text])
|
" ".join([part.text.strip() for part in chat_response.parts if part.text])
|
||||||
)
|
)
|
||||||
return result
|
return conversation.ConversationResult(
|
||||||
|
response=response, conversation_id=session.conversation_id
|
||||||
async def _async_render_prompt(
|
|
||||||
self,
|
|
||||||
user_input: conversation.ConversationInput,
|
|
||||||
llm_api: llm.APIInstance | None,
|
|
||||||
llm_context: llm.LLMContext,
|
|
||||||
) -> str:
|
|
||||||
user_name: str | None = None
|
|
||||||
if (
|
|
||||||
user_input.context
|
|
||||||
and user_input.context.user_id
|
|
||||||
and (
|
|
||||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
|
||||||
)
|
)
|
||||||
):
|
|
||||||
user_name = user.name
|
|
||||||
|
|
||||||
parts = [
|
|
||||||
template.Template(
|
|
||||||
llm.BASE_PROMPT
|
|
||||||
+ self.entry.options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
|
||||||
self.hass,
|
|
||||||
).async_render(
|
|
||||||
{
|
|
||||||
"ha_name": self.hass.config.location_name,
|
|
||||||
"user_name": user_name,
|
|
||||||
"llm_context": llm_context,
|
|
||||||
},
|
|
||||||
parse_result=False,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if llm_api:
|
|
||||||
parts.append(llm_api.api_prompt)
|
|
||||||
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
async def _async_entry_update_listener(
|
async def _async_entry_update_listener(
|
||||||
self, hass: HomeAssistant, entry: ConfigEntry
|
self, hass: HomeAssistant, entry: ConfigEntry
|
||||||
|
@ -42,6 +42,10 @@
|
|||||||
'parts': 'Ok',
|
'parts': 'Ok',
|
||||||
'role': 'model',
|
'role': 'model',
|
||||||
}),
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': '1st user request',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -102,6 +106,10 @@
|
|||||||
'parts': '1st model response',
|
'parts': '1st model response',
|
||||||
'role': 'model',
|
'role': 'model',
|
||||||
}),
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': '2nd user request',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -150,6 +158,10 @@
|
|||||||
),
|
),
|
||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': '1st user request',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -202,6 +214,10 @@
|
|||||||
'parts': '1st model response',
|
'parts': '1st model response',
|
||||||
'role': 'model',
|
'role': 'model',
|
||||||
}),
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': '2nd user request',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -250,6 +266,10 @@
|
|||||||
),
|
),
|
||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': 'hello',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -298,6 +318,10 @@
|
|||||||
),
|
),
|
||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': 'hello',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -347,6 +371,10 @@
|
|||||||
),
|
),
|
||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': 'hello',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -396,6 +424,10 @@
|
|||||||
),
|
),
|
||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': 'hello',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -482,6 +514,10 @@
|
|||||||
),
|
),
|
||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': 'Please call the test function',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -558,6 +594,10 @@
|
|||||||
),
|
),
|
||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': 'Please call the test function',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
@ -208,6 +208,7 @@ async def test_function_call(
|
|||||||
chat_response = MagicMock()
|
chat_response = MagicMock()
|
||||||
mock_chat.send_message_async.return_value = chat_response
|
mock_chat.send_message_async.return_value = chat_response
|
||||||
mock_part = MagicMock()
|
mock_part = MagicMock()
|
||||||
|
mock_part.text = ""
|
||||||
mock_part.function_call = FunctionCall(
|
mock_part.function_call = FunctionCall(
|
||||||
name="test_tool",
|
name="test_tool",
|
||||||
args={
|
args={
|
||||||
@ -284,8 +285,12 @@ async def test_function_call(
|
|||||||
]
|
]
|
||||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||||
detail_event = trace_events[1]
|
detail_event = trace_events[1]
|
||||||
assert "Answer in plain text" in detail_event["data"]["prompt"]
|
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
||||||
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
|
assert [
|
||||||
|
p.function_response.name
|
||||||
|
for p in detail_event["data"]["messages"][2]["content"].parts
|
||||||
|
if p.function_response
|
||||||
|
] == ["test_tool"]
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
@ -315,6 +320,7 @@ async def test_function_call_without_parameters(
|
|||||||
chat_response = MagicMock()
|
chat_response = MagicMock()
|
||||||
mock_chat.send_message_async.return_value = chat_response
|
mock_chat.send_message_async.return_value = chat_response
|
||||||
mock_part = MagicMock()
|
mock_part = MagicMock()
|
||||||
|
mock_part.text = ""
|
||||||
mock_part.function_call = FunctionCall(name="test_tool", args={})
|
mock_part.function_call = FunctionCall(name="test_tool", args={})
|
||||||
|
|
||||||
def tool_call(
|
def tool_call(
|
||||||
@ -403,6 +409,7 @@ async def test_function_exception(
|
|||||||
chat_response = MagicMock()
|
chat_response = MagicMock()
|
||||||
mock_chat.send_message_async.return_value = chat_response
|
mock_chat.send_message_async.return_value = chat_response
|
||||||
mock_part = MagicMock()
|
mock_part = MagicMock()
|
||||||
|
mock_part.text = ""
|
||||||
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
|
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
|
||||||
|
|
||||||
def tool_call(
|
def tool_call(
|
||||||
@ -543,7 +550,7 @@ async def test_invalid_llm_api(
|
|||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||||
"Error preparing LLM API: API invalid_llm_api not found"
|
"Error preparing LLM API"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user